1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Multi-process runner for testing purpose.""" 16 17import collections 18import contextlib 19import json 20import os 21import signal 22import sys 23import threading 24import time 25import unittest 26import weakref 27 28from absl import logging 29import six 30from six.moves import queue as Queue 31 32from tensorflow.python import tf2 33from tensorflow.python.compat import v2_compat 34from tensorflow.python.distribute import multi_worker_util 35from tensorflow.python.distribute import multi_process_lib 36from tensorflow.python.eager import context 37from tensorflow.python.framework import test_util 38from tensorflow.python.util.tf_export import tf_export 39 40multiprocessing = multi_process_lib.multiprocessing 41 42# pylint: disable=g-import-not-at-top 43try: 44 # `faulthandler` is not available in py2. 45 import faulthandler 46except ImportError: 47 faulthandler = None 48 49# TODO(b/150264776): Remove after resolving CI issue. 50try: 51 import dill 52except ImportError: 53 dill = None 54 55# TODO(b/150264776): Remove after resolving CI issue. 56try: 57 import tblib.pickling_support 58 # For pickling traceback objects. 59 tblib.pickling_support.install() 60except ImportError: 61 pass 62 63 64# _ProcessStatusInfo contains process status information. When is_successful 65# attribute is True, the subprocess has ended successfully, or if False, the 66# exception stack trace info is stored in exc_info to pass on to parent process 67# to be re-raised. 68_ProcessStatusInfo = collections.namedtuple( 69 '_ProcessStatusInfo', 70 ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value']) 71 72# Information returned from a successful MultiProcessRunner run. 73MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', 74 ['return_value', 'stdout']) 75 76# visible_gpus: If not None, CUDA_VISIBLE_DEVICES is set to visible_gpus. 77TestEnvironment = collections.namedtuple('TestEnvironment', [ 78 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', 79 'v2_enabled', 'executing_eagerly', 'visible_gpus' 80]) 81 82# Resources for communication between worker processes and the main process. 83# 84# `process_status_queue` is used by `multi_process_runner` internally for 85# communication from subprocesses to the parent process for whether it's been 86# successful, and if not what the error stack trace is. 87# `parent_to_sub_queue` is used for communications from parent to subprocess. 88# Currently this is only used to terminate subprocesses. 89# TODO(rchao): Remove this once subprocess is terminated by SIGKILL. 90# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent 91# process. 92# `barrier` is a barrier for the party of all subprocesses. 93Resources = collections.namedtuple('Resources', [ 94 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' 95]) 96 97# Default time out sec is selected so that it's handled before the default 98# "medium" timeout of the test runs. 99_DEFAULT_TIMEOUT_SEC = 200 100 101# The timeout in seconds to wait to force kill a child process. When a child 102# process times out we first try to SIGTERM it so that it has a chance to dump 103# stacktraces. However dumping stacktrace can take a long time. 104_FORCE_KILL_WAIT_SEC = 30 105 106 107class MultiProcessRunner(object): 108 """A utility class to start multiple processes to simulate a cluster. 109 110 We need to use multiple processes to simulate a cluster in TF 2.0 tests 111 because TF 2.0 has some process-global data structures that have to be 112 separated by processes. We also need child processes to test out our fault 113 tolerance because shutting down a standard TensorFlow server within its 114 process is not supported. 115 116 Note: the main test program that uses this runner class must run main program 117 via `test_main` defined in this file. Using this runner in non-test binaries 118 is not supported yet. 119 120 This class is not thread-safe. Child processes will inherit TF2 behavior flag. 121 """ 122 123 def __init__(self, 124 fn, 125 cluster_spec, 126 rpc_layer=None, 127 max_run_time=None, 128 grpc_fail_fast=None, 129 stream_output=True, 130 return_output=False, 131 use_dill_for_args=True, 132 daemon=False, 133 dependence_on_chief=True, 134 auto_restart=False, 135 share_gpu=True, 136 args=None, 137 kwargs=None): 138 """Instantiation of a `MultiProcessRunner`. 139 140 Args: 141 fn: Function to be run on child processes. This will be run on processes 142 for all task types. 143 cluster_spec: Dict for cluster spec. The utility function 144 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` 145 can be conveniently used to create such dict. The following is an 146 example of cluster with three workers and two ps's. 147 {"worker": ["worker0.example.com:2222", 148 "worker1.example.com:2222", 149 "worker2.example.com:2222"], 150 "ps": ["ps0.example.com:2222", 151 "ps1.example.com:2222"]} 152 rpc_layer: RPC layer to use. Default value is 'grpc'. 153 max_run_time: `None` or integer. If not `None`, child processes are forced 154 to exit at approximately this many seconds after this utility is called. 155 We achieve this through `signal.alarm()` api. Note that this is best 156 effort at Python level since Python signal handler does not get executed 157 when it runs lower level C/C++ code. So it can be delayed for 158 arbitrarily long time. If any of the child process is still running when 159 `max_run_time` is up, they will be force-terminated and an 160 `UnexpectedSubprocessExitError` may be raised. If `None`, child 161 processes are not forced to exit. 162 grpc_fail_fast: Whether GRPC connection between processes should fail 163 without retrying. Defaults to None, in which case the environment 164 variable is not explicitly set. 165 stream_output: True if the output/error from the subprocesses should be 166 streamed to be printed in parent process' log. Defaults to True. 167 return_output: If True, the output/error from the subprocesses should be 168 collected to be attached to the resulting namedtuple returned from 169 `join()`. The list of output can be retrieved via `stdout` attribute. 170 Defaults to False. 171 use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill 172 can pickle more objects, but doesn't work with types in 173 `multiprocessing` library like `Mutex`. 174 daemon: Whether to start processes as daemons. 175 dependence_on_chief: Whether to terminates the cluster if the chief exits. 176 If auto_restart is True, it only terminates the cluster if the chief 177 exits with a zero exit code. 178 auto_restart: Whether to automatically restart processes that exit with 179 non-zero exit code. 180 share_gpu: Whether to share GPUs among workers. If False, each worker is 181 assigned different GPUs in a roundrobin fashion. This should be True 182 whenever possible for better test execution coverage; some situations 183 that need it to be False are tests that runs NCCL. 184 args: Positional arguments to be sent to `fn` run on subprocesses. 185 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 186 187 Raises: 188 RuntimeError: if `multi_process_runner.test_main()` is not called. 189 ValueError: if there are more than one chief in the `cluster_spec`. 190 SkipTest: if thread sanitizer is enabled (which is incompatible with MPR). 191 """ 192 if test_util.is_tsan_enabled(): 193 raise unittest.SkipTest( 194 'ThreadSanitizer is not compatible with MultiProcessRunner.') 195 196 assert cluster_spec is not None 197 if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: 198 raise ValueError('If chief exists in the cluster, there must be at most ' 199 'one chief. Current `cluster_spec` has {} chiefs.' 200 .format(len(cluster_spec['chief']))) 201 _check_initialization() 202 if not callable(fn): 203 raise ValueError('fn is not a callable') 204 205 self._fn = fn 206 self._cluster_spec = cluster_spec 207 self._rpc_layer = rpc_layer or 'grpc' 208 self._max_run_time = max_run_time 209 self._grpc_fail_fast = grpc_fail_fast 210 self._stream_output = stream_output 211 # TODO(rchao): Revisit return_output argument to consider other solution. 212 self._return_output = return_output 213 self._dependence_on_chief = dependence_on_chief 214 self._use_dill_for_args = use_dill_for_args 215 self._daemon = daemon 216 self._auto_restart = auto_restart 217 self._args = args or () 218 self._kwargs = kwargs or {} 219 220 self._share_gpu = share_gpu 221 self._total_gpu = len(context.context().list_physical_devices('GPU')) 222 223 # Child processes should have the same v2 and eager behavior. 224 self._v2_enabled = tf2.enabled() 225 self._executing_eagerly = context.executing_eagerly() 226 227 self._joined = False 228 self._process_lock = threading.Lock() 229 # Guarded by self._process_lock. 230 self._processes = {} 231 # Record which processes are terminated. Due to a bug in Python<3.7, 232 # terminated processes return 255 exit code, which should cause an exception 233 # in join(). 234 # https://bugs.python.org/issue30589 235 # Guarded by self._process_lock. 236 self._terminated = set() 237 self._reading_threads = [] 238 239 self._manager = manager() 240 self._process_status_queue = self._manager.Queue() 241 self._parent_to_sub_queue = self._manager.Queue() 242 parties = sum(len(addresses) for addresses in self._cluster_spec.values()) 243 self._barrier = self._manager.Barrier(parties) 244 245 # We use a queue to collect outputs from worker processes since it's thread 246 # safe. 247 self._streaming_queue = self._manager.Queue() 248 249 self._watchdog_thread = None 250 251 def set_args(self, args=None, kwargs=None): 252 self._args = args or self._args 253 self._kwargs = kwargs or self._kwargs 254 255 def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): 256 """Function to continuously read lines from subprocesses.""" 257 with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: 258 for line in reader: 259 task_string = '[{}-{}]:'.format(task_type, task_id) 260 formatted_line = '{} {}'.format(task_string.ljust(14), line) 261 if self._stream_output: 262 # TODO(rchao): Use a lock here to ensure the printed lines are not 263 # broken. 264 print(formatted_line, end='', flush=True) 265 if self._return_output: 266 self._streaming_queue.put(formatted_line) 267 268 def _start_subprocess_and_reading_thread(self, 269 task_type, 270 task_id, 271 cluster_spec=None, 272 fn=None, 273 args=None, 274 kwargs=None): 275 """Start a subprocess and a thread the reads lines from the subprocess.""" 276 277 if dill is None: 278 raise unittest.SkipTest( 279 'TODO(b/150264776): Resolve dependency issue in CI') 280 281 cluster_spec = cluster_spec or self._cluster_spec 282 visible_gpus = None 283 if not self._share_gpu and self._total_gpu > 0: 284 # Assign GPUs in a roundrobin fashion. 285 id_in_cluster = multi_worker_util.id_in_cluster(cluster_spec, task_type, 286 task_id) 287 worker_count = multi_worker_util.worker_count(cluster_spec, task_type) 288 visible_gpus = list(range(id_in_cluster, self._total_gpu, worker_count)) 289 290 test_env = TestEnvironment( 291 task_type=task_type, 292 task_id=task_id, 293 cluster_spec=cluster_spec, 294 rpc_layer=self._rpc_layer, 295 grpc_fail_fast=self._grpc_fail_fast, 296 v2_enabled=self._v2_enabled, 297 executing_eagerly=self._executing_eagerly, 298 visible_gpus=visible_gpus, 299 ) 300 pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) 301 resources = Resources( 302 process_status_queue=self._process_status_queue, 303 parent_to_sub_queue=self._parent_to_sub_queue, 304 streaming_pipe_w=pipe_w, 305 barrier=self._barrier, 306 ) 307 if fn is None: 308 fn, args, kwargs = self._fn, self._args, self._kwargs 309 # Always use dill to pickle fn so that we support more callable 310 # types, e.g. lambda. 311 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 312 if self._use_dill_for_args: 313 args = dill.dumps(args, dill.HIGHEST_PROTOCOL) 314 kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) 315 316 p = _Process( 317 test_env=test_env, 318 target=_ProcFunc(), 319 args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args), 320 daemon=self._daemon) 321 p.start() 322 self._processes[(task_type, task_id)] = p 323 self._terminated.discard((task_type, task_id)) 324 325 # For each subprocess, we dedicate a thread continuously reading lines 326 # from them. 327 thread = threading.Thread( # pylint: disable=unexpected-keyword-arg 328 target=self._continuously_readline_from_sub, 329 args=(pipe_r, task_type, task_id)) 330 thread.start() 331 self._reading_threads.append(thread) 332 333 if self._watchdog_thread is None or not self._watchdog_thread.is_alive(): 334 self._watchdog_thread = threading.Thread(target=self._process_watchdog) 335 self._watchdog_thread.start() 336 337 def start(self): 338 """Starts processes, one for each task in `cluster_spec`. 339 340 Note that this is best effort by the applicable multiprocessing library, 341 and it may take up to seconds for a subprocess to be successfully started. 342 """ 343 with self._process_lock: 344 if self._processes: 345 raise ValueError('MultiProcessRunner already started.') 346 if self._joined: 347 raise ValueError('cannot start new processes after' 348 'MultiProcessRunner.join() is called') 349 350 for task_type, addresses in self._cluster_spec.items(): 351 for task_id, _ in enumerate(addresses): 352 self._start_subprocess_and_reading_thread(task_type, task_id) 353 354 # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, 355 # without this the tests become very flaky. 356 if self._max_run_time is not None: 357 358 def handler(signum, frame): 359 del signum, frame 360 self.terminate_all() 361 362 signal.signal(signal.SIGALRM, handler) 363 signal.alarm(self._max_run_time) 364 365 def start_in_process_as(self, as_task_type, as_task_id): 366 """Start the processes, with the specified task run in main process. 367 368 This is similar to `start()` except that the task with task_type 369 `as_task_type` and task_id `as_task_id` is run in the main process. 370 This method is particularly useful when debugging tool such as `pdb` is 371 needed in some specific task. Note that since this method is blocking until 372 that specific task exits, additional actions would need a thread to be 373 called: 374 375 ```python 376 def fn(): 377 # user code to be run 378 import pdb; pdb.set_trace() 379 380 def follow_ups(): 381 time.sleep(5) 382 mpr.start_single_process( 383 task_type='evaluator', 384 task_id=0) 385 386 mpr = multi_process_runner.MultiProcessRunner( 387 fn, 388 multi_worker_test_base.create_cluster_spec( 389 has_chief=True, num_workers=1)) 390 threading.Thread(target=follow_ups).start() 391 mpr.start_in_process_as(as_task_type='chief', as_task_id=0) 392 mpr.join() 393 ``` 394 395 Note that if `return_output=True`, the logs/stdout by task 396 run by the main process is not available in result.stdout. 397 398 Args: 399 as_task_type: The task type to be run in the main process. 400 as_task_id: The task id to be run in the main process. 401 """ 402 if self._processes: 403 raise ValueError('MultiProcessRunner already started.') 404 with self._process_lock: 405 if self._joined: 406 raise ValueError('cannot start new processes after' 407 'MultiProcessRunner.join() is called') 408 for task_type, addresses in self._cluster_spec.items(): 409 for task_id, _ in enumerate(addresses): 410 if not (task_type == as_task_type and task_id == as_task_id): 411 self._start_subprocess_and_reading_thread(task_type, task_id) 412 413 _set_tf_config(as_task_type, as_task_id, self._cluster_spec, 414 self._rpc_layer) 415 self._fn(*self._args, **self._kwargs) 416 417 def start_single_process(self, 418 task_type, 419 task_id, 420 cluster_spec=None, 421 fn=None, 422 args=None, 423 kwargs=None): 424 """Starts a single process. 425 426 This starts a process in the cluster with the task type, task id, and the 427 process function (`fn`). If process function is `None`, the function 428 provided at `__init__` will be used. If `cluster_spec` is `None`, the 429 cluster spec provided at `__init__` will be used. 430 431 TODO(rchao): It is meant that all subprocesses will be updated with the new 432 cluster spec, but this has yet to be implemented. At this time only the 433 newly started subprocess picks up this updated cluster spec. 434 435 Args: 436 task_type: The task type. 437 task_id: The task id. 438 cluster_spec: The cluster spec to be used on the newly started 439 process. If `None`, the cluster spec provided at `__init__` will be 440 used. 441 fn: The process function to be run on the newly started 442 process. If specified, specify `args` and `kwargs` as well. If `None`, 443 the function provided at `__init__` will be used. 444 args: Optional positional arguments to be supplied in `fn`. 445 kwargs: Optional keyword arguments to be supplied in `fn`. 446 """ 447 with self._process_lock: 448 if self._joined: 449 raise ValueError('cannot start new processes after' 450 'MultiProcessRunner.join() is called') 451 self._start_subprocess_and_reading_thread( 452 task_type, 453 task_id, 454 cluster_spec=cluster_spec, 455 fn=fn, 456 args=args or (), 457 kwargs=kwargs or {}) 458 459 def _queue_to_list(self, queue_to_convert): 460 """Convert `queue.Queue` to `list`.""" 461 list_to_return = [] 462 # Calling `queue.empty()` is not reliable. 463 while True: 464 try: 465 list_to_return.append(queue_to_convert.get(block=False)) 466 except Queue.Empty: 467 break 468 return list_to_return 469 470 def _get_process_statuses(self): 471 # One worker may have multiple statuses. We only keep the last one. 472 statuses = {} 473 for status in self._queue_to_list(self._process_status_queue): 474 statuses[(status.task_type, status.task_id)] = status 475 return statuses 476 477 def get_process_id(self, task_type, task_id): 478 """Returns the subprocess id given the task type and task id.""" 479 with self._process_lock: 480 p = self._processes.get((task_type, task_id), None) 481 return p.pid if p else None 482 483 def get_process_exit_code(self, task_type, task_id): 484 """Returns the subprocess exit code given the task type and task id. 485 486 Args: 487 task_type: The task type. 488 task_id: The task id. 489 490 Returns: 491 The subprocess exit code; `None` if the subprocess has not exited yet. 492 493 Raises: 494 KeyError: If the corresponding subprocess is not found with `task_type` 495 and `task_id`. 496 """ 497 with self._process_lock: 498 p = self._processes[(task_type, task_id)] 499 return p.exitcode if p else None 500 501 def process_exists(self, task_type, task_id): 502 """Returns whether the subprocess still exists given the task type and id. 503 504 Args: 505 task_type: The task type. 506 task_id: The task id. 507 508 Returns: 509 Boolean; whether the subprocess still exists. If the subprocess has 510 exited, this returns False. 511 """ 512 return self.get_process_exit_code(task_type, task_id) is None 513 514 def _process_watchdog(self): 515 """Simulates a cluster management system. 516 517 - If auto_restart is True, it restarts processes that exit with a non-zero 518 exit code. Note that when join() times out it overrides auto_restart to 519 False. 520 - If dependence_on_chief is True, it terminates all processes once the chief 521 exits. If auto_restart is also True, it only terminates all processes if 522 the chief exit with a zero exit code, otherwise it restarts the chief. 523 524 This runs in self._watchdog_thread. 525 """ 526 while True: 527 time.sleep(1) 528 with self._process_lock: 529 chief = self._processes.get(('chief', 0), None) 530 # Terminate the cluster when _dependence_on_chief is True if either: 531 # - chief has exited with zero exit code. 532 # - chief has exited with non-zero exit code and self._auto_restart is 533 # False. 534 if chief and self._dependence_on_chief and chief.exitcode is not None: 535 if chief.exitcode == 0 or (not self._auto_restart): 536 for p in self._processes.values(): 537 # Give other processes a chance to exit on their own. 538 p.join(timeout=3) 539 self._terminate_all() 540 for p in self._processes.values(): 541 p.join() 542 return 543 544 # Auto restart failed processes if self._auto_restart is True. 545 if self._auto_restart: 546 has_failure = False 547 for (task_type, task_id), p in self._processes.items(): 548 if p.exitcode is not None and p.exitcode != 0: 549 has_failure = True 550 logging.info('Restarting failed %s-%d', task_type, task_id) 551 self._start_subprocess_and_reading_thread(task_type, task_id) 552 if has_failure: 553 continue 554 555 # Exit the thread if all processes have exited at this point. 556 if all(p.exitcode is not None for p in self._processes.values()): 557 return 558 559 def _reraise_if_subprocess_error(self, process_statuses): 560 for process_status in process_statuses.values(): 561 assert isinstance(process_status, _ProcessStatusInfo) 562 if not process_status.is_successful: 563 process_status.exc_info[1].mpr_result = self._get_mpr_result( 564 process_statuses) 565 six.reraise(*process_status.exc_info) 566 567 def join(self, timeout=_DEFAULT_TIMEOUT_SEC): 568 """Joins all the processes with timeout. 569 570 If any of the subprocesses does not exit approximately after `timeout` 571 seconds has passed after `join` call, this raises a 572 `SubprocessTimeoutError`. 573 574 Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to 575 log the stack traces of the subprocesses when they exit. However, this 576 results in timeout when the test runs with tsan (thread sanitizer); if tsan 577 is being run on the test targets that rely on timeout to assert information, 578 `MultiProcessRunner.terminate_all()` must be called after `join()`, before 579 the test exits, so the subprocesses are terminated with SIGKILL, and data 580 race is removed. 581 582 Args: 583 timeout: optional integer or `None`. If provided as an integer, and not 584 all processes report status within roughly `timeout` seconds, a 585 `SubprocessTimeoutError` exception will be raised. If `None`, `join` never 586 times out. 587 588 Returns: 589 A `MultiProcessRunnerResult` object, which has two attributes, 590 `return_value` and `stdout`. `return_value` always contains a list of 591 return values from the subprocesses, although the order is not meaningful. 592 If `return_output` argument is True at `__init__`, `stdout` is available 593 that contains a list of all messages from subprocesses' stdout and stderr. 594 595 Raises: 596 SubprocessTimeoutError: if not all processes report status approximately 597 within `timeout` seconds. When this is raised, a 598 `MultiProcessRunnerResult` object can be retrieved by 599 `SubprocessTimeoutError`'s mpr_result attribute, which has the same 600 structure as above 'Returns' section describes. 601 UnexpectedSubprocessExitError: If any of the subprocesses did not exit 602 properly (for example, they exit on SIGTERM or SIGKILL signal). When 603 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 604 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 605 same structure as above 'Returns' section describes. If `max_run_time` 606 is not `None`, it is expected that some subprocesses may be 607 force-killed when `max_run_time` is up, and this is raised in those 608 cases. 609 Exception: if there is an Exception propagated from any subprocess. When 610 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 611 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 612 same structure as above 'Returns' section describes. 613 """ 614 if timeout and not isinstance(timeout, int): 615 raise ValueError('`timeout` must be an integer or `None`.') 616 with self._process_lock: 617 if self._joined: 618 raise ValueError("MultiProcessRunner can't be joined twice.") 619 self._joined = True 620 621 self._watchdog_thread.join(timeout) 622 if self._watchdog_thread.is_alive(): 623 # Timeout. Force termination to dump worker processes stack trace. 624 with self._process_lock: 625 self._auto_restart = False 626 logging.error('Timeout when joining for child processes. Terminating...') 627 self.terminate_all(sig=signal.SIGTERM) 628 # Wait for the processes to terminate by themselves first, so they have a 629 # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them. 630 self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC) 631 if self._watchdog_thread.is_alive(): 632 logging.error('Timeout when waiting for child processes to ' 633 'print stacktrace. Sending SIGKILL...') 634 self.terminate_all() 635 self._watchdog_thread.join() 636 process_statuses = self._get_process_statuses() 637 self._reraise_if_subprocess_error(process_statuses) 638 raise SubprocessTimeoutError( 639 'One or more subprocesses timed out, where timeout was set to {}s. ' 640 'Please change the `timeout` argument for ' 641 '`MultiProcessRunner.join()` or `multi_process_runner.run()` ' 642 'if it should be adjusted.'.format(timeout), 643 self._get_mpr_result(process_statuses)) 644 645 for (task_type, task_id), p in self._processes.items(): 646 logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) 647 648 process_statuses = self._get_process_statuses() 649 self._reraise_if_subprocess_error(process_statuses) 650 651 # Checking all the processes that are expected to exit properly. 652 for (task_type, task_id), p in self._processes.items(): 653 # Successfully exiting process has exit code 0. We ignore processes that 654 # are terminated. 655 assert p.exitcode is not None 656 if (p.exitcode > 0 and (task_type, task_id) not in self._terminated): 657 raise UnexpectedSubprocessExitError( 658 'Subprocess %s-%d exited with exit code %s. See logs for details.' 659 % (task_type, task_id, p.exitcode), 660 self._get_mpr_result(process_statuses)) 661 662 logging.info('Joining log reading threads.') 663 for thread in self._reading_threads: 664 thread.join() 665 logging.info('Joined log reading threads.') 666 667 # Clear the alarm. 668 signal.alarm(0) 669 670 return self._get_mpr_result(process_statuses) 671 672 def _get_mpr_result(self, process_statuses): 673 stdout = self._queue_to_list(self._streaming_queue) 674 return_values = [] 675 for process_status in process_statuses.values(): 676 if process_status.return_value is not None: 677 return_values.append(process_status.return_value) 678 return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) 679 680 def terminate(self, task_type, task_id): 681 """Terminates the process with `task_type` and `task_id`. 682 683 If auto_retart=True, the terminated task will be restarted unless the chief 684 has already exited with zero exit code. 685 686 Args: 687 task_type: the task type. 688 task_id: the task id. 689 690 """ 691 with self._process_lock: 692 p = self._processes.get((task_type, task_id), None) 693 if p is None: 694 raise ValueError('{}-{} does not exist'.format(task_type, task_id)) 695 self._terminated.add((task_type, task_id)) 696 # TODO(crccw): change to use Process.terminate() as well. 697 self._parent_to_sub_queue.put('terminate {} {}'.format( 698 task_type, task_id)) 699 p.join() 700 701 def _terminate_all(self, sig=None): 702 """Terminates all subprocesses. 703 704 The caller is required to hold self._process_lock. 705 706 Args: 707 sig: the signal used to terminate the process. The default is SIGKILL. 708 """ 709 710 # Use SIGKILL as default. In systems where that's unavailable such as 711 # windows, use SIGTERM. 712 sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) 713 for (task_type, task_id), p in self._processes.items(): 714 if p.exitcode is not None: 715 logging.info('%s-%d has already exited. Not terminating.', task_type, 716 task_id) 717 continue 718 try: 719 os.kill(p.pid, sig) 720 self._terminated.add((task_type, task_id)) 721 logging.info('%s-%d terminated with signal %r.', task_type, task_id, 722 sig) 723 except ProcessLookupError: 724 logging.info('Attempting to kill %s-%d but it does not exist.', 725 task_type, task_id) 726 727 def terminate_all(self, sig=None): 728 """Terminates all subprocesses.""" 729 with self._process_lock: 730 self._terminate_all(sig) 731 732 733class _Process(multi_process_lib.Process): 734 """A modified `multiprocessing.Process` that can set up environment variables.""" 735 736 # TODO(crccw): consider moving other logics in _ProcFunc to _Process. 737 738 def __init__(self, test_env, **kwargs): 739 super(_Process, self).__init__(**kwargs) 740 self._test_env = test_env 741 self._actual_run = getattr(self, 'run') 742 self.run = self._run_with_setenv 743 744 def _run_with_setenv(self): 745 # We need to set environment variables before doing anything because 746 # setenv() is not thread-safe. 747 test_env = self._test_env 748 if test_env.grpc_fail_fast is not None: 749 os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) 750 if test_env.visible_gpus: 751 os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 752 [str(i) for i in test_env.visible_gpus]) 753 _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, 754 test_env.rpc_layer) 755 return self._actual_run() 756 757 758class _ProcFunc(object): 759 """Represents a callable to run in a subprocess.""" 760 761 @contextlib.contextmanager 762 def _runtime_mode(self, executing_eagerly): 763 if executing_eagerly: 764 with context.eager_mode(): 765 yield 766 else: 767 with context.graph_mode(): 768 yield 769 770 def _message_checking_func(self, task_type, task_id): 771 """A function that regularly checks messages from parent process.""" 772 # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. 773 while True: 774 try: 775 message = self._resources.parent_to_sub_queue.get(block=False) 776 777 # Currently the only possible message is termination. 778 if not message.startswith('terminate'): 779 raise ValueError('Unrecognized message: {}'.format(message)) 780 781 if message == 'terminate {} {}'.format(task_type, task_id): 782 break 783 else: 784 # If the message is not targeting this process, put it back to the 785 # queue. 786 self._resources.parent_to_sub_queue.put(message) 787 time.sleep(1) 788 except Queue.Empty: 789 time.sleep(0.1) 790 self._resources.process_status_queue.put( 791 _ProcessStatusInfo( 792 task_type=task_type, 793 task_id=task_id, 794 is_successful=True, 795 exc_info=None, 796 return_value=None)) 797 # `os._exit(1)` is used to more reliably terminate a subprocess. 798 os._exit(1) # pylint: disable=protected-access 799 800 def _close_streaming(self): 801 """Close stdout, stderr and streaming pipe. 802 803 We need to explicitly close them since Tensorflow may take a while to exit, 804 so that the reading threads in the main process can exit more quickly. 805 """ 806 sys.stdout.flush() 807 sys.stderr.flush() 808 sys.stdout.close() 809 sys.stderr.close() 810 self._resources.streaming_pipe_w.close() 811 812 def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args): 813 """The wrapper function that actually gets run in child process(es).""" 814 815 global _barrier 816 817 self._resources = resources 818 _barrier = self._resources.barrier 819 fn = dill.loads(fn) 820 if use_dill_for_args: 821 args = dill.loads(args) 822 kwargs = dill.loads(kwargs) 823 824 if faulthandler is not None: 825 faulthandler.enable() 826 faulthandler.register(signal.SIGTERM, chain=True) 827 828 # All logging should go to stderr to be streamed to the main process. 829 logging.set_stderrthreshold(logging.DEBUG) 830 831 # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so 832 # print() and logging.*() write directly to `streaming_pipe_w`. 833 # Unfortunately since we cannot prepend task_type and task_id information to 834 # the streamed logs we will need a thread per subprocess to distinguish 835 # where the piece of message is from. 836 os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) 837 os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) 838 839 pid = os.getpid() 840 logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, 841 test_env.task_type, test_env.task_id) 842 logging.info('TF_CONFIG: %r', os.environ['TF_CONFIG']) 843 844 # The thread will be dedicated to checking messages from the parent process. 845 threading.Thread( # pylint: disable=unexpected-keyword-arg 846 target=self._message_checking_func, 847 args=(test_env.task_type, test_env.task_id), 848 daemon=True).start() 849 850 if test_env.v2_enabled: 851 v2_compat.enable_v2_behavior() 852 853 with self._runtime_mode(test_env.executing_eagerly): 854 info = _run_contained(test_env.task_type, test_env.task_id, fn, args, 855 kwargs) 856 self._resources.process_status_queue.put(info) 857 858 # Re-raise the exception in addition to reporting it to the parent 859 # process, so that even if `--test_timeout` flag is set and the 860 # error doesn't make it to be shown in parent process before bazel's 861 # timeout, the log would still show what happens in this subprocess, 862 # instead of silently suppressing the error due to early bazel 863 # timeout. Raising an error in the subprocess produces stack trace in 864 # the log, but the program continues running. 865 if not info.is_successful: 866 six.reraise(*info.exc_info) 867 868 self._close_streaming() 869 870 # Exit with code 0 as it's considered successful exit at this point. 871 sys.exit(0) 872 873 874# Active MultiProcessPoolRunner. We need to shut them down when the program 875# exits, and this is by setting the `tearDownModule` of the module containing 876# `__main__`. Note this it set in both the parent process and the subprocesses. 877_active_pool_runners = weakref.WeakSet() 878 879 880def _shutdown_all_pool_runners(): 881 for pool in _active_pool_runners: 882 pool.shutdown() 883 884 885def is_oss(): 886 """Returns whether the test is run under OSS.""" 887 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 888 889 890class MultiProcessPoolRunner(object): 891 """A utility class to start a process pool to simulate a cluster. 892 893 It's similar to MultiProcessRunner, but uses a pool of processes to avoid the 894 expensive initialization cost of Tensorflow. 895 """ 896 897 def __init__(self, cluster_spec, initializer=None, share_gpu=True): 898 """Creates a multi-process pool runner. 899 900 Args: 901 cluster_spec: Dict for cluster spec. The following is an example of 902 cluster with three workers. 903 {"worker": ["worker0.example.com:2222", 904 "worker1.example.com:2222", 905 "worker2.example.com:2222"]} 906 initializer: a callable to called at the startup of worker processes. 907 share_gpu: Whether to share GPUs among workers. If False, each worker is 908 assigned different GPUs in a roundrobin fashion. 909 910 Raises: 911 RuntimeError: if `multi_process_runner.test_main()` is not called. 912 ValueError: if there are more than one chief in the `cluster_spec`. 913 """ 914 _active_pool_runners.add(self) 915 self._cluster_spec = cluster_spec 916 self._initializer = initializer 917 self._share_gpu = share_gpu 918 self._conn = {} 919 self._runner = None 920 921 def __del__(self): 922 self.shutdown() 923 924 def shutdown(self): 925 """Shuts down the worker pool.""" 926 for conn in self._conn.values(): 927 conn.close() 928 self._conn = {} 929 if self._runner is not None: 930 try: 931 self._runner.join() 932 except Exception as e: # pylint: disable=broad-except 933 logging.error( 934 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', 935 e) 936 self._runner = None 937 938 def _start(self): 939 """Starts the worker pool.""" 940 # We need different arguments for different processes so we're passing a 941 # no-op fn here and use start_single_process instead. 942 943 if dill is None: 944 raise unittest.SkipTest( 945 'TODO(b/150264776): Resolve dependency issue in CI') 946 947 self._runner = MultiProcessRunner( 948 fn=lambda: None, 949 cluster_spec=self._cluster_spec, 950 use_dill_for_args=False, 951 share_gpu=self._share_gpu) 952 if self._initializer: 953 initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL) 954 else: 955 initializer = None 956 for task_type, addresses in self._cluster_spec.items(): 957 for task_id, _ in enumerate(addresses): 958 conn1, conn2 = multiprocessing.Pipe(duplex=True) 959 self._conn[(task_type, task_id)] = conn1 960 self._runner.start_single_process( 961 task_type, 962 task_id, 963 fn=_pool_runner_worker, 964 args=(task_type, task_id, initializer, conn2)) 965 966 def run(self, fn, args=None, kwargs=None): 967 """Runs `fn` with `args` and `kwargs` on all jobs. 968 969 Args: 970 fn: The function to be run. 971 args: Optional positional arguments to be supplied in `fn`. 972 kwargs: Optional keyword arguments to be supplied in `fn`. 973 974 Returns: 975 A list of return values. 976 """ 977 _check_initialization() 978 # TODO(b/150264776): skip in OSS until it's implemented. 979 multi_process_lib.Process() 980 if self._runner is None: 981 self._start() 982 983 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 984 for conn in self._conn.values(): 985 conn.send((fn, args or [], kwargs or {})) 986 987 process_statuses = [] 988 for (task_type, task_id), conn in self._conn.items(): 989 logging.info('Waiting for the result from %s-%d', task_type, task_id) 990 try: 991 process_statuses.append(conn.recv()) 992 except EOFError: 993 # This shouldn't happen due to exceptions in fn. This usually 994 # means bugs in the runner. 995 self.shutdown() 996 raise RuntimeError('Unexpected EOF. Worker process may have died. ' 997 'Please report a bug') 998 999 return_values = [] 1000 for process_status in process_statuses: 1001 assert isinstance(process_status, _ProcessStatusInfo) 1002 if not process_status.is_successful: 1003 six.reraise(*process_status.exc_info) 1004 if process_status.return_value is not None: 1005 return_values.append(process_status.return_value) 1006 1007 return return_values 1008 1009 1010def _pool_runner_worker(task_type, task_id, initializer, conn): 1011 """Function that runs on the workers in a pool. 1012 1013 It listens for callables to run and returns the result until `conn` is closed. 1014 It captures the exceptions during executing the callable and return it through 1015 `conn`. 1016 1017 Args: 1018 task_type: the task type. 1019 task_id: the task index. 1020 initializer: a callable to execute during startup. 1021 conn: a multiprocessing.Connection object to listen for tasks and send 1022 results. 1023 """ 1024 if initializer: 1025 initializer = dill.loads(initializer) 1026 initializer() 1027 while True: 1028 try: 1029 fn, args, kwargs = conn.recv() 1030 except EOFError: 1031 break 1032 fn = dill.loads(fn) 1033 info = _run_contained(task_type, task_id, fn, args, kwargs) 1034 sys.stdout.flush() 1035 sys.stderr.flush() 1036 conn.send(info) 1037 1038 1039def _run_contained(task_type, task_id, fn, args, kwargs): 1040 """Runs `fn` with `args` and `kwargs`. 1041 1042 The function returns _ProcessStatusInfo which captures the return value and 1043 the exception. 1044 1045 Args: 1046 task_type: the task type. 1047 task_id: the task index. 1048 fn: the function to be run. 1049 args: optional positional arguments to be supplied in `fn`. 1050 kwargs: optional keyword arguments to be supplied in `fn`. 1051 1052 Returns: 1053 a _ProcessStatusInfo. 1054 1055 """ 1056 is_successful = False 1057 return_value = None 1058 exc_info = None 1059 try: 1060 return_value = fn(*args, **kwargs) 1061 is_successful = True 1062 return _ProcessStatusInfo( 1063 task_type=task_type, 1064 task_id=task_id, 1065 is_successful=is_successful, 1066 exc_info=exc_info, 1067 return_value=return_value) 1068 1069 # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not 1070 # handled here. 1071 except Exception: # pylint: disable=broad-except 1072 exc_info = sys.exc_info() 1073 return _ProcessStatusInfo( 1074 task_type=task_type, 1075 task_id=task_id, 1076 is_successful=is_successful, 1077 exc_info=exc_info, 1078 return_value=return_value) 1079 1080 1081@tf_export('__internal__.distribute.multi_process_runner' 1082 '.SubprocessTimeoutError', 1083 v1=[]) 1084class SubprocessTimeoutError(RuntimeError): 1085 """An error that indicates there is at least one subprocess timing out. 1086 1087 When this is raised, a namedtuple object representing the multi-process run 1088 result can be retrieved by 1089 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 1090 `mpr_result` attribute. See 1091 `tf.__internal__.distribute.multi_process_runner.run` for more information. 1092 """ 1093 1094 def __init__(self, msg, mpr_result): 1095 super(SubprocessTimeoutError, self).__init__(msg) 1096 self.mpr_result = mpr_result 1097 1098 1099@tf_export('__internal__.distribute.multi_process_runner' 1100 '.UnexpectedSubprocessExitError', 1101 v1=[]) 1102class UnexpectedSubprocessExitError(RuntimeError): 1103 """An error indicating there is at least one subprocess with unexpected exit. 1104 1105 When this is raised, a namedtuple object representing the multi-process run 1106 result can be retrieved by 1107 `tf.__internal__.distribute.multi_process_runner 1108 .UnexpectedSubprocessExitError`'s 1109 `mpr_result` attribute. See 1110 `tf.__internal__.distribute.multi_process_runner.run` for more information. 1111 """ 1112 1113 def __init__(self, msg, mpr_result): 1114 super(UnexpectedSubprocessExitError, self).__init__(msg) 1115 self.mpr_result = mpr_result 1116 1117 1118@tf_export( 1119 '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[]) 1120class NotInitializedError(RuntimeError): 1121 """An error indicating `multi_process_runner.run` is used without init. 1122 1123 When this is raised, user is supposed to call 1124 `tf.__internal__.distribute.multi_process_runner.test_main()` within 1125 `if __name__ == '__main__':` block to properly initialize 1126 `multi_process_runner.run`. 1127 """ 1128 pass 1129 1130 1131def _check_initialization(): 1132 if not multi_process_lib.initialized(): 1133 raise NotInitializedError( 1134 '`multi_process_runner` is not initialized. ' 1135 'Please call `tf.__internal__.distribute.multi_process_runner.' 1136 'test_main()` within `if __name__ == \'__main__\':` block ' 1137 'in your python module to properly initialize ' 1138 '`multi_process_runner`.') 1139 1140 1141def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): 1142 """Set TF_CONFIG environment variable.""" 1143 tf_config_dict = { 1144 'cluster': cluster_spec, 1145 'task': { 1146 'type': task_type, 1147 'index': task_id, 1148 }, 1149 } 1150 if rpc_layer is not None: 1151 tf_config_dict['rpc_layer'] = rpc_layer 1152 os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) 1153 1154 1155@tf_export('__internal__.distribute.multi_process_runner.run', v1=[]) 1156def run(fn, 1157 cluster_spec, 1158 rpc_layer=None, 1159 max_run_time=None, 1160 return_output=False, 1161 timeout=_DEFAULT_TIMEOUT_SEC, 1162 args=None, 1163 kwargs=None): 1164 """Run `fn` in multiple processes according to `cluster_spec`. 1165 1166 Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run` 1167 launches multiple processes, each of which runs `fn`. These processes are 1168 referred to as "subprocesses" or "child processes". Each of those subprocesses 1169 will have their `TF_CONFIG` environment variable set, according to 1170 `cluster_spec` and their task types. The stdout of the subprocesses are 1171 streamed to the main process' and thus available in logs (if `stream_output` 1172 is True), with [type-id] prefix. 1173 1174 `tf.__internal__.distribute.multi_process_runner.run` will block until all 1175 subprocesses have successfully exited, and return a namedtuple object that 1176 represents the run result. This object has a `return_value` attribute, which 1177 is a list that contains subprocesses `fn`'s return values, for those 1178 subprocesses that successfully returned from `fn`. The order of `return_value` 1179 list is not meaningful. If an optional arg `return_output` (default to False) 1180 is set to True, the namedtuple object will have an additional attribute 1181 `stdout`, which is a list containing the stdout of the subprocesses. If any 1182 subprocess' `fn` ends up raising an error, that error will be reraised from 1183 `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned 1184 namedtuple object will be available through the exception's 1185 `mpr_result` attribute. 1186 1187 This utility is used for simulating running TensorFlow programs across 1188 multiple task types, and each of the task type may contain more than one task 1189 (except for "chief" where more than one task is prohibited). Test coverage of 1190 multi-worker training is the main application of this utility, where code 1191 written for multi-worker training can be realistically covered in unit tests. 1192 1193 Any test module that uses 1194 `tf.__internal__.distribute.multi_process_runner.run()` must call 1195 `tf.__internal__.distribute.multi_process_runner.test_main()` instead of 1196 regular `test.main()` inside `if __name__ == '__main__':` block for proper 1197 initialization. 1198 1199 Args: 1200 fn: Function to be run on child processes. This will be run on processes for 1201 all task types. 1202 cluster_spec: Dict for cluster spec. The utility function 1203 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can 1204 be conveniently used to create such dict. The following is an example of 1205 cluster with three workers and two ps's. 1206 {"worker": ["worker0.example.com:2222", 1207 "worker1.example.com:2222", 1208 "worker2.example.com:2222"], 1209 "ps": ["ps0.example.com:2222", 1210 "ps1.example.com:2222"]} 1211 rpc_layer: RPC layer to use. Default value is 'grpc'. 1212 max_run_time: `None` or integer. If not `None`, child processes are forced 1213 to exit at approximately this many seconds after this utility is called. 1214 We achieve this through `signal.alarm()` api. Note that this is best 1215 effort at Python level since Python signal handler does not get executed 1216 when it runs lower level C/C++ code. So it can be delayed for arbitrarily 1217 long time. If any of the child process is still running when 1218 `max_run_time` is up, they will be force-terminated and an 1219 `tf.__internal__.distribute.multi_process_runner 1220 .UnexpectedSubprocessExitError` 1221 may be raised. If `None`, child processes are not forced to exit. 1222 return_output: If True, the output/error from the subprocesses should be 1223 collected to be attached to the resulting namedtuple returned from this 1224 utility. The list of output can be retrieved via `stdout` attribute. 1225 Defaults to False. 1226 timeout: optional integer or `None`. If provided as an integer, and not all 1227 processes report status within roughly `timeout` seconds, a 1228 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError` 1229 exception will be raised. If `None`, 1230 `tf.__internal__.distribute.multi_process_runner.run` never times out. 1231 Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in 1232 `multi_process_runner` module. 1233 args: Positional arguments to be sent to `fn` run on subprocesses. 1234 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 1235 1236 Returns: 1237 A namedtuple object, which has two attributes, 1238 `return_value` and `stdout`. `return_value` always contains a list of 1239 returnvalues from the subprocesses, although the order is not meaningful. 1240 If `return_output` argument is True, `stdout` is available that contains a 1241 list of all messages from subprocesses' stdout and stderr, and the order 1242 is mostly chronological. 1243 1244 Raises: 1245 RuntimeError: if 1246 `tf.__internal__.distribute.multi_process_runner.test_main()` is 1247 not called in test's `if __name__ == '__main__':` block. 1248 ValueError: if there are more than one chief in the `cluster_spec`. 1249 tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if 1250 not all processes report status approximately 1251 within `timeout` seconds. When this is raised, a 1252 namedtuple object can be retrieved by 1253 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 1254 `mpr_result` attribute, which has the same 1255 structure as above 'Returns' section describes. 1256 tf.__internal__.distribute.multi_process_runner 1257 .UnexpectedSubprocessExitError: 1258 If any of the subprocesses did not exit 1259 properly (for example, they exit on SIGTERM or SIGKILL signal). When 1260 this is raised, a namedtuple object can be retrieved by 1261 `tf.__internal__.distribute.multi_process_runner 1262 .UnexpectedSubprocessExitError`'s 1263 `mpr_result` attribute, which has the 1264 same structure as above 'Returns' section describes. If `max_run_time` 1265 is not `None`, it is expected that some subprocesses may be 1266 force-killed when `max_run_time` is up, and this is raised in those 1267 cases. 1268 Exception: if there is an Exception propagated from any subprocess. When 1269 this is raised, a namedtuple object can be retrieved by 1270 `tf.__internal__.distribute.multi_process_runner 1271 .UnexpectedSubprocessExitError` 1272 `mpr_result` attribute, which has the 1273 same structure as above 'Returns' section describes. 1274 1275 Examples: 1276 1277 ```python 1278 class SimpleMultiProcessTest(tf.test.TestCase): 1279 1280 def test_simple_printing_and_return(self): 1281 1282 def fn(): 1283 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 1284 1285 # This will print "[chief-0]: Task type: chief , task id: 0" 1286 # for chief, for example. 1287 logging.info('Task type: %s, task id: %d', 1288 resolver.task_type, resolver.task_id) 1289 1290 return resolver.task_type 1291 1292 result = tf.__internal__.distribute.multi_process_runner.run( 1293 fn=fn, 1294 cluster_spec=( 1295 tf.__internal__ 1296 .distribute.multi_process_runner.create_cluster_spec( 1297 has_chief=True, num_workers=2))) 1298 assert sorted(result.return_value) == ['chief', 'worker', 'worker'] 1299 1300 def test_error_from_fn(self): 1301 1302 def fn(): 1303 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 1304 raise ValueError('Task type {}, task id {} is errors out'.format( 1305 resolver.task_type, resolver.task_id)) 1306 1307 with self.assertRaisesRegexp(ValueError, 1308 'Task type worker, task id 0 is errors out'): 1309 cluster_spec = ( 1310 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 1311 num_workers=1)) 1312 tf.__internal__.distribute.multi_process_runner.run( 1313 fn=fn, cluster_spec=cluster_spec) 1314 1315 1316 if __name__ == '__main__': 1317 tf.__internal__.distribute.multi_process_runner.test_main() 1318 ``` 1319 """ 1320 runner = MultiProcessRunner( 1321 fn, 1322 cluster_spec, 1323 rpc_layer, 1324 max_run_time=max_run_time, 1325 return_output=return_output, 1326 args=args, 1327 kwargs=kwargs) 1328 runner.start() 1329 return runner.join(timeout) 1330 1331 1332# This is set by MultiProcessRunner in worker processes. 1333_barrier = None 1334 1335 1336@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[]) 1337def get_barrier(): 1338 """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`. 1339 1340 `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns 1341 a `multiprocessing.Barrier` object which can be used within `fn` of 1342 `tf.__internal__.distribute.multi_process_runner` to wait with 1343 `barrier.wait()` call until all other tasks have also reached the 1344 `barrier.wait()` call, before they can proceed individually. 1345 1346 Note that all tasks (subprocesses) have to reach `barrier.wait()` call to 1347 proceed. Currently it is not supported to block on only a subset of tasks 1348 in the cluster. 1349 1350 Example: 1351 ```python 1352 1353 def fn(): 1354 some_work_to_be_done_by_all_tasks() 1355 1356 tf.__internal__.distribute.multi_process_runner.get_barrier().wait() 1357 1358 # The barrier guarantees that at this point, all tasks have finished 1359 # `some_work_to_be_done_by_all_tasks()` 1360 some_other_work_to_be_done_by_all_tasks() 1361 1362 result = tf.__internal__.distribute.multi_process_runner.run( 1363 fn=fn, 1364 cluster_spec=( 1365 tf.__internal__ 1366 .distribute.multi_process_runner.create_cluster_spec( 1367 num_workers=2))) 1368 ``` 1369 1370 1371 Returns: 1372 A `multiprocessing.Barrier` for `multi_process_runner.run`. 1373 """ 1374 if _barrier is None: 1375 raise ValueError( 1376 'barrier is not defined. It is likely because you are calling ' 1377 'get_barrier() in the main process. get_barrier() can only be called ' 1378 'in the subprocesses.' 1379 ) 1380 return _barrier 1381 1382 1383_manager = None 1384_manager_lock = threading.Lock() 1385 1386 1387def manager(): 1388 """Returns the multiprocessing manager object for concurrency tools. 1389 1390 The manager object is useful as it controls a server process that holds 1391 the python objects that can be shared across processes. This can be used 1392 for parent-subprocess communication: 1393 1394 ```python 1395 manager = multi_process_runner.manager() 1396 some_event_happening_in_subprocess = manager.Event() 1397 mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec, 1398 args=(some_event_happening_in_subprocess,)) 1399 mpr.start() 1400 some_event_happening_in_subprocess.wait() 1401 # Do something that only should after some event happens in subprocess. 1402 ``` 1403 1404 Note that the user of multi_process_runner should not create additional 1405 `multiprocessing.Manager()` objects; doing so can result in segfault in 1406 some cases. 1407 1408 This method should only be called after multi_process_runner.test_main() is 1409 called. 1410 """ 1411 _check_initialization() 1412 global _manager 1413 with _manager_lock: 1414 if _manager is None: 1415 _manager = multiprocessing.Manager() 1416 return _manager 1417 1418 1419@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[]) 1420def test_main(): 1421 """Main function to be called within `__main__` of a test file. 1422 1423 Any test module that uses 1424 `tf.__internal__.distribute.multi_process_runner.run()` 1425 must call this instead of regular `test.main()` inside 1426 `if __name__ == '__main__':` block, or an error will be raised when 1427 `tf.__internal__.distribute.multi_process_runner.run()` is used. This method 1428 takes 1429 care of needed initialization for launching multiple subprocesses. 1430 1431 Example: 1432 ```python 1433 class MyTestClass(tf.test.TestCase): 1434 def testSomething(self): 1435 # Testing code making use of 1436 # `tf.__internal__.distribute.multi_process_runner.run()`. 1437 1438 if __name__ == '__main__': 1439 tf.__internal__.distribute.multi_process_runner.test_main() 1440 ``` 1441 """ 1442 # Inject tearDownModule() to shut down all pool runners. Active pool runners 1443 # will block the program from exiting. This is necessary for global pool 1444 # runners. We tried atexit in the past, and it doesn't work in some 1445 # deployment. 1446 old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule', 1447 None) 1448 1449 def tear_down_module(): 1450 _shutdown_all_pool_runners() 1451 if old_tear_down_module is not None: 1452 old_tear_down_module() 1453 1454 setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module) 1455 multi_process_lib.test_main() 1456