xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_worker_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base testing class for strategies that require multiple nodes."""
16
17import contextlib
18import copy
19import json
20import os
21import subprocess
22import sys
23import threading
24import unittest
25
26import six
27
28
29# pylint: disable=g-import-not-at-top
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python.client import session
33from tensorflow.python.distribute import distribute_coordinator as dc
34from tensorflow.python.distribute import multi_process_runner
35from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
36from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
37from tensorflow.python.eager import context
38from tensorflow.python.eager import remote
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import test_util
42from tensorflow.python.platform import test
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import coordinator
45from tensorflow.python.training import server_lib
46from tensorflow.python.util import deprecation
47from tensorflow.python.util import nest
48from tensorflow.python.util.compat import collections_abc
49from tensorflow.python.util.tf_export import tf_export
50
51
52original_run_std_server = dc._run_std_server  # pylint: disable=protected-access
53pick_unused_port = test_util.pick_unused_port
54
55
56def _create_cluster(num_workers,
57                    num_ps,
58                    has_chief=False,
59                    has_eval=False,
60                    protocol='grpc',
61                    worker_config=None,
62                    ps_config=None,
63                    eval_config=None,
64                    worker_name='worker',
65                    ps_name='ps',
66                    chief_name='chief'):
67  """Creates and starts local servers and returns the cluster_spec dict."""
68
69  worker_ports = [pick_unused_port() for _ in range(num_workers)]
70  ps_ports = [pick_unused_port() for _ in range(num_ps)]
71
72  cluster_dict = {}
73  if num_workers > 0:
74    cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports]
75  if num_ps > 0:
76    cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports]
77  if has_eval:
78    cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
79  if has_chief:
80    cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()]
81
82  cs = server_lib.ClusterSpec(cluster_dict)
83
84  for i in range(num_workers):
85    server_lib.Server(
86        cs,
87        job_name=worker_name,
88        protocol=protocol,
89        task_index=i,
90        config=worker_config,
91        start=True)
92
93  for i in range(num_ps):
94    server_lib.Server(
95        cs,
96        job_name=ps_name,
97        protocol=protocol,
98        task_index=i,
99        config=ps_config,
100        start=True)
101
102  if has_chief:
103    server_lib.Server(
104        cs,
105        job_name=chief_name,
106        protocol=protocol,
107        task_index=0,
108        config=worker_config,
109        start=True)
110
111  if has_eval:
112    server_lib.Server(
113        cs,
114        job_name='evaluator',
115        protocol=protocol,
116        task_index=0,
117        config=eval_config,
118        start=True)
119
120  return cluster_dict
121
122
123def create_in_process_cluster(num_workers,
124                              num_ps,
125                              has_chief=False,
126                              has_eval=False,
127                              rpc_layer='grpc'):
128  """Create an in-process cluster that consists of only standard server."""
129  # Leave some memory for cuda runtime.
130  gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
131  worker_config = config_pb2.ConfigProto()
132  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
133
134  # The cluster may hang if workers don't have enough inter_op threads. See
135  # b/172296720 for more details.
136  if worker_config.inter_op_parallelism_threads < num_workers + 1:
137    worker_config.inter_op_parallelism_threads = num_workers + 1
138
139  # Enable collective ops which has no impact on non-collective ops.
140  # TODO(yuefengz, tucker): removing this after we move the initialization of
141  # collective mgr to the session level.
142  if has_chief:
143    worker_config.experimental.collective_group_leader = (
144        '/job:chief/replica:0/task:0')
145  else:
146    worker_config.experimental.collective_group_leader = (
147        '/job:worker/replica:0/task:0')
148
149  ps_config = config_pb2.ConfigProto()
150  ps_config.device_count['GPU'] = 0
151
152  eval_config = config_pb2.ConfigProto()
153  eval_config.experimental.collective_group_leader = ''
154
155  # Create in-process servers. Once an in-process tensorflow server is created,
156  # there is no way to terminate it. So we create one cluster per test process.
157  # We could've started the server in another process, we could then kill that
158  # process to terminate the server. The reasons why we don't want multiple
159  # processes are
160  # 1) it is more difficult to manage these processes;
161  # 2) there is something global in CUDA such that if we initialize CUDA in the
162  # parent process, the child process cannot initialize it again and thus cannot
163  # use GPUs (https://stackoverflow.com/questions/22950047).
164  cluster = None
165  try:
166    cluster = _create_cluster(
167        num_workers,
168        num_ps=num_ps,
169        has_chief=has_chief,
170        has_eval=has_eval,
171        worker_config=worker_config,
172        ps_config=ps_config,
173        eval_config=eval_config,
174        protocol=rpc_layer)
175  except errors.UnknownError as e:
176    if 'Could not start gRPC server' in e.message:
177      raise unittest.SkipTest('Cannot start std servers.')
178    else:
179      raise
180  return cluster
181
182
183class MultiProcessCluster(object):
184  """A cluster of TensorFlow servers in separate processes.
185
186  This class is not thread-safe.
187  """
188
189  def __init__(self,
190               cluster_resolver,
191               stream_output=False,
192               collective_leader=None):
193    self._cluster_resolver = cluster_resolver
194    self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
195    self._rpc_layer = cluster_resolver.rpc_layer
196    self._stream_output = stream_output
197    self._start_events = {}
198    self._finish_events = {}
199    self._mpr_manager = multi_process_runner.manager()
200
201    def task_function(start_events, finish_events):
202      cluster_resolver = TFConfigClusterResolver()
203      cluster_spec = cluster_resolver.cluster_spec()
204      task_type = cluster_resolver.task_type
205      task_id = cluster_resolver.task_id
206      rpc_layer = cluster_resolver.rpc_layer
207
208      # TODO(yuefengz): support GPU clusters.
209      server_config = config_pb2.ConfigProto()
210      server_config.device_count['GPU'] = 0
211
212      if collective_leader:
213        server_config.experimental.collective_group_leader = collective_leader
214        server_config.experimental.collective_nccl = False
215
216        logging.info(
217            'Enabling collective ops with cluster_spec = %r, task_type = %r, '
218            'task_id = %r, rpc_layer = %r, collective_leader = %s',
219            cluster_spec, task_type, task_id, rpc_layer, collective_leader)
220      else:
221        logging.info(
222            'Starting server with cluster_spec = %r, task_type = %r, '
223            'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
224            rpc_layer)
225
226      server_lib.Server(
227          cluster_spec,
228          job_name=task_type,
229          protocol=rpc_layer,
230          task_index=task_id,
231          config=server_config,
232          start=True)
233
234      start_event = start_events[task_type][task_id]
235      start_event.set()
236
237      finish_event = finish_events[task_type][task_id]
238      finish_event.wait()
239
240      os._exit(0)  # pylint: disable=protected-access
241
242    self._task_function = task_function
243    self._mpr = None
244
245  def start(self):
246    """Starts one TensorFlow server for each task in the cluster_resolver.
247
248    It will wait until all the servers are up before returns.
249    """
250    if self._mpr:
251      raise ValueError('The cluster has already been started.')
252    for task_type, task_addresses in self._cluster_spec.items():
253      self._start_events[task_type] = []
254      self._finish_events[task_type] = []
255      for _ in task_addresses:
256        self._start_events[task_type].append(self._mpr_manager.Event())
257        self._finish_events[task_type].append(self._mpr_manager.Event())
258
259    self._mpr = multi_process_runner.MultiProcessRunner(
260        self._task_function,
261        self._cluster_spec,
262        args=(self._start_events, self._finish_events),
263        rpc_layer=self._rpc_layer,
264        stream_output=self._stream_output,
265        return_output=False,
266        use_dill_for_args=False)
267    self._mpr.start()
268    for task_type, task_addresses in self._cluster_spec.items():
269      for i in range(len(task_addresses)):
270        self._start_events[task_type][i].wait()
271
272  def stop(self):
273    """Stops all the servers."""
274    for task_type, task_addresses in self._cluster_spec.items():
275      for i in range(len(task_addresses)):
276        self._finish_events[task_type][i].set()
277    try:
278      self._mpr.join()
279    except multi_process_runner.UnexpectedSubprocessExitError:
280      # TODO(yuefengz): investigate why processes exit with 255.
281      pass
282    self._mpr = None
283    self._start_events = {}
284    self._finish_events = {}
285
286  def kill_task(self, task_type, task_id):
287    """Kill a server given task_type and task_id.
288
289    Args:
290      task_type: the type of the task such as "worker".
291      task_id: the id the task such as 1.
292    """
293    assert self._mpr
294    if (not self._start_events[task_type][task_id].is_set() or
295        self._finish_events[task_type][task_id].is_set()):
296      raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
297
298    self._finish_events[task_type][task_id].set()
299    self._mpr._processes[(task_type, task_id)].join()
300
301  def start_task(self, task_type, task_id):
302    """Starts a server given task_type and task_id.
303
304    Args:
305      task_type: the type of the task such as "worker".
306      task_id: the id the task such as 1.
307
308    Raises:
309      ValueError: if the server already exists.
310    """
311    assert self._mpr
312
313    if (not self._start_events[task_type][task_id].is_set() or
314        not self._finish_events[task_type][task_id].is_set()):
315      raise ValueError(
316          'The task %s:%d is still alive. You cannot start another one.' %
317          (task_type, task_id))
318    self._start_events[task_type][task_id] = self._mpr_manager.Event()
319    self._finish_events[task_type][task_id] = self._mpr_manager.Event()
320    self._mpr.start_single_process(task_type=task_type, task_id=task_id)
321    self._start_events[task_type][task_id].wait()
322
323  @property
324  def cluster_resolver(self):
325    return copy.deepcopy(self._cluster_resolver)
326
327
328def create_multi_process_cluster(num_workers,
329                                 num_ps,
330                                 has_chief=False,
331                                 has_eval=False,
332                                 rpc_layer='grpc',
333                                 stream_output=False,
334                                 collective_leader=None):
335  logging.info('Now creating a MultiProcessCluster with '
336               f'num_workers={num_workers}, num_ps={num_ps}.')
337  cluster_spec = create_cluster_spec(
338      has_chief=has_chief,
339      num_workers=num_workers,
340      num_ps=num_ps,
341      has_eval=has_eval)
342
343  cluster = MultiProcessCluster(
344      SimpleClusterResolver(
345          server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer),
346      stream_output=stream_output,
347      collective_leader=collective_leader)
348  cluster.start()
349  return cluster
350
351
352@tf_export(
353    '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[])
354def create_cluster_spec(has_chief=False,
355                        num_workers=1,
356                        num_ps=0,
357                        has_eval=False):
358  """Create a cluster spec with tasks with unused local ports.
359
360  This utility finds available ports at localhost, and returns a dict that
361  represents the cluster spec that utilizes those ports, according to the
362  arguments. The dict representing the cluster spec contains task types, and
363  their instances' addresses. Note that this is usually only for testing purpose
364  using multiple processes in the local machine, and should not be used for real
365  multi-worker TensorFlow programs, where the addresses need to point to the
366  processes at separate machines.
367
368  This util is useful when creating the `cluster_spec` arg for
369  `tf.__internal__.distribute.multi_process_runner.run`.
370
371  Args:
372    has_chief: Whether the generated cluster spec should contain "chief" task
373      type.
374    num_workers: Number of workers to use in the cluster spec.
375    num_ps: Number of parameter servers to use in the cluster spec.
376    has_eval: Whether this cluster spec has evaluator.
377
378  Returns:
379    A dict that represents the cluster spec using localhost ports for the tasks.
380
381  Example:
382
383  ```python
384  cluster_spec =
385  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
386      has_chief=True, num_workers=2, num_ps=2)
387  # An example of cluster_spec is
388  # {'chief': ['localhost:23381'],
389  # 'worker': ['localhost:19197', 'localhost:22903'],
390  # 'ps': ['localhost:16912', 'localhost:21535']}
391
392  cluster_spec =
393  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
394      has_chief=False, num_workers=0, num_ps=0, has_eval=True)
395  # An example of cluster_spec is
396  # {'evaluator': ['localhost:23381']}
397  ```
398  """
399  cluster_spec = {}
400  if has_chief:
401    cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
402  if num_workers:
403    cluster_spec['worker'] = [
404        'localhost:%s' % pick_unused_port() for _ in range(num_workers)
405    ]
406  if num_ps:
407    cluster_spec['ps'] = [
408        'localhost:%s' % pick_unused_port() for _ in range(num_ps)
409    ]
410  if has_eval:
411    cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
412  return cluster_spec
413
414
415@contextlib.contextmanager
416def skip_if_grpc_server_cant_be_started(test_obj):
417  try:
418    yield
419  except errors.UnknownError as e:
420    if 'Could not start gRPC server' in e.message:
421      reason = 'Cannot start std servers.'
422      test_obj.test_skipped_reason = reason
423      test_obj.skipTest(reason)
424    else:
425      raise
426
427
428class MultiWorkerTestBase(test.TestCase):
429  """Base class for testing multi node strategy and dataset."""
430
431  @classmethod
432  def setUpClass(cls, num_workers=2, num_ps=1):  # pylint: disable=g-missing-super-call
433    """Create a local cluster with 2 workers."""
434    cls._cluster_spec = create_in_process_cluster(num_workers=num_workers,
435                                                  num_ps=num_ps)
436    cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
437
438  def setUp(self):
439    # We only cache the session in one test because another test may have a
440    # different session config or master target.
441    self._thread_local = threading.local()
442    self._thread_local.cached_session = None
443    self._coord = coordinator.Coordinator()
444
445  @contextlib.contextmanager
446  def session(self, graph=None, config=None, target=None):
447    """Create a test session with master target set to the testing cluster.
448
449    Creates a test session that connects to the local testing cluster.
450
451    Args:
452      graph: Optional graph to use during the returned session.
453      config: An optional config_pb2.ConfigProto to use to configure the
454        session.
455      target: the target of session to connect to.
456
457    Yields:
458      A Session object that should be used as a context manager to surround
459      the graph building and execution code in a test case.
460    """
461    config = self._create_config(config)
462
463    if target is None:
464      target = self._default_target
465    with session.Session(graph=graph, config=config, target=target) as sess:
466      yield sess
467
468  @contextlib.contextmanager
469  # TODO(b/117573461): Overwrite self.evaluate() to use this function.
470  def cached_session(self, graph=None, config=None, target=None):
471    """Create a test session with master target set to the testing cluster.
472
473    Creates a test session that connects to the local testing cluster.
474    The session is only created once per test and then reused.
475
476    Args:
477      graph: Optional graph to use during the returned session.
478      config: An optional config_pb2.ConfigProto to use to configure the
479        session.
480      target: the target of session to connect to.
481
482    Yields:
483      A Session object that should be used as a context manager to surround
484      the graph building and execution code in a test case. Note that the
485      session will live until the end of the test.
486    """
487    config = self._create_config(config)
488
489    if target is None:
490      target = self._default_target
491    if getattr(self._thread_local, 'cached_session', None) is None:
492      self._thread_local.cached_session = session.Session(
493          graph=None, config=config, target=target)
494    sess = self._thread_local.cached_session
495    with sess.graph.as_default(), sess.as_default():
496      yield sess
497
498  def _create_config(self, config):
499    if config is None:
500      config = config_pb2.ConfigProto(allow_soft_placement=True)
501    else:
502      config = copy.deepcopy(config)
503    # Don't perform optimizations for tests so we don't inadvertently run
504    # gpu ops on cpu
505    config.graph_options.optimizer_options.opt_level = -1
506    config.graph_options.rewrite_options.constant_folding = (
507        rewriter_config_pb2.RewriterConfig.OFF)
508
509    return config
510
511  def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode,
512                  *args, **kwargs):
513
514    def wrapped_client_fn():
515      with self._coord.stop_on_exception():
516        client_fn(task_type, task_id, num_gpus, *args, **kwargs)
517
518    if eager_mode:
519      with context.eager_mode():
520        wrapped_client_fn()
521    else:
522      with context.graph_mode():
523        wrapped_client_fn()
524
525  def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
526                                 **kwargs):
527    """Runs several clients for between-graph replication.
528
529    Args:
530      client_fn: a function that needs to accept `task_type`, `task_id`,
531        `num_gpus`.
532      cluster_spec: a dict specifying jobs in a cluster.
533      num_gpus: number of GPUs per worker.
534      *args: will be passed to `client_fn`.
535      **kwargs: will be passed to `client_fn`.
536    """
537    threads = []
538    for task_type in ['chief', 'worker']:
539      for task_id in range(len(cluster_spec.get(task_type, []))):
540        t = threading.Thread(
541            target=self._run_client,
542            args=(client_fn, task_type, task_id, num_gpus,
543                  context.executing_eagerly()) + args,
544            kwargs=kwargs)
545        t.start()
546        threads.append(t)
547    self._coord.join(threads)
548
549
550class SingleWorkerTestBaseGraph(MultiWorkerTestBase):
551  """Base class for testing remote single worker strategy graph and dataset."""
552
553  @classmethod
554  def setUpClass(cls):
555    super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1)
556
557
558class SingleWorkerTestBaseEager(test.TestCase):
559  """Base class for testing remote single worker strategy eager and dataset."""
560
561  def setUp(self):
562    super(SingleWorkerTestBaseEager, self).setUp()
563    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
564    remote.connect_to_remote_host(workers[0].target)
565
566  def cached_session(self):
567    return DummySession()
568
569
570class DummySession(object):
571
572  def __enter__(self):
573    return
574
575  def __exit__(self, exception_type, exception_value, traceback):
576    pass
577
578
579class MockOsEnv(collections_abc.Mapping):
580  """A class that allows per-thread TF_CONFIG."""
581
582  def __init__(self, *args):
583    self._dict = dict()
584    self._thread_local = threading.local()
585    super(MockOsEnv, self).__init__(*args)
586
587  def get(self, key, default=None):
588    if not hasattr(self._thread_local, 'dict'):
589      self._thread_local.dict = dict()
590    if key == 'TF_CONFIG':
591      return dict.get(self._thread_local.dict, key, default)
592    else:
593      return dict.get(self._dict, key, default)
594
595  def __getitem__(self, key):
596    if not hasattr(self._thread_local, 'dict'):
597      self._thread_local.dict = dict()
598    if key == 'TF_CONFIG':
599      return dict.__getitem__(self._thread_local.dict, key)
600    else:
601      return dict.__getitem__(self._dict, key)
602
603  def __setitem__(self, key, val):
604    if not hasattr(self._thread_local, 'dict'):
605      self._thread_local.dict = dict()
606    if key == 'TF_CONFIG':
607      return dict.__setitem__(self._thread_local.dict, key, val)
608    else:
609      return dict.__setitem__(self._dict, key, val)
610
611  def __iter__(self):
612    if not hasattr(self._thread_local, 'dict'):
613      self._thread_local.dict = dict()
614    for x in self._thread_local.dict:
615      yield x
616    for x in self._dict:
617      yield x
618
619  def __len__(self):
620    if not hasattr(self._thread_local, 'dict'):
621      self._thread_local.dict = dict()
622    return self._thread_local.dict.__len__() + self._dict.__len__()
623
624
625class IndependentWorkerTestBase(test.TestCase):
626  """Testing infra for independent workers."""
627
628  def _make_mock_run_std_server(self):
629
630    def _mock_run_std_server(*args, **kwargs):
631      """Returns the std server once all threads have started it."""
632      with skip_if_grpc_server_cant_be_started(self):
633        ret = original_run_std_server(*args, **kwargs)
634      # Wait for all std servers to be brought up in order to reduce the chance
635      # of remote sessions taking local ports that have been assigned to std
636      # servers. Only call this barrier the first time this function is run for
637      # each thread.
638      if not getattr(self._thread_local, 'server_started', False):
639        self._barrier.wait()
640      self._thread_local.server_started = True
641      return ret
642
643    return _mock_run_std_server
644
645  def setUp(self):
646    self._mock_os_env = MockOsEnv()
647    self._mock_context = test.mock.patch.object(os, 'environ',
648                                                self._mock_os_env)
649    self._coord = coordinator.Coordinator()
650    super(IndependentWorkerTestBase, self).setUp()
651    self._mock_context.__enter__()
652    # threading local object to be shared by all threads
653    self._thread_local = threading.local()
654
655  def tearDown(self):
656    self._mock_context.__exit__(None, None, None)
657    super(IndependentWorkerTestBase, self).tearDown()
658
659  def _task_thread(self, task_fn, tf_config, executing_eagerly, *args,
660                   **kwargs):
661    with self._coord.stop_on_exception():
662      os.environ['TF_CONFIG'] = json.dumps(tf_config)
663      # Force the new thread simulating a worker to run in the same context
664      # mode as the parent thread does.
665      if executing_eagerly:
666        with context.eager_mode():
667          task_fn(*args, **kwargs)
668      else:
669        with ops.Graph().as_default(), context.graph_mode():
670          task_fn(*args, **kwargs)
671
672  def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
673                          *args, **kwargs):
674    """Run tasks in a thread.
675
676    If `tf_config` is provided, use it for the new thread; if not, construct one
677    from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new
678    thread to be set as `TF_CONFIG` environment.
679
680    Args:
681      task_fn: The function to run in the new thread.
682      cluster_spec: The cluster spec.
683      task_type: The task type.
684      task_id: The task id.
685      *args: Additional positional arguments to provide to the thread's task_fn.
686      **kwargs: Additional keyword arguments to provide to the thread's task_fn.
687        If `tf_config` is provided, that dict will be used for the TF_CONFIG for
688        the new thread.
689
690    Returns:
691      The thread that has started.
692    """
693    tf_config = kwargs.pop('tf_config', None)
694    if tf_config is None:
695      if task_type:
696        tf_config = {
697            'cluster': cluster_spec,
698            'task': {
699                'type': task_type,
700                'index': task_id
701            }
702        }
703      else:
704        tf_config = {
705            'cluster': cluster_spec,
706        }
707    t = threading.Thread(
708        target=self._task_thread,
709        args=(task_fn, tf_config, context.executing_eagerly()) + args,
710        kwargs=kwargs)
711    t.start()
712    return t
713
714  def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
715                                    **kwargs):
716    # The task_fn should create std_server by itself.
717    threads = {}
718    for task_type in cluster_spec.keys():
719      threads[task_type] = []
720      for task_id in range(len(cluster_spec[task_type])):
721        t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
722                                     *args, **kwargs)
723        threads[task_type].append(t)
724    return threads
725
726  def join_independent_workers(self, worker_threads):
727    with skip_if_grpc_server_cant_be_started(self):
728      self._coord.join(worker_threads)
729
730
731class MultiWorkerMultiProcessTest(test.TestCase):
732  """Testing infra for independent workers using multiple processes."""
733
734  def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id):
735    env = os.environ.copy()
736    env['TF_CONFIG'] = json.dumps({
737        'cluster': cluster_spec,
738        'task': {
739            'type': task_type,
740            'index': task_id
741        }
742    })
743    return subprocess.Popen(
744        cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
745
746  @deprecation.deprecated(
747      None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
748      'requiring multiple processes should use `multi_process_runner` for '
749      'better support of log printing, streaming, and more functionality.')
750  def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
751    """Run `cmd_args` in a process for each task in `cluster_spec`."""
752    processes = {}
753    for task_type in cluster_spec.keys():
754      processes[task_type] = []
755      for task_id in range(len(cluster_spec[task_type])):
756        p = self._run_task_in_process(cmd_args, cluster_spec, task_type,
757                                      task_id)
758        processes[task_type].append(p)
759    return processes
760
761  @deprecation.deprecated(
762      None, '`join_independent_workers` is deprecated; any new test '
763      'requiring multiple processes should use `multi_process_runner` for '
764      'better support of log printing, streaming, and more functionality.')
765  def join_independent_workers(self, worker_processes):
766    return_codes = []
767    for p in nest.flatten(worker_processes):
768      try:
769        # Calling p.wait() will hang if we don't consume its output.
770        p.communicate()
771      except ValueError:
772        # The output of the process may have been consumed, in which case
773        # calling `p.communicate()` will raise a ValueError.
774        pass
775      finally:
776        return_codes.append(p.returncode)
777    for return_code in return_codes:
778      self.assertEqual(return_code, 0)
779
780  @deprecation.deprecated(
781      None, '`stream_stderr` is deprecated; any new test '
782      'requiring multiple processes should use `multi_process_runner` for '
783      'better support of log printing, streaming, and more functionality.')
784  def stream_stderr(self, processes, print_only_first=False):
785    """Consume stderr of all processes and print to stdout.
786
787    To reduce the amount of logging, caller can set print_only_first to True.
788    In that case, this function only prints stderr from the first process of
789    each type.
790
791    Args:
792      processes: A dictionary from process type string -> list of processes.
793      print_only_first: If true, only print output from first process of each
794        type.
795    """
796
797    def _stream_stderr_single_process(process, type_string, index,
798                                      print_to_stdout):
799      """Consume a single process's stderr and optionally print to stdout."""
800      while True:
801        output = process.stderr.readline()
802        if not output and process.poll() is not None:
803          break
804        if output and print_to_stdout:
805          print('{}{} {}'.format(type_string, index, output.strip()))
806          sys.stdout.flush()
807
808    stream_threads = []
809    for process_type, process_list in six.iteritems(processes):
810      for i in range(len(process_list)):
811        print_to_stdout = (not print_only_first) or (i == 0)
812        thread = threading.Thread(
813            target=_stream_stderr_single_process,
814            args=(process_list[i], process_type, i, print_to_stdout))
815        thread.start()
816        stream_threads.append(thread)
817    for thread in stream_threads:
818      thread.join()
819
820
821def get_tf_config_task():
822  return json.loads(os.environ['TF_CONFIG'])['task']
823
824
825def get_tf_config_cluster_spec():
826  return json.loads(os.environ['TF_CONFIG'])['cluster']
827
828
829def get_task_type():
830  return get_tf_config_task()['type']
831
832
833def get_task_index():
834  return get_tf_config_task()['index']
835
836
837def is_chief():
838  return ('chief' not in get_tf_config_cluster_spec()
839          and get_task_type() == 'worker'
840          and get_task_index() == 0)
841