1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base testing class for strategies that require multiple nodes.""" 16 17import contextlib 18import copy 19import json 20import os 21import subprocess 22import sys 23import threading 24import unittest 25 26import six 27 28 29# pylint: disable=g-import-not-at-top 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python.client import session 33from tensorflow.python.distribute import distribute_coordinator as dc 34from tensorflow.python.distribute import multi_process_runner 35from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 36from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 37from tensorflow.python.eager import context 38from tensorflow.python.eager import remote 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import test_util 42from tensorflow.python.platform import test 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import coordinator 45from tensorflow.python.training import server_lib 46from tensorflow.python.util import deprecation 47from tensorflow.python.util import nest 48from tensorflow.python.util.compat import collections_abc 49from tensorflow.python.util.tf_export import tf_export 50 51 52original_run_std_server = dc._run_std_server # pylint: disable=protected-access 53pick_unused_port = test_util.pick_unused_port 54 55 56def _create_cluster(num_workers, 57 num_ps, 58 has_chief=False, 59 has_eval=False, 60 protocol='grpc', 61 worker_config=None, 62 ps_config=None, 63 eval_config=None, 64 worker_name='worker', 65 ps_name='ps', 66 chief_name='chief'): 67 """Creates and starts local servers and returns the cluster_spec dict.""" 68 69 worker_ports = [pick_unused_port() for _ in range(num_workers)] 70 ps_ports = [pick_unused_port() for _ in range(num_ps)] 71 72 cluster_dict = {} 73 if num_workers > 0: 74 cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports] 75 if num_ps > 0: 76 cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports] 77 if has_eval: 78 cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] 79 if has_chief: 80 cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()] 81 82 cs = server_lib.ClusterSpec(cluster_dict) 83 84 for i in range(num_workers): 85 server_lib.Server( 86 cs, 87 job_name=worker_name, 88 protocol=protocol, 89 task_index=i, 90 config=worker_config, 91 start=True) 92 93 for i in range(num_ps): 94 server_lib.Server( 95 cs, 96 job_name=ps_name, 97 protocol=protocol, 98 task_index=i, 99 config=ps_config, 100 start=True) 101 102 if has_chief: 103 server_lib.Server( 104 cs, 105 job_name=chief_name, 106 protocol=protocol, 107 task_index=0, 108 config=worker_config, 109 start=True) 110 111 if has_eval: 112 server_lib.Server( 113 cs, 114 job_name='evaluator', 115 protocol=protocol, 116 task_index=0, 117 config=eval_config, 118 start=True) 119 120 return cluster_dict 121 122 123def create_in_process_cluster(num_workers, 124 num_ps, 125 has_chief=False, 126 has_eval=False, 127 rpc_layer='grpc'): 128 """Create an in-process cluster that consists of only standard server.""" 129 # Leave some memory for cuda runtime. 130 gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) 131 worker_config = config_pb2.ConfigProto() 132 worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac 133 134 # The cluster may hang if workers don't have enough inter_op threads. See 135 # b/172296720 for more details. 136 if worker_config.inter_op_parallelism_threads < num_workers + 1: 137 worker_config.inter_op_parallelism_threads = num_workers + 1 138 139 # Enable collective ops which has no impact on non-collective ops. 140 # TODO(yuefengz, tucker): removing this after we move the initialization of 141 # collective mgr to the session level. 142 if has_chief: 143 worker_config.experimental.collective_group_leader = ( 144 '/job:chief/replica:0/task:0') 145 else: 146 worker_config.experimental.collective_group_leader = ( 147 '/job:worker/replica:0/task:0') 148 149 ps_config = config_pb2.ConfigProto() 150 ps_config.device_count['GPU'] = 0 151 152 eval_config = config_pb2.ConfigProto() 153 eval_config.experimental.collective_group_leader = '' 154 155 # Create in-process servers. Once an in-process tensorflow server is created, 156 # there is no way to terminate it. So we create one cluster per test process. 157 # We could've started the server in another process, we could then kill that 158 # process to terminate the server. The reasons why we don't want multiple 159 # processes are 160 # 1) it is more difficult to manage these processes; 161 # 2) there is something global in CUDA such that if we initialize CUDA in the 162 # parent process, the child process cannot initialize it again and thus cannot 163 # use GPUs (https://stackoverflow.com/questions/22950047). 164 cluster = None 165 try: 166 cluster = _create_cluster( 167 num_workers, 168 num_ps=num_ps, 169 has_chief=has_chief, 170 has_eval=has_eval, 171 worker_config=worker_config, 172 ps_config=ps_config, 173 eval_config=eval_config, 174 protocol=rpc_layer) 175 except errors.UnknownError as e: 176 if 'Could not start gRPC server' in e.message: 177 raise unittest.SkipTest('Cannot start std servers.') 178 else: 179 raise 180 return cluster 181 182 183class MultiProcessCluster(object): 184 """A cluster of TensorFlow servers in separate processes. 185 186 This class is not thread-safe. 187 """ 188 189 def __init__(self, 190 cluster_resolver, 191 stream_output=False, 192 collective_leader=None): 193 self._cluster_resolver = cluster_resolver 194 self._cluster_spec = cluster_resolver.cluster_spec().as_dict() 195 self._rpc_layer = cluster_resolver.rpc_layer 196 self._stream_output = stream_output 197 self._start_events = {} 198 self._finish_events = {} 199 self._mpr_manager = multi_process_runner.manager() 200 201 def task_function(start_events, finish_events): 202 cluster_resolver = TFConfigClusterResolver() 203 cluster_spec = cluster_resolver.cluster_spec() 204 task_type = cluster_resolver.task_type 205 task_id = cluster_resolver.task_id 206 rpc_layer = cluster_resolver.rpc_layer 207 208 # TODO(yuefengz): support GPU clusters. 209 server_config = config_pb2.ConfigProto() 210 server_config.device_count['GPU'] = 0 211 212 if collective_leader: 213 server_config.experimental.collective_group_leader = collective_leader 214 server_config.experimental.collective_nccl = False 215 216 logging.info( 217 'Enabling collective ops with cluster_spec = %r, task_type = %r, ' 218 'task_id = %r, rpc_layer = %r, collective_leader = %s', 219 cluster_spec, task_type, task_id, rpc_layer, collective_leader) 220 else: 221 logging.info( 222 'Starting server with cluster_spec = %r, task_type = %r, ' 223 'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id, 224 rpc_layer) 225 226 server_lib.Server( 227 cluster_spec, 228 job_name=task_type, 229 protocol=rpc_layer, 230 task_index=task_id, 231 config=server_config, 232 start=True) 233 234 start_event = start_events[task_type][task_id] 235 start_event.set() 236 237 finish_event = finish_events[task_type][task_id] 238 finish_event.wait() 239 240 os._exit(0) # pylint: disable=protected-access 241 242 self._task_function = task_function 243 self._mpr = None 244 245 def start(self): 246 """Starts one TensorFlow server for each task in the cluster_resolver. 247 248 It will wait until all the servers are up before returns. 249 """ 250 if self._mpr: 251 raise ValueError('The cluster has already been started.') 252 for task_type, task_addresses in self._cluster_spec.items(): 253 self._start_events[task_type] = [] 254 self._finish_events[task_type] = [] 255 for _ in task_addresses: 256 self._start_events[task_type].append(self._mpr_manager.Event()) 257 self._finish_events[task_type].append(self._mpr_manager.Event()) 258 259 self._mpr = multi_process_runner.MultiProcessRunner( 260 self._task_function, 261 self._cluster_spec, 262 args=(self._start_events, self._finish_events), 263 rpc_layer=self._rpc_layer, 264 stream_output=self._stream_output, 265 return_output=False, 266 use_dill_for_args=False) 267 self._mpr.start() 268 for task_type, task_addresses in self._cluster_spec.items(): 269 for i in range(len(task_addresses)): 270 self._start_events[task_type][i].wait() 271 272 def stop(self): 273 """Stops all the servers.""" 274 for task_type, task_addresses in self._cluster_spec.items(): 275 for i in range(len(task_addresses)): 276 self._finish_events[task_type][i].set() 277 try: 278 self._mpr.join() 279 except multi_process_runner.UnexpectedSubprocessExitError: 280 # TODO(yuefengz): investigate why processes exit with 255. 281 pass 282 self._mpr = None 283 self._start_events = {} 284 self._finish_events = {} 285 286 def kill_task(self, task_type, task_id): 287 """Kill a server given task_type and task_id. 288 289 Args: 290 task_type: the type of the task such as "worker". 291 task_id: the id the task such as 1. 292 """ 293 assert self._mpr 294 if (not self._start_events[task_type][task_id].is_set() or 295 self._finish_events[task_type][task_id].is_set()): 296 raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id)) 297 298 self._finish_events[task_type][task_id].set() 299 self._mpr._processes[(task_type, task_id)].join() 300 301 def start_task(self, task_type, task_id): 302 """Starts a server given task_type and task_id. 303 304 Args: 305 task_type: the type of the task such as "worker". 306 task_id: the id the task such as 1. 307 308 Raises: 309 ValueError: if the server already exists. 310 """ 311 assert self._mpr 312 313 if (not self._start_events[task_type][task_id].is_set() or 314 not self._finish_events[task_type][task_id].is_set()): 315 raise ValueError( 316 'The task %s:%d is still alive. You cannot start another one.' % 317 (task_type, task_id)) 318 self._start_events[task_type][task_id] = self._mpr_manager.Event() 319 self._finish_events[task_type][task_id] = self._mpr_manager.Event() 320 self._mpr.start_single_process(task_type=task_type, task_id=task_id) 321 self._start_events[task_type][task_id].wait() 322 323 @property 324 def cluster_resolver(self): 325 return copy.deepcopy(self._cluster_resolver) 326 327 328def create_multi_process_cluster(num_workers, 329 num_ps, 330 has_chief=False, 331 has_eval=False, 332 rpc_layer='grpc', 333 stream_output=False, 334 collective_leader=None): 335 logging.info('Now creating a MultiProcessCluster with ' 336 f'num_workers={num_workers}, num_ps={num_ps}.') 337 cluster_spec = create_cluster_spec( 338 has_chief=has_chief, 339 num_workers=num_workers, 340 num_ps=num_ps, 341 has_eval=has_eval) 342 343 cluster = MultiProcessCluster( 344 SimpleClusterResolver( 345 server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer), 346 stream_output=stream_output, 347 collective_leader=collective_leader) 348 cluster.start() 349 return cluster 350 351 352@tf_export( 353 '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[]) 354def create_cluster_spec(has_chief=False, 355 num_workers=1, 356 num_ps=0, 357 has_eval=False): 358 """Create a cluster spec with tasks with unused local ports. 359 360 This utility finds available ports at localhost, and returns a dict that 361 represents the cluster spec that utilizes those ports, according to the 362 arguments. The dict representing the cluster spec contains task types, and 363 their instances' addresses. Note that this is usually only for testing purpose 364 using multiple processes in the local machine, and should not be used for real 365 multi-worker TensorFlow programs, where the addresses need to point to the 366 processes at separate machines. 367 368 This util is useful when creating the `cluster_spec` arg for 369 `tf.__internal__.distribute.multi_process_runner.run`. 370 371 Args: 372 has_chief: Whether the generated cluster spec should contain "chief" task 373 type. 374 num_workers: Number of workers to use in the cluster spec. 375 num_ps: Number of parameter servers to use in the cluster spec. 376 has_eval: Whether this cluster spec has evaluator. 377 378 Returns: 379 A dict that represents the cluster spec using localhost ports for the tasks. 380 381 Example: 382 383 ```python 384 cluster_spec = 385 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 386 has_chief=True, num_workers=2, num_ps=2) 387 # An example of cluster_spec is 388 # {'chief': ['localhost:23381'], 389 # 'worker': ['localhost:19197', 'localhost:22903'], 390 # 'ps': ['localhost:16912', 'localhost:21535']} 391 392 cluster_spec = 393 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 394 has_chief=False, num_workers=0, num_ps=0, has_eval=True) 395 # An example of cluster_spec is 396 # {'evaluator': ['localhost:23381']} 397 ``` 398 """ 399 cluster_spec = {} 400 if has_chief: 401 cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] 402 if num_workers: 403 cluster_spec['worker'] = [ 404 'localhost:%s' % pick_unused_port() for _ in range(num_workers) 405 ] 406 if num_ps: 407 cluster_spec['ps'] = [ 408 'localhost:%s' % pick_unused_port() for _ in range(num_ps) 409 ] 410 if has_eval: 411 cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] 412 return cluster_spec 413 414 415@contextlib.contextmanager 416def skip_if_grpc_server_cant_be_started(test_obj): 417 try: 418 yield 419 except errors.UnknownError as e: 420 if 'Could not start gRPC server' in e.message: 421 reason = 'Cannot start std servers.' 422 test_obj.test_skipped_reason = reason 423 test_obj.skipTest(reason) 424 else: 425 raise 426 427 428class MultiWorkerTestBase(test.TestCase): 429 """Base class for testing multi node strategy and dataset.""" 430 431 @classmethod 432 def setUpClass(cls, num_workers=2, num_ps=1): # pylint: disable=g-missing-super-call 433 """Create a local cluster with 2 workers.""" 434 cls._cluster_spec = create_in_process_cluster(num_workers=num_workers, 435 num_ps=num_ps) 436 cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] 437 438 def setUp(self): 439 # We only cache the session in one test because another test may have a 440 # different session config or master target. 441 self._thread_local = threading.local() 442 self._thread_local.cached_session = None 443 self._coord = coordinator.Coordinator() 444 445 @contextlib.contextmanager 446 def session(self, graph=None, config=None, target=None): 447 """Create a test session with master target set to the testing cluster. 448 449 Creates a test session that connects to the local testing cluster. 450 451 Args: 452 graph: Optional graph to use during the returned session. 453 config: An optional config_pb2.ConfigProto to use to configure the 454 session. 455 target: the target of session to connect to. 456 457 Yields: 458 A Session object that should be used as a context manager to surround 459 the graph building and execution code in a test case. 460 """ 461 config = self._create_config(config) 462 463 if target is None: 464 target = self._default_target 465 with session.Session(graph=graph, config=config, target=target) as sess: 466 yield sess 467 468 @contextlib.contextmanager 469 # TODO(b/117573461): Overwrite self.evaluate() to use this function. 470 def cached_session(self, graph=None, config=None, target=None): 471 """Create a test session with master target set to the testing cluster. 472 473 Creates a test session that connects to the local testing cluster. 474 The session is only created once per test and then reused. 475 476 Args: 477 graph: Optional graph to use during the returned session. 478 config: An optional config_pb2.ConfigProto to use to configure the 479 session. 480 target: the target of session to connect to. 481 482 Yields: 483 A Session object that should be used as a context manager to surround 484 the graph building and execution code in a test case. Note that the 485 session will live until the end of the test. 486 """ 487 config = self._create_config(config) 488 489 if target is None: 490 target = self._default_target 491 if getattr(self._thread_local, 'cached_session', None) is None: 492 self._thread_local.cached_session = session.Session( 493 graph=None, config=config, target=target) 494 sess = self._thread_local.cached_session 495 with sess.graph.as_default(), sess.as_default(): 496 yield sess 497 498 def _create_config(self, config): 499 if config is None: 500 config = config_pb2.ConfigProto(allow_soft_placement=True) 501 else: 502 config = copy.deepcopy(config) 503 # Don't perform optimizations for tests so we don't inadvertently run 504 # gpu ops on cpu 505 config.graph_options.optimizer_options.opt_level = -1 506 config.graph_options.rewrite_options.constant_folding = ( 507 rewriter_config_pb2.RewriterConfig.OFF) 508 509 return config 510 511 def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode, 512 *args, **kwargs): 513 514 def wrapped_client_fn(): 515 with self._coord.stop_on_exception(): 516 client_fn(task_type, task_id, num_gpus, *args, **kwargs) 517 518 if eager_mode: 519 with context.eager_mode(): 520 wrapped_client_fn() 521 else: 522 with context.graph_mode(): 523 wrapped_client_fn() 524 525 def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, 526 **kwargs): 527 """Runs several clients for between-graph replication. 528 529 Args: 530 client_fn: a function that needs to accept `task_type`, `task_id`, 531 `num_gpus`. 532 cluster_spec: a dict specifying jobs in a cluster. 533 num_gpus: number of GPUs per worker. 534 *args: will be passed to `client_fn`. 535 **kwargs: will be passed to `client_fn`. 536 """ 537 threads = [] 538 for task_type in ['chief', 'worker']: 539 for task_id in range(len(cluster_spec.get(task_type, []))): 540 t = threading.Thread( 541 target=self._run_client, 542 args=(client_fn, task_type, task_id, num_gpus, 543 context.executing_eagerly()) + args, 544 kwargs=kwargs) 545 t.start() 546 threads.append(t) 547 self._coord.join(threads) 548 549 550class SingleWorkerTestBaseGraph(MultiWorkerTestBase): 551 """Base class for testing remote single worker strategy graph and dataset.""" 552 553 @classmethod 554 def setUpClass(cls): 555 super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1) 556 557 558class SingleWorkerTestBaseEager(test.TestCase): 559 """Base class for testing remote single worker strategy eager and dataset.""" 560 561 def setUp(self): 562 super(SingleWorkerTestBaseEager, self).setUp() 563 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 564 remote.connect_to_remote_host(workers[0].target) 565 566 def cached_session(self): 567 return DummySession() 568 569 570class DummySession(object): 571 572 def __enter__(self): 573 return 574 575 def __exit__(self, exception_type, exception_value, traceback): 576 pass 577 578 579class MockOsEnv(collections_abc.Mapping): 580 """A class that allows per-thread TF_CONFIG.""" 581 582 def __init__(self, *args): 583 self._dict = dict() 584 self._thread_local = threading.local() 585 super(MockOsEnv, self).__init__(*args) 586 587 def get(self, key, default=None): 588 if not hasattr(self._thread_local, 'dict'): 589 self._thread_local.dict = dict() 590 if key == 'TF_CONFIG': 591 return dict.get(self._thread_local.dict, key, default) 592 else: 593 return dict.get(self._dict, key, default) 594 595 def __getitem__(self, key): 596 if not hasattr(self._thread_local, 'dict'): 597 self._thread_local.dict = dict() 598 if key == 'TF_CONFIG': 599 return dict.__getitem__(self._thread_local.dict, key) 600 else: 601 return dict.__getitem__(self._dict, key) 602 603 def __setitem__(self, key, val): 604 if not hasattr(self._thread_local, 'dict'): 605 self._thread_local.dict = dict() 606 if key == 'TF_CONFIG': 607 return dict.__setitem__(self._thread_local.dict, key, val) 608 else: 609 return dict.__setitem__(self._dict, key, val) 610 611 def __iter__(self): 612 if not hasattr(self._thread_local, 'dict'): 613 self._thread_local.dict = dict() 614 for x in self._thread_local.dict: 615 yield x 616 for x in self._dict: 617 yield x 618 619 def __len__(self): 620 if not hasattr(self._thread_local, 'dict'): 621 self._thread_local.dict = dict() 622 return self._thread_local.dict.__len__() + self._dict.__len__() 623 624 625class IndependentWorkerTestBase(test.TestCase): 626 """Testing infra for independent workers.""" 627 628 def _make_mock_run_std_server(self): 629 630 def _mock_run_std_server(*args, **kwargs): 631 """Returns the std server once all threads have started it.""" 632 with skip_if_grpc_server_cant_be_started(self): 633 ret = original_run_std_server(*args, **kwargs) 634 # Wait for all std servers to be brought up in order to reduce the chance 635 # of remote sessions taking local ports that have been assigned to std 636 # servers. Only call this barrier the first time this function is run for 637 # each thread. 638 if not getattr(self._thread_local, 'server_started', False): 639 self._barrier.wait() 640 self._thread_local.server_started = True 641 return ret 642 643 return _mock_run_std_server 644 645 def setUp(self): 646 self._mock_os_env = MockOsEnv() 647 self._mock_context = test.mock.patch.object(os, 'environ', 648 self._mock_os_env) 649 self._coord = coordinator.Coordinator() 650 super(IndependentWorkerTestBase, self).setUp() 651 self._mock_context.__enter__() 652 # threading local object to be shared by all threads 653 self._thread_local = threading.local() 654 655 def tearDown(self): 656 self._mock_context.__exit__(None, None, None) 657 super(IndependentWorkerTestBase, self).tearDown() 658 659 def _task_thread(self, task_fn, tf_config, executing_eagerly, *args, 660 **kwargs): 661 with self._coord.stop_on_exception(): 662 os.environ['TF_CONFIG'] = json.dumps(tf_config) 663 # Force the new thread simulating a worker to run in the same context 664 # mode as the parent thread does. 665 if executing_eagerly: 666 with context.eager_mode(): 667 task_fn(*args, **kwargs) 668 else: 669 with ops.Graph().as_default(), context.graph_mode(): 670 task_fn(*args, **kwargs) 671 672 def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, 673 *args, **kwargs): 674 """Run tasks in a thread. 675 676 If `tf_config` is provided, use it for the new thread; if not, construct one 677 from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new 678 thread to be set as `TF_CONFIG` environment. 679 680 Args: 681 task_fn: The function to run in the new thread. 682 cluster_spec: The cluster spec. 683 task_type: The task type. 684 task_id: The task id. 685 *args: Additional positional arguments to provide to the thread's task_fn. 686 **kwargs: Additional keyword arguments to provide to the thread's task_fn. 687 If `tf_config` is provided, that dict will be used for the TF_CONFIG for 688 the new thread. 689 690 Returns: 691 The thread that has started. 692 """ 693 tf_config = kwargs.pop('tf_config', None) 694 if tf_config is None: 695 if task_type: 696 tf_config = { 697 'cluster': cluster_spec, 698 'task': { 699 'type': task_type, 700 'index': task_id 701 } 702 } 703 else: 704 tf_config = { 705 'cluster': cluster_spec, 706 } 707 t = threading.Thread( 708 target=self._task_thread, 709 args=(task_fn, tf_config, context.executing_eagerly()) + args, 710 kwargs=kwargs) 711 t.start() 712 return t 713 714 def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, 715 **kwargs): 716 # The task_fn should create std_server by itself. 717 threads = {} 718 for task_type in cluster_spec.keys(): 719 threads[task_type] = [] 720 for task_id in range(len(cluster_spec[task_type])): 721 t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, 722 *args, **kwargs) 723 threads[task_type].append(t) 724 return threads 725 726 def join_independent_workers(self, worker_threads): 727 with skip_if_grpc_server_cant_be_started(self): 728 self._coord.join(worker_threads) 729 730 731class MultiWorkerMultiProcessTest(test.TestCase): 732 """Testing infra for independent workers using multiple processes.""" 733 734 def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id): 735 env = os.environ.copy() 736 env['TF_CONFIG'] = json.dumps({ 737 'cluster': cluster_spec, 738 'task': { 739 'type': task_type, 740 'index': task_id 741 } 742 }) 743 return subprocess.Popen( 744 cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) 745 746 @deprecation.deprecated( 747 None, '`run_multiple_tasks_in_processes` is deprecated; any new test ' 748 'requiring multiple processes should use `multi_process_runner` for ' 749 'better support of log printing, streaming, and more functionality.') 750 def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec): 751 """Run `cmd_args` in a process for each task in `cluster_spec`.""" 752 processes = {} 753 for task_type in cluster_spec.keys(): 754 processes[task_type] = [] 755 for task_id in range(len(cluster_spec[task_type])): 756 p = self._run_task_in_process(cmd_args, cluster_spec, task_type, 757 task_id) 758 processes[task_type].append(p) 759 return processes 760 761 @deprecation.deprecated( 762 None, '`join_independent_workers` is deprecated; any new test ' 763 'requiring multiple processes should use `multi_process_runner` for ' 764 'better support of log printing, streaming, and more functionality.') 765 def join_independent_workers(self, worker_processes): 766 return_codes = [] 767 for p in nest.flatten(worker_processes): 768 try: 769 # Calling p.wait() will hang if we don't consume its output. 770 p.communicate() 771 except ValueError: 772 # The output of the process may have been consumed, in which case 773 # calling `p.communicate()` will raise a ValueError. 774 pass 775 finally: 776 return_codes.append(p.returncode) 777 for return_code in return_codes: 778 self.assertEqual(return_code, 0) 779 780 @deprecation.deprecated( 781 None, '`stream_stderr` is deprecated; any new test ' 782 'requiring multiple processes should use `multi_process_runner` for ' 783 'better support of log printing, streaming, and more functionality.') 784 def stream_stderr(self, processes, print_only_first=False): 785 """Consume stderr of all processes and print to stdout. 786 787 To reduce the amount of logging, caller can set print_only_first to True. 788 In that case, this function only prints stderr from the first process of 789 each type. 790 791 Args: 792 processes: A dictionary from process type string -> list of processes. 793 print_only_first: If true, only print output from first process of each 794 type. 795 """ 796 797 def _stream_stderr_single_process(process, type_string, index, 798 print_to_stdout): 799 """Consume a single process's stderr and optionally print to stdout.""" 800 while True: 801 output = process.stderr.readline() 802 if not output and process.poll() is not None: 803 break 804 if output and print_to_stdout: 805 print('{}{} {}'.format(type_string, index, output.strip())) 806 sys.stdout.flush() 807 808 stream_threads = [] 809 for process_type, process_list in six.iteritems(processes): 810 for i in range(len(process_list)): 811 print_to_stdout = (not print_only_first) or (i == 0) 812 thread = threading.Thread( 813 target=_stream_stderr_single_process, 814 args=(process_list[i], process_type, i, print_to_stdout)) 815 thread.start() 816 stream_threads.append(thread) 817 for thread in stream_threads: 818 thread.join() 819 820 821def get_tf_config_task(): 822 return json.loads(os.environ['TF_CONFIG'])['task'] 823 824 825def get_tf_config_cluster_spec(): 826 return json.loads(os.environ['TF_CONFIG'])['cluster'] 827 828 829def get_task_type(): 830 return get_tf_config_task()['type'] 831 832 833def get_task_index(): 834 return get_tf_config_task()['index'] 835 836 837def is_chief(): 838 return ('chief' not in get_tf_config_cluster_spec() 839 and get_task_type() == 'worker' 840 and get_task_index() == 0) 841