1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for `ClusterCoordinator` and relevant cluster-worker related library. 16 17This is currently under development and the API is subject to change. 18""" 19 20import collections 21import contextlib 22import os 23import re 24import threading 25import time 26import weakref 27 28from six.moves import queue 29 30from tensorflow.python.distribute import parameter_server_strategy_v2 31from tensorflow.python.distribute.coordinator import coordinator_context 32from tensorflow.python.distribute.coordinator import metric_utils 33from tensorflow.python.distribute.coordinator import values as values_lib 34from tensorflow.python.distribute.coordinator import watchdog 35from tensorflow.python.eager import cancellation 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.eager import executor 39from tensorflow.python.eager import function as tf_function 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import func_graph 42from tensorflow.python.framework import ops 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.util import nest 45from tensorflow.python.util.tf_export import tf_export 46 47# Maximum time for failed worker to come back is 1 hour 48_WORKER_MAXIMUM_RECOVERY_SEC = 3600 49 50# Maximum size for queued closures, "infinite" if set to 0. 51# When the maximum queue size is reached, further schedule calls will become 52# blocking until some previously queued closures are executed on workers. 53# Note that using an "infinite" queue size can take a non-trivial portion of 54# memory, and even lead to coordinator OOM. Modify the size to a smaller value 55# for coordinator with constrained memory resource (only recommended for 56# advanced users). Also used in unit tests to ensure the correctness when the 57# queue is full. 58_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 59 60# RPC error message from PS 61_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 62 63# InvalidArgumentError (unknown device) will not have "GRPC error..." string. 64_JOB_WORKER_STRING_IDENTIFIER = "/job:worker" 65 66 67RemoteValueStatus = values_lib.RemoteValueStatus 68RemoteValue = values_lib.RemoteValue 69RemoteValueImpl = values_lib.RemoteValueImpl 70PerWorkerValues = values_lib.PerWorkerValues 71 72 73class ClosureInputError(Exception): 74 """Wrapper for errors from resource building. 75 76 When a closure starts, it first checks for errors in any of its inputs, which 77 are RemoteValues from resource closures. If there were any errors, it wraps 78 the exception in this class and raises so it can be handled by the worker 79 failure handler. 80 81 Attributes: 82 original_exception: 83 """ 84 85 def __init__(self, original_exception): 86 # Avoid doubly-nested errors 87 if isinstance(original_exception, 88 (ClosureInputError, ClosureAbortedError)): 89 self.original_exception = original_exception.original_exception 90 else: 91 self.original_exception = original_exception 92 message = ("Input has an error, the original exception is %r, " 93 "error message is %s." % 94 (self.original_exception, str(self.original_exception))) 95 super().__init__(message) 96 self.with_traceback(original_exception.__traceback__) 97 98 99class ClosureAbortedError(Exception): 100 """Wrapper for errors from training closures, to attach to resource closures. 101 102 This wrapper is used when a dependent training closure fails to set errors on 103 its required resource closures. 104 105 Attributes: 106 original_exception: The Exception to wrap 107 """ 108 109 def __init__(self, original_exception): 110 # Avoid doubly-nested errors 111 if isinstance(original_exception, 112 (ClosureInputError, ClosureAbortedError)): 113 self.original_exception = original_exception.original_exception 114 else: 115 self.original_exception = original_exception 116 message = ("Other function has an execution error, as a result, the " 117 "current value is not available. The original exception is %r, " 118 "error message is %s." % 119 (self.original_exception, str(self.original_exception))) 120 super().__init__(message) 121 self.with_traceback(original_exception.__traceback__) 122 123 124def _get_error_from_remote_values(structure): 125 """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" 126 errors_in_structure = [] 127 128 def _get_error(val): 129 if isinstance(val, RemoteValue): 130 error = val._get_error() # pylint: disable=protected-access 131 if error: 132 errors_in_structure.append(error) 133 134 nest.map_structure(_get_error, structure) 135 if errors_in_structure: 136 return errors_in_structure[0] 137 else: 138 return None 139 140 141def _maybe_get_remote_value(val): 142 """Gets the value of `val` if it is a `RemoteValue`.""" 143 if isinstance(val, RemoteValue): 144 error = val._get_error() # pylint: disable=protected-access 145 if error: 146 raise AssertionError( 147 "RemoteValue doesn't have a value because it has error %r:%s" % 148 (error, error)) 149 elif val._status is not RemoteValueStatus.READY: # pylint: disable=protected-access 150 raise AssertionError("The input RemoteValue has not been executed.") 151 else: 152 return val._get_values() # pylint: disable=protected-access 153 else: 154 return val 155 156 157def _maybe_as_type_spec(val): 158 if isinstance(val, (RemoteValue, PerWorkerValues)): 159 if val._type_spec is None: # pylint: disable=protected-access 160 raise ValueError("Output of a scheduled function that is not " 161 "tf.function cannot be the input of another function.") 162 return val._type_spec # pylint: disable=protected-access 163 else: 164 return val 165 166 167def _select_worker_slice(worker_id, structured): 168 """Selects the worker slice of each of the items in `structured`.""" 169 170 def _get(x): 171 return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access 172 173 return nest.map_structure(_get, structured) 174 175 176def _disallow_remote_value_as_input(structured): 177 """Raises if any element of `structured` is a RemoteValue.""" 178 179 def _raise_if_remote_value(x): 180 if isinstance(x, RemoteValue): 181 raise ValueError( 182 "`tf.distribute.experimental.coordinator.RemoteValue` used " 183 "as an input to scheduled function is not yet " 184 "supported.") 185 186 nest.map_structure(_raise_if_remote_value, structured) 187 188 189class Closure(object): 190 """Hold a function to be scheduled and its arguments.""" 191 192 def __init__(self, function, cancellation_mgr, args=None, kwargs=None): 193 if not callable(function): 194 raise ValueError("Function passed to `ClusterCoordinator.schedule` must " 195 "be a callable object.") 196 self._args = args or () 197 self._kwargs = kwargs or {} 198 199 _disallow_remote_value_as_input(self._args) 200 _disallow_remote_value_as_input(self._kwargs) 201 202 if isinstance(function, def_function.Function): 203 replica_args = _select_worker_slice(0, self._args) 204 replica_kwargs = _select_worker_slice(0, self._kwargs) 205 206 # Note: no need to handle function registration failure since this kind of 207 # failure will not raise exceptions as designed in the runtime. The 208 # coordinator has to rely on subsequent operations that raise to catch 209 # function registration failure. 210 211 # Record the function tracing overhead. Note that we pass in the tracing 212 # count of the def_function.Function as a state tracker, so that metrics 213 # will only record the time for actual function tracing (i.e., excluding 214 # function cache lookups). 215 with metric_utils.monitored_timer( 216 "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access 217 self._concrete_function = function.get_concrete_function( 218 *nest.map_structure(_maybe_as_type_spec, replica_args), 219 **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) 220 elif isinstance(function, tf_function.ConcreteFunction): 221 self._concrete_function = function 222 223 if hasattr(self, "_concrete_function"): 224 # If we have a concrete function, we get to retrieve the output type spec 225 # via the structured_output. 226 self._output_type_spec = func_graph.convert_structure_to_signature( 227 self._concrete_function.structured_outputs) 228 self._function = cancellation_mgr.get_cancelable_function( 229 self._concrete_function) 230 else: 231 # Otherwise (i.e. what is passed in is a regular python function), we have 232 # no such information. 233 self._output_type_spec = None 234 self._function = function 235 236 self._output_remote_value_ref = None 237 238 def build_output_remote_value(self): 239 if self._output_remote_value_ref is None: 240 ret = RemoteValueImpl(None, self._output_type_spec) 241 self._output_remote_value_ref = weakref.ref(ret) 242 return ret 243 else: 244 raise ValueError( 245 "The output of the Closure cannot be built more than once.") 246 247 def maybe_call_with_output_remote_value(self, method): 248 if self._output_remote_value_ref is None: 249 return None 250 output_remote_value = self._output_remote_value_ref() 251 if output_remote_value is not None: 252 return method(output_remote_value) 253 return None 254 255 def mark_cancelled(self): 256 e = errors.CancelledError( 257 None, None, "The corresponding function is " 258 "cancelled. Please reschedule the function.") 259 self.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 260 261 def execute_on(self, worker): 262 """Executes the closure on the given worker. 263 264 Args: 265 worker: a `Worker` object. 266 """ 267 replica_args = _select_worker_slice(worker.worker_index, self._args) 268 replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs) 269 270 e = ( 271 _get_error_from_remote_values(replica_args) or 272 _get_error_from_remote_values(replica_kwargs)) 273 if e: 274 if not isinstance(e, ClosureInputError): 275 e = ClosureInputError(e) 276 raise e 277 278 with ops.device(worker.device_name): 279 with context.executor_scope(worker.executor): 280 with coordinator_context.with_dispatch_context(worker): 281 with metric_utils.monitored_timer("closure_execution"): 282 output_values = self._function( 283 *nest.map_structure(_maybe_get_remote_value, replica_args), 284 **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) 285 self.maybe_call_with_output_remote_value( 286 lambda r: r._set_values(output_values)) # pylint: disable=protected-access 287 288 289class ResourceClosure(Closure): 290 291 def build_output_remote_value(self): 292 if self._output_remote_value_ref is None: 293 # We need to remember the Closure object in the `RemoteValue` here. 294 ret = RemoteValueImpl(self, self._output_type_spec) 295 self._output_remote_value_ref = weakref.ref(ret) 296 return ret 297 else: 298 return self._output_remote_value_ref() 299 300 301class _CoordinatedClosureQueue(object): 302 """Manage a queue of closures, inflight count and errors from execution. 303 304 This class is thread-safe. 305 """ 306 307 def __init__(self): 308 # `self._inflight_closure_count` only tracks the number of inflight closures 309 # that are "in generation". Once an error occurs, error generation is 310 # incremented and all subsequent arriving closures (from inflight) are 311 # considered "out of generation". 312 self._inflight_closure_count = 0 313 314 self._queue_lock = threading.Lock() 315 316 # Condition indicating that all pending closures (either queued or inflight) 317 # have been processed, failed, or cancelled. 318 self._stop_waiting_condition = threading.Condition(self._queue_lock) 319 320 # Condition indicating that an item becomes available in queue (not empty). 321 self._closures_queued_condition = threading.Condition(self._queue_lock) 322 self._should_process_closures = True 323 324 # Condition indicating that a queue slot becomes available (not full). 325 # Note that even with "infinite" queue size, there is still a "practical" 326 # size limit for the queue depending on host memory capacity, and thus the 327 # queue will eventually become full with a lot of enqueued closures. 328 self._queue_free_slot_condition = threading.Condition(self._queue_lock) 329 330 # Condition indicating there is no inflight closures. 331 self._no_inflight_closure_condition = threading.Condition(self._queue_lock) 332 333 # Use to cancel in-flight closures. 334 self._cancellation_mgr = cancellation.CancellationManager() 335 336 if _CLOSURE_QUEUE_MAX_SIZE <= 0: 337 logging.warning( 338 "In a `ClusterCoordinator`, creating an infinite closure queue can " 339 "consume a significant amount of memory and even lead to OOM.") 340 self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) 341 self._tagged_queue = collections.defaultdict(queue.Queue) 342 self._error = None 343 344 # The following is a lock to make sure when `wait` is called and before it 345 # returns no `put` can be executed during this period. It is because `wait` 346 # won't know what to do with newly put closures. This lock adds an cutoff 347 # for `wait` so that closures put into the queue while waiting would not be 348 # taken responsible by this `wait`. 349 # 350 # We cannot reuse the `self._queue_lock` since when `wait` waits for a 351 # condition, the `self._queue_lock` will be released. 352 # 353 # We don't use a reader/writer's lock on purpose to reduce the complexity 354 # of the code. 355 self._put_wait_lock = threading.Lock() 356 357 self._watchdog = watchdog.WatchDog(on_triggered=self._on_watchdog_timeout) 358 359 def _on_watchdog_timeout(self): 360 logging.info("inflight_closure_count is %d", self._inflight_closure_count) 361 logging.info("current error is %s:%r", self._error, self._error) 362 363 def stop(self): 364 with self._queue_lock: 365 self._should_process_closures = False 366 self._cancellation_mgr.start_cancel() 367 self._closures_queued_condition.notify_all() 368 self._watchdog.stop() 369 370 def _cancel_all_closures(self): 371 """Clears the queue and sets remaining closures cancelled error. 372 373 This method expects self._queue_lock to be held prior to entry. 374 """ 375 self._cancellation_mgr.start_cancel() 376 logging.info("Canceling all closures: waiting for inflight closures to " 377 "finish") 378 while self._inflight_closure_count > 0: 379 self._no_inflight_closure_condition.wait() 380 logging.info("Canceling all closures: canceling remaining closures on the " 381 "queue") 382 while True: 383 try: 384 closure = self._queue.get(block=False) 385 self._queue_free_slot_condition.notify() 386 closure.mark_cancelled() 387 except queue.Empty: 388 break 389 # The cancellation manager cannot be reused once cancelled. After all 390 # closures (queued or inflight) are cleaned up, recreate the cancellation 391 # manager with clean state. 392 # Note on thread-safety: this is triggered when one of theses 393 # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the 394 # same time, no new closures can be constructed (which reads the 395 # _cancellation_mgr to get cancellable functions). 396 self._cancellation_mgr = cancellation.CancellationManager() 397 398 def _raise_if_error(self): 399 """Raises the error if one exists. 400 401 If an error exists, cancel the closures in queue, raises it, and clear 402 the error. 403 404 This method expects self._queue_lock to be held prior to entry. 405 """ 406 if self._error: 407 logging.error("Start cancelling closures due to error %r: %s", 408 self._error, self._error) 409 self._cancel_all_closures() 410 try: 411 raise self._error # pylint: disable=raising-bad-type 412 finally: 413 self._error = None 414 415 def put(self, closure, tag=None): 416 """Put a closure into the queue for later execution. 417 418 If `mark_failed` was called before `put`, the error from the first 419 invocation of `mark_failed` will be raised. 420 421 Args: 422 closure: The `Closure` to put into the queue. 423 tag: if not None, put into a queue with the given tag. 424 """ 425 closure.tag = tag 426 if tag is not None: 427 with self._queue_lock: 428 self._tagged_queue[tag].put(closure, block=False) 429 self._closures_queued_condition.notifyAll() 430 else: 431 with self._put_wait_lock, self._queue_lock: 432 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 433 self._queue.put(closure, block=False) 434 self._raise_if_error() 435 self._closures_queued_condition.notify() 436 437 def get(self, timeout=None, tag=None): 438 """Return a closure from the queue to be executed. 439 440 It will try to fetch an item from the queue with the given tag. If this 441 queue is empty, it will then check the global queue. 442 443 Args: 444 timeout: timeout when waiting for a closure to be put. 445 tag: optional tag to specify which queue to query first before querying 446 the global queue. 447 448 Returns: 449 a closure or None after timeout. 450 """ 451 with self._queue_lock: 452 while (self._should_process_closures and self._queue.empty() and 453 (tag is None or self._tagged_queue[tag].empty())): 454 if not self._closures_queued_condition.wait(timeout=timeout): 455 return None 456 if not self._should_process_closures: 457 return None 458 if tag is not None and not self._tagged_queue[tag].empty(): 459 closure = self._tagged_queue[tag].get(block=False) 460 return closure 461 closure = self._queue.get(block=False) 462 assert closure.tag is None 463 assert tag is None or self._tagged_queue[tag].empty() 464 self._queue_free_slot_condition.notify() 465 self._inflight_closure_count += 1 466 return closure 467 468 def mark_finished(self): 469 """Let the queue know that a closure has been successfully executed.""" 470 with self._queue_lock: 471 if self._inflight_closure_count < 1: 472 raise AssertionError("There is no inflight closures to mark_finished.") 473 self._inflight_closure_count -= 1 474 if self._inflight_closure_count == 0: 475 self._no_inflight_closure_condition.notify_all() 476 if self._queue.empty() and self._inflight_closure_count == 0: 477 self._stop_waiting_condition.notify_all() 478 self._watchdog.report_closure_done() 479 480 def put_back(self, closure): 481 """Put the closure back into the queue as it was not properly executed.""" 482 assert closure.tag is None 483 with self._queue_lock: 484 if self._inflight_closure_count < 1: 485 raise AssertionError("There is no inflight closures to put_back.") 486 if self._error: 487 closure.mark_cancelled() 488 else: 489 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 490 self._queue.put(closure, block=False) 491 self._closures_queued_condition.notify() 492 self._inflight_closure_count -= 1 493 if self._inflight_closure_count == 0: 494 self._no_inflight_closure_condition.notify_all() 495 496 def wait(self, timeout=None): 497 """Wait for all closures to be finished before returning. 498 499 If `mark_failed` was called before or during `wait`, the error from the 500 first invocation of `mark_failed` will be raised. 501 502 Args: 503 timeout: A float specifying a timeout for the wait in seconds. 504 505 Returns: 506 True unless the given timeout expired, in which case it returns False. 507 """ 508 with self._put_wait_lock, self._queue_lock: 509 logging.info("Waiting for all global closures to be finished.") 510 while (not self._error and 511 (not self._queue.empty() or self._inflight_closure_count > 0)): 512 if not self._stop_waiting_condition.wait(timeout=timeout): 513 return False 514 self._raise_if_error() 515 return True 516 517 def mark_failed(self, e): 518 """Sets error and unblocks any wait() call.""" 519 with self._queue_lock: 520 # TODO(yuefengz): maybe record all failure and give users more 521 # information? 522 if self._inflight_closure_count < 1: 523 raise AssertionError("There is no inflight closures to mark_failed.") 524 if self._error is None: 525 self._error = e 526 self._inflight_closure_count -= 1 527 if self._inflight_closure_count == 0: 528 self._no_inflight_closure_condition.notify_all() 529 self._stop_waiting_condition.notify_all() 530 531 def done(self): 532 """Returns true if the queue is empty and there is no inflight closure. 533 534 If `mark_failed` was called before `done`, the error from the first 535 invocation of `mark_failed` will be raised. 536 """ 537 with self._queue_lock: 538 self._raise_if_error() 539 return self._queue.empty() and self._inflight_closure_count == 0 540 541 def clear_tag_unlocked(self, tag): 542 self._tagged_queue[tag] = queue.Queue() 543 544 545class WorkerPreemptionHandler(object): 546 """Handles worker preemptions.""" 547 548 def __init__(self, server_def, cluster): 549 self._server_def = server_def 550 self._cluster = cluster 551 self._cluster_update_lock = threading.Lock() 552 self._cluster_due_for_update_or_finish = threading.Event() 553 self._worker_up_cond = threading.Condition(self._cluster_update_lock) 554 self._error_from_recovery = None 555 self._should_preemption_thread_run = True 556 self._preemption_handler_thread = threading.Thread( 557 target=self._preemption_handler, 558 name="WorkerPreemptionHandler", 559 daemon=True) 560 self._preemption_handler_thread.start() 561 562 def stop(self): 563 """Ensure the worker preemption thread is closed.""" 564 self._should_preemption_thread_run = False 565 with self._cluster_update_lock: 566 self._cluster_due_for_update_or_finish.set() 567 # TODO(yuefengz): The preemption handler thread shouldn't be terminated 568 # asynchronously since it touches eager context which is a process-wide 569 # singleton. The problem is in OSS unit tests will time out. 570 571 def _validate_preemption_failure(self, e): 572 """Validates that the given exception represents worker preemption.""" 573 574 # Only categorize the failure as a worker preemption if the cancellation 575 # manager did not attempt to cancel the blocking operations. 576 if _is_worker_failure(e) and ( 577 not self._cluster.closure_queue._cancellation_mgr.is_cancelled): # pylint: disable=protected-access 578 return 579 raise e 580 581 @contextlib.contextmanager 582 def wait_on_failure(self, 583 on_failure_fn=None, 584 on_transient_failure_fn=None, 585 on_recovery_fn=None, 586 worker_device_name="(unknown)"): 587 """Catches worker preemption error and wait until failed workers are back. 588 589 Args: 590 on_failure_fn: an optional function to run if preemption happens. 591 on_transient_failure_fn: an optional function to run if transient failure 592 happens. 593 on_recovery_fn: an optional function to run when a worker is recovered 594 from preemption. 595 worker_device_name: the device name of the worker instance that is passing 596 through the failure. 597 598 Yields: 599 None. 600 """ 601 assert self._should_preemption_thread_run 602 try: 603 yield 604 except (errors.OpError, ClosureInputError, 605 ClosureAbortedError) as e: 606 # If the error is due to temporary connectivity issues between worker and 607 # ps, put back closure, ignore error and do not mark worker as failure. 608 if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access 609 logging.error( 610 "Remote function on worker %s failed with %r:%s\n" 611 "It is treated as a transient connectivity failure for now.", 612 worker_device_name, e, e) 613 if on_transient_failure_fn: 614 on_transient_failure_fn() 615 return 616 617 # If the error is due to temporary connectivity issues that cause the 618 # server-side RPCs to be cancelled, TF might not abort the step and the 619 # closure might timeout. The coordinator ignores certain amount of such 620 # failures without marking worker as failure. 621 if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access 622 logging.error( 623 "Remote function on worker %s failed with %r:%s\n" 624 "This derived error is ignored and not reported to users.", 625 worker_device_name, e, e) 626 if on_transient_failure_fn: 627 on_transient_failure_fn() 628 return 629 630 # Ignoring derived CancelledErrors to tolerate transient failures in 631 # PS-worker communication, which initially exposed as an UnavailableError 632 # and then lead to sub-function cancellation, subsequently getting 633 # reported from worker to chief as CancelledError. 634 # We do not mark either worker or PS as failed due to only CancelledError. 635 # If there are real (non-transient) failures, they must also be reported 636 # as other errors (UnavailableError most likely) in closure executions. 637 if isinstance(e, errors.CancelledError) and "/job:" in str(e): 638 logging.error( 639 "Remote function on worker %s failed with %r:%s\n" 640 "This derived error is ignored and not reported to users.", 641 worker_device_name, e, e) 642 if on_transient_failure_fn: 643 on_transient_failure_fn() 644 return 645 646 # This reraises the error, if it's not considered recoverable; otherwise, 647 # the following failure recovery logic run. At this time, only worker 648 # unavailability is recoverable. PS unavailability as well as other 649 # errors in the user function is not recoverable. 650 self._validate_preemption_failure(e) 651 652 logging.error("Worker %s failed with %r:%s", worker_device_name, e, e) 653 if on_failure_fn: 654 on_failure_fn(e) 655 656 with self._cluster_update_lock: 657 self._cluster_due_for_update_or_finish.set() 658 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) 659 if self._error_from_recovery: 660 # TODO(yuefengz): there is only one worker that will get this error. 661 # Ideally we shuold let all workers notified by `_worker_up_cond` get 662 # this error. 663 try: 664 raise self._error_from_recovery 665 finally: 666 self._error_from_recovery = None 667 logging.info("Worker %s has been recovered.", worker_device_name) 668 669 if on_recovery_fn: 670 logging.info("Worker %s calling on_recovery_fn", worker_device_name) 671 with self.wait_on_failure( 672 on_recovery_fn=on_recovery_fn, 673 on_transient_failure_fn=on_transient_failure_fn, 674 worker_device_name=worker_device_name): 675 on_recovery_fn() 676 677 def _preemption_handler(self): 678 """A loop that handles preemption. 679 680 This loop waits for signal of worker preemption and upon worker preemption, 681 it waits until all workers are back and updates the cluster about the 682 restarted workers. 683 """ 684 assert self._should_preemption_thread_run 685 while True: 686 self._cluster_due_for_update_or_finish.wait() 687 if not self._should_preemption_thread_run: 688 logging.info("Stopping the failure handing thread.") 689 break 690 691 with self._cluster_update_lock: 692 try: 693 # TODO(haoyuzhang): support partial cluster recovery 694 logging.info("Cluster now being recovered.") 695 context.context().update_server_def(self._server_def) 696 697 # Cluster updated successfully, clear the update signal, and notify 698 # all workers that they are recovered from failure. 699 logging.info("Cluster successfully recovered.") 700 self._worker_up_cond.notify_all() 701 # The check for _should_preemption_thread_run is necessary since the 702 # `stop` may have already set _cluster_due_for_update_or_finish. 703 if self._should_preemption_thread_run: 704 self._cluster_due_for_update_or_finish.clear() 705 except Exception as e: # pylint: disable=broad-except 706 logging.info("Error occurred while updating server def: %s", e) 707 try: 708 self._validate_preemption_failure(e) 709 except Exception as ps_e: # pylint: disable=broad-except 710 logging.info("Error that occurred while updating server def is not " 711 "a worker failure. So set it as _error_from_recovery") 712 # In this case, a parameter server fails. So we raise this error to 713 # the caller of `wait_on_failure`. 714 self._error_from_recovery = ps_e 715 self._worker_up_cond.notify_all() 716 if self._should_preemption_thread_run: 717 self._cluster_due_for_update_or_finish.clear() 718 # NOTE: Since the first RPC (GetStatus) of update_server_def is 719 # currently blocking by default, error should only happen if: 720 # (1) More workers failed while waiting for the previous workers to 721 # come back; 722 # (2) Worker failed when exchanging subsequent RPCs after the first 723 # RPC returns. 724 # Consider adding backoff retry logic if we see the error logged 725 # too frequently. 726 logging.error("Cluster update failed with error: %s. Retrying...", e) 727 728 729class Worker(object): 730 """A worker in a cluster. 731 732 Attributes: 733 worker_index: The index of the worker in the cluster. 734 device_name: The device string of the worker, e.g. "/job:worker/task:1". 735 executor: The worker's executor for remote function execution. 736 failure_handler: The failure handler used to handler worker preemption 737 failure. 738 """ 739 740 def __init__(self, worker_index, device_name, cluster): 741 self.worker_index = worker_index 742 self.device_name = device_name 743 self.executor = executor.new_executor(enable_async=False) 744 self.failure_handler = cluster.failure_handler 745 self._cluster = cluster 746 self._resource_tracking_lock = threading.Lock() 747 self._resource_remote_value_refs = [] 748 self._is_dead_with_error = None 749 self._should_worker_thread_run = True 750 751 # Worker threads need to start after `Worker`'s initialization. 752 threading.Thread(target=self._process_queue, 753 name="WorkerClosureProcessingLoop-%d" % self.worker_index, 754 daemon=True).start() 755 756 def stop(self): 757 """Ensure the worker thread is closed.""" 758 self._should_worker_thread_run = False 759 760 def _schedule_resource(self, closure): 761 self._cluster.closure_queue.put(closure, tag=self.worker_index) 762 763 def _set_resources_aborted(self, e): 764 """Set the resource ABORTED and add an error to it.""" 765 # TODO(yuefengz): maybe we can query whether a tensor is valid or not 766 # instead of marking a tensor aborted? 767 logging.info("[Worker %d] Clearing all resources.", self.worker_index) 768 for weakref_resource in self._resource_remote_value_refs: 769 resource = weakref_resource() 770 if resource: 771 # It is important to set an error on an aborted RemoteValue from a 772 # ResourceClosure because its failure will not trigger the worker thread 773 # to raise error immediately and the worker may continue executing 774 # closures taking it as an input. The error will then be correctly 775 # reported to users. 776 resource._set_aborted(ClosureAbortedError(e)) # pylint: disable=protected-access 777 778 def _on_closure_failure(self, closure, e): 779 logging.info("[Worker %d] Putting back a closure after it failed.", 780 self.worker_index) 781 self._cluster.closure_queue.put_back(closure) 782 783 with self._resource_tracking_lock: 784 self._is_dead_with_error = e 785 self._set_resources_aborted(e) 786 787 def _on_resource_closure_failure(self, e): 788 """Clear tagged queue to ensure resource closures are rebuilt. 789 790 Args: 791 e: The exception arisen from the resource closure. 792 """ 793 logging.info("[Worker %d] Clearing tagged queue after resource closure " 794 "failure.", self.worker_index) 795 with self._resource_tracking_lock: 796 self._is_dead_with_error = e 797 # No locking on queue is needed since 798 # * get will not happen concurrently here. 799 # * put to the specific tagged queue will be guarded by 800 # `self._resource_tracking_lock`. 801 self._cluster.closure_queue.clear_tag_unlocked(self.worker_index) 802 self._set_resources_aborted(e) 803 804 def _on_worker_recovery(self): 805 logging.info("[Worker %d] calling _on_worker_recovery", self.worker_index) 806 with self._resource_tracking_lock: 807 for weakref_resource in self._resource_remote_value_refs: 808 resource = weakref_resource() 809 if resource: 810 self._schedule_resource(resource._closure) # pylint: disable=protected-access 811 self._is_dead_with_error = False 812 813 def _process_closure(self, closure): 814 """Runs a closure with preemption handling.""" 815 try: 816 with self.failure_handler.wait_on_failure( 817 on_failure_fn=lambda e: self._on_closure_failure(closure, e), 818 on_transient_failure_fn=( 819 lambda: self._cluster.closure_queue.put_back(closure)), 820 on_recovery_fn=self._on_worker_recovery, 821 worker_device_name=self.device_name): 822 closure.execute_on(self) 823 with metric_utils.monitored_timer("remote_value_fetch"): 824 # Copy the remote tensor to local (the coordinator) in case worker 825 # becomes unavailable at a later time. 826 closure.maybe_call_with_output_remote_value(lambda r: r.get()) 827 self._cluster.closure_queue.mark_finished() 828 except Exception as e: # pylint: disable=broad-except 829 # Avoid logging the derived cancellation error 830 if not isinstance(e, errors.CancelledError): 831 logging.error( 832 " /job:worker/task:%d encountered the following error when " 833 "processing closure: %r:%s", self.worker_index, e, e) 834 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 835 self._cluster.closure_queue.mark_failed(e) 836 837 def _process_resource_closure(self, closure): 838 """Run the given resource closure with preemption handling.""" 839 assert closure.tag == self.worker_index 840 try: 841 with self.failure_handler.wait_on_failure( 842 on_failure_fn=self._on_resource_closure_failure, 843 on_transient_failure_fn=( 844 lambda: self._process_resource_closure(closure)), 845 on_recovery_fn=self._on_worker_recovery, 846 worker_device_name=self.device_name): 847 closure.execute_on(self) 848 except Exception as e: # pylint: disable=broad-except 849 # Avoid logging the derived cancellation error 850 logging.info("[Worker %d] got an exception when processing resource " 851 "closure", self.worker_index) 852 if not isinstance(e, errors.CancelledError): 853 logging.error( 854 " /job:worker/task:%d encountered the following error when " 855 "processing resource closure: %r:%s", self.worker_index, e, e) 856 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 857 858 def _maybe_delay(self): 859 """Delay if corresponding env vars are set.""" 860 # If the following two env vars variables are set. Scheduling for workers 861 # will start in a staggered manner. Worker i will wait for 862 # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding 863 # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`. 864 delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0")) 865 delay_secs *= self.worker_index 866 delay_cap = int( 867 os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0")) 868 if delay_cap: 869 delay_secs = min(delay_secs, delay_cap) 870 if delay_secs > 0: 871 logging.info(" Worker %d sleeping for %d seconds before running function", 872 self.worker_index, delay_secs) 873 time.sleep(delay_secs) 874 875 def _process_queue(self): 876 """Function running in a worker thread to process closure queues.""" 877 self._maybe_delay() 878 while self._should_worker_thread_run: 879 closure = self._cluster.closure_queue.get(tag=self.worker_index) 880 if not self._should_worker_thread_run or closure is None: 881 if closure is not None: 882 closure.mark_cancelled() 883 return 884 if isinstance(closure, ResourceClosure): 885 self._process_resource_closure(closure) 886 else: 887 self._process_closure(closure) 888 # To properly stop the worker and preemption threads, it is important that 889 # `ClusterCoordinator` object is not held onto so its `__del__` can be 890 # called. By removing the reference to the `closure` that has already been 891 # processed, we ensure that the `closure` object is released, while 892 # getting the next `closure` at above `self._cluster.closure_queue.get()` 893 # call. 894 del closure 895 896 def create_resource(self, function, args=None, kwargs=None): 897 """Synchronously creates a per-worker resource represented by a `RemoteValue`. 898 899 Args: 900 function: the resource function to be run remotely. It should be a 901 `tf.function`, a concrete function or a Python function. 902 args: positional arguments to be passed to the function. 903 kwargs: keyword arguments to be passed to the function. 904 905 Returns: 906 one or several RemoteValue objects depending on the function return 907 values. 908 """ 909 # Some notes about the concurrency: currently all the activities related to 910 # the same worker such as creating resources, setting resources' aborted 911 # status, and executing closures happen on the same thread. This allows us 912 # to have simpler logic of concurrency. 913 914 closure = ResourceClosure( 915 function, 916 self._cluster.resource_cancellation_mgr, 917 args=args, 918 kwargs=kwargs) 919 resource_remote_value = closure.build_output_remote_value() 920 with self._resource_tracking_lock: 921 self._register_resource(resource_remote_value) 922 if self._is_dead_with_error: 923 resource_remote_value._set_aborted( # pylint: disable=protected-access 924 ClosureAbortedError(self._is_dead_with_error)) 925 else: 926 self._schedule_resource(closure) 927 return resource_remote_value 928 929 def _register_resource(self, resource_remote_value): 930 if not isinstance(resource_remote_value, RemoteValue): 931 raise ValueError("Resource being registered is not of type " 932 "`tf.distribute.experimental.coordinator.RemoteValue`.") 933 self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) 934 935 936class Cluster(object): 937 """A cluster with workers. 938 939 We assume all function errors are fatal and based on this assumption our 940 error reporting logic is: 941 1) Both `schedule` and `join` can raise a non-retryable error which is the 942 first error seen by the coordinator from any previously scheduled functions. 943 2) When an error is raised, there is no guarantee on how many previously 944 scheduled functions have been executed; functions that have not been executed 945 will be thrown away and marked as cancelled. 946 3) After an error is raised, the internal state of error will be cleared. 947 I.e. functions can continue to be scheduled and subsequent calls of `schedule` 948 or `join` will not raise the same error again. 949 950 Attributes: 951 failure_handler: The failure handler used to handler worker preemption 952 failure. 953 workers: a list of `Worker` objects in the cluster. 954 closure_queue: the global Closure queue. 955 resource_cancellation_mgr: the cancellation manager used to cancel resource 956 closures. 957 """ 958 959 def __init__(self, strategy): 960 """Initializes the cluster instance.""" 961 962 self._num_workers = strategy._num_workers 963 self._num_ps = strategy._num_ps 964 965 # Ignore PS failures reported by workers due to transient connection errors. 966 # Transient connectivity issues between workers and PS are relayed by the 967 # workers to the coordinator, leading the coordinator to believe that there 968 # are PS failures. The difference between transient vs. permanent PS failure 969 # is the number of reports from the workers. When this env var is set to a 970 # positive integer K, the coordinator ignores up to K reports of a failed PS 971 # task, i.e., only when there are more than K trials of executing closures 972 # fail due to errors from the same PS instance do we consider the PS 973 # instance encounters a failure. 974 # TODO(b/164279603): Remove this workaround when the underlying connectivity 975 # issue in gRPC server is resolved. 976 self._transient_ps_failures_threshold = int( 977 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3)) 978 self._potential_ps_failures_lock = threading.Lock() 979 self._potential_ps_failures_count = [0] * self._num_ps 980 981 # Ignore worker timeouts due to transient connection errors. 982 # Transient connectivity issues might cause the server side to unexpectedly 983 # cancel RPC handling logic, leading to closure execution timeouts. When 984 # the _transient_timeout_threshold is set to a positive number, the cluster 985 # coordinator ignores DeadlineExceeded errors from workers for the specified 986 # times before raising the error to users. 987 self._transient_timeouts_threshold = int( 988 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS", 989 self._num_workers // 10)) 990 self._transient_timeouts_lock = threading.Lock() 991 self._transient_timeouts_count = 0 992 993 self.closure_queue = _CoordinatedClosureQueue() 994 self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), 995 self) 996 worker_device_strings = [ 997 "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) 998 ] 999 self.workers = [ 1000 Worker(i, w, self) for i, w in enumerate(worker_device_strings) 1001 ] 1002 1003 # Cancellation manager for all resource closures. 1004 self.resource_cancellation_mgr = cancellation.CancellationManager() 1005 1006 def stop(self): 1007 """Stop worker, worker preemption threads, and the closure queue.""" 1008 logging.info("Stopping cluster, starting with failure handler") 1009 self.failure_handler.stop() 1010 1011 logging.info("Stopping workers") 1012 for worker in self.workers: 1013 worker.stop() 1014 logging.info("Stopping queue") 1015 self.closure_queue.stop() 1016 logging.info("Start cancelling remote resource-building functions") 1017 self.resource_cancellation_mgr.start_cancel() 1018 1019 def _record_and_ignore_transient_ps_failure(self, e): 1020 """Records potential PS failures and return if failure should be ignored.""" 1021 if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e): 1022 return False 1023 1024 ps_tasks = _extract_failed_ps_instances(str(e)) 1025 with self._potential_ps_failures_lock: 1026 for t in ps_tasks: 1027 self._potential_ps_failures_count[t] += 1 1028 # The number of UnavailableError encountered on this PS task exceeds the 1029 # maximum number of ignored error 1030 if (self._potential_ps_failures_count[t] >= 1031 self._transient_ps_failures_threshold): 1032 return False 1033 return True 1034 1035 def _record_and_ignore_transient_timeouts(self, e): 1036 """Records observed timeout error and return if it should be ignored.""" 1037 if self._transient_timeouts_threshold <= 0: 1038 return False 1039 if not isinstance(e, errors.DeadlineExceededError): 1040 return False 1041 with self._transient_timeouts_lock: 1042 self._transient_timeouts_count += 1 1043 if self._transient_timeouts_count >= self._transient_timeouts_threshold: 1044 return False 1045 return True 1046 1047 def schedule(self, function, args, kwargs): 1048 """Schedules `function` to be dispatched to a worker for execution. 1049 1050 Args: 1051 function: The function to be dispatched to a worker for execution 1052 asynchronously. 1053 args: Positional arguments for `fn`. 1054 kwargs: Keyword arguments for `fn`. 1055 1056 Returns: 1057 A `RemoteValue` object. 1058 """ 1059 closure = Closure( 1060 function, 1061 self.closure_queue._cancellation_mgr, # pylint: disable=protected-access 1062 args=args, 1063 kwargs=kwargs) 1064 ret = closure.build_output_remote_value() 1065 self.closure_queue.put(closure) 1066 return ret 1067 1068 def join(self): 1069 """Blocks until all scheduled functions are executed.""" 1070 self.closure_queue.wait() 1071 1072 def done(self): 1073 """Returns true if all scheduled functions are executed.""" 1074 return self.closure_queue.done() 1075 1076 1077@tf_export("distribute.experimental.coordinator.ClusterCoordinator", 1078 "distribute.coordinator.ClusterCoordinator", v1=[]) 1079class ClusterCoordinator(object): 1080 """An object to schedule and coordinate remote function execution. 1081 1082 This class is used to create fault-tolerant resources and dispatch functions 1083 to remote TensorFlow servers. 1084 1085 Currently, this class is not supported to be used in a standalone manner. It 1086 should be used in conjunction with a `tf.distribute` strategy that is designed 1087 to work with it. The `ClusterCoordinator` class currently only works 1088 `tf.distribute.experimental.ParameterServerStrategy`. 1089 1090 __The `schedule`/`join` APIs__ 1091 1092 The most important APIs provided by this class is the `schedule`/`join` pair. 1093 The `schedule` API is non-blocking in that it queues a `tf.function` and 1094 returns a `RemoteValue` immediately. The queued functions will be dispatched 1095 to remote workers in background threads and their `RemoteValue`s will be 1096 filled asynchronously. Since `schedule` doesn’t require worker assignment, the 1097 `tf.function` passed in can be executed on any available worker. If the worker 1098 it is executed on becomes unavailable before its completion, it will be 1099 migrated to another worker. Because of this fact and function execution is not 1100 atomic, a function may be executed more than once. 1101 1102 __Handling Task Failure__ 1103 1104 This class when used with 1105 `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in 1106 fault tolerance for worker failures. That is, when some workers are not 1107 available for any reason to be reached from the coordinator, the training 1108 progress continues to be made with the remaining workers. Upon recovery of a 1109 failed worker, it will be added for function execution after datasets created 1110 by `create_per_worker_dataset` are re-built on it. 1111 1112 When a parameter server fails, a `tf.errors.UnavailableError` is raised by 1113 `schedule`, `join` or `done`. In this case, in addition to bringing back the 1114 failed parameter server, users should restart the coordinator so that it 1115 reconnects to workers and parameter servers, re-creates the variables, and 1116 loads checkpoints. If the coordinator fails, after the user brings it back, 1117 the program will automatically connect to workers and parameter servers, and 1118 continue the progress from a checkpoint. 1119 1120 It is thus essential that in user's program, a checkpoint file is periodically 1121 saved, and restored at the start of the program. If an 1122 `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a 1123 checkpoiont, its `iterations` property roughly indicates the number of steps 1124 that have been made. This can be used to decide how many epochs and steps are 1125 needed before the training completion. 1126 1127 See `tf.distribute.experimental.ParameterServerStrategy` docstring for an 1128 example usage of this API. 1129 1130 This is currently under development, and the API as well as implementation 1131 are subject to changes. 1132 """ 1133 1134 def __new__(cls, strategy): 1135 # `ClusterCoordinator` is kept as a single instance to a given `Strategy`. 1136 # TODO(rchao): Needs a lock for thread-safety 1137 if strategy._cluster_coordinator is None: 1138 strategy._cluster_coordinator = super( 1139 ClusterCoordinator, cls).__new__(cls) 1140 return strategy._cluster_coordinator 1141 1142 def __init__(self, strategy): 1143 """Initialization of a `ClusterCoordinator` instance. 1144 1145 Args: 1146 strategy: a supported `tf.distribute.Strategy` object. Currently, only 1147 `tf.distribute.experimental.ParameterServerStrategy` is supported. 1148 1149 Raises: 1150 ValueError: if the strategy being used is not supported. 1151 """ 1152 if not getattr(self, "_has_initialized", False): 1153 if not isinstance(strategy, 1154 parameter_server_strategy_v2.ParameterServerStrategyV2): 1155 raise ValueError( 1156 "Only `tf.distribute.experimental.ParameterServerStrategy` " 1157 "is supported to work with " 1158 "`tf.distribute.experimental.coordinator.ClusterCoordinator` " 1159 "currently.") 1160 self._strategy = strategy 1161 self.strategy.extended._used_with_coordinator = True 1162 self._cluster = Cluster(strategy) 1163 self._has_initialized = True 1164 1165 def __del__(self): 1166 logging.info("ClusterCoordinator destructor: stopping cluster") 1167 self._cluster.stop() 1168 1169 @property 1170 def strategy(self): 1171 """Returns the `Strategy` associated with the `ClusterCoordinator`.""" 1172 return self._strategy 1173 1174 def schedule(self, fn, args=None, kwargs=None): 1175 """Schedules `fn` to be dispatched to a worker for asynchronous execution. 1176 1177 This method is non-blocking in that it queues the `fn` which will be 1178 executed later and returns a 1179 `tf.distribute.experimental.coordinator.RemoteValue` object immediately. 1180 `fetch` can be called on it to wait for the function execution to finish 1181 and retrieve its output from a remote worker. On the other hand, call 1182 `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for 1183 all scheduled functions to finish. 1184 1185 `schedule` guarantees that `fn` will be executed on a worker at least once; 1186 it could be more than once if its corresponding worker fails in the middle 1187 of its execution. Note that since worker can fail at any point when 1188 executing the function, it is possible that the function is partially 1189 executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator` 1190 guarantees that in those events, the function will eventually be executed on 1191 any worker that is available. 1192 1193 If any previously scheduled function raises an error, `schedule` will raise 1194 any one of those errors, and clear the errors collected so far. What happens 1195 here, some of the previously scheduled functions may have not been executed. 1196 User can call `fetch` on the returned 1197 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 1198 executed, failed, or cancelled, and reschedule the corresponding function if 1199 needed. 1200 1201 When `schedule` raises, it guarantees that there is no function that is 1202 still being executed. 1203 1204 At this time, there is no support of worker assignment for function 1205 execution, or priority of the workers. 1206 1207 `args` and `kwargs` are the arguments passed into `fn`, when `fn` is 1208 executed on a worker. They can be 1209 `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case, 1210 the argument will be substituted with the corresponding component on the 1211 target worker. Arguments that are not 1212 `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into 1213 `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue` 1214 is not supported to be input `args` or `kwargs`. 1215 1216 Args: 1217 fn: A `tf.function`; the function to be dispatched to a worker for 1218 execution asynchronously. Regular python function is not supported to be 1219 scheduled. 1220 args: Positional arguments for `fn`. 1221 kwargs: Keyword arguments for `fn`. 1222 1223 Returns: 1224 A `tf.distribute.experimental.coordinator.RemoteValue` object that 1225 represents the output of the function scheduled. 1226 1227 Raises: 1228 Exception: one of the exceptions caught by the coordinator from any 1229 previously scheduled function, since the last time an error was thrown 1230 or since the beginning of the program. 1231 """ 1232 if not isinstance(fn, 1233 (def_function.Function, tf_function.ConcreteFunction)): 1234 raise TypeError( 1235 "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`" 1236 " only accepts a `tf.function` or a concrete function.") 1237 # Slot variables are usually created during function tracing time; thus 1238 # `schedule` needs to be called within the `strategy.scope()`. 1239 with self.strategy.scope(): 1240 self.strategy.extended._being_scheduled = True # pylint: disable=protected-access 1241 remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs) 1242 self.strategy.extended._being_scheduled = False # pylint: disable=protected-access 1243 return remote_value 1244 1245 def join(self): 1246 """Blocks until all the scheduled functions have finished execution. 1247 1248 If any previously scheduled function raises an error, `join` will fail by 1249 raising any one of those errors, and clear the errors collected so far. If 1250 this happens, some of the previously scheduled functions may have not been 1251 executed. Users can call `fetch` on the returned 1252 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 1253 executed, failed, or cancelled. If some that have been cancelled need to be 1254 rescheduled, users should call `schedule` with the function again. 1255 1256 When `join` returns or raises, it guarantees that there is no function that 1257 is still being executed. 1258 1259 Raises: 1260 Exception: one of the exceptions caught by the coordinator by any 1261 previously scheduled function since the last time an error was thrown or 1262 since the beginning of the program. 1263 """ 1264 self._cluster.join() 1265 1266 def done(self): 1267 """Returns whether all the scheduled functions have finished execution. 1268 1269 If any previously scheduled function raises an error, `done` will fail by 1270 raising any one of those errors. 1271 1272 When `done` returns True or raises, it guarantees that there is no function 1273 that is still being executed. 1274 1275 Returns: 1276 Whether all the scheduled functions have finished execution. 1277 Raises: 1278 Exception: one of the exceptions caught by the coordinator by any 1279 previously scheduled function since the last time an error was thrown or 1280 since the beginning of the program. 1281 """ 1282 return self._cluster.done() 1283 1284 def create_per_worker_dataset(self, dataset_fn): 1285 """Create dataset on each worker. 1286 1287 This creates dataset on workers from the input which can be either a 1288 `tf.data.Dataset`, a `tf.distribute.DistributedDataset` or a function which 1289 returns a dataset, and returns an object that represents the collection of 1290 those individual datasets. Calling `iter` on such collection of datasets 1291 returns a `tf.distribute.experimental.coordinator.PerWorkerValues`, which is 1292 a collection of iterators, where the iterators have been placed on 1293 respective workers. 1294 1295 Calling `next` on a `PerWorkerValues` of iterator is unsupported. The 1296 iterator is meant to be passed as an argument into 1297 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When 1298 the scheduled function is about to be executed by a worker, the 1299 function will receive the individual iterator that corresponds to the 1300 worker. The `next` method can be called on an iterator inside a 1301 scheduled function when the iterator is an input of the function. 1302 1303 Currently the `schedule` method assumes workers are all the same and thus 1304 assumes the datasets on different workers are the same, except they may be 1305 shuffled differently if they contain a `dataset.shuffle` operation and a 1306 random seed is not set. Because of this, we also recommend the datasets to 1307 be repeated indefinitely and schedule a finite number of steps instead of 1308 relying on the `OutOfRangeError` from a dataset. 1309 1310 1311 Example: 1312 1313 ```python 1314 strategy = tf.distribute.experimental.ParameterServerStrategy( 1315 cluster_resolver=...) 1316 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1317 strategy=strategy) 1318 1319 @tf.function 1320 def worker_fn(iterator): 1321 return next(iterator) 1322 1323 def per_worker_dataset_fn(): 1324 return strategy.distribute_datasets_from_function( 1325 lambda x: tf.data.Dataset.from_tensor_slices([3] * 3)) 1326 1327 per_worker_dataset = coordinator.create_per_worker_dataset( 1328 per_worker_dataset_fn) 1329 per_worker_iter = iter(per_worker_dataset) 1330 remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,)) 1331 assert remote_value.fetch() == 3 1332 ``` 1333 1334 Args: 1335 dataset_fn: The dataset function that returns a dataset. This is to be 1336 executed on the workers. 1337 1338 Returns: 1339 An object that represents the collection of those individual 1340 datasets. `iter` is expected to be called on this object that returns 1341 a `tf.distribute.experimental.coordinator.PerWorkerValues` of the 1342 iterators (that are on the workers). 1343 """ 1344 return values_lib.get_per_worker_dataset(dataset_fn, self) 1345 1346 def _create_per_worker_resources(self, fn, args=None, kwargs=None): 1347 """Synchronously create resources on the workers. 1348 1349 The resources are represented by 1350 `tf.distribute.experimental.coordinator.RemoteValue`s. 1351 1352 Args: 1353 fn: The function to be dispatched to all workers for execution 1354 asynchronously. 1355 args: Positional arguments for `fn`. 1356 kwargs: Keyword arguments for `fn`. 1357 1358 Returns: 1359 A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which 1360 wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue` 1361 objects. 1362 """ 1363 results = [] 1364 for w in self._cluster.workers: 1365 results.append(w.create_resource(fn, args=args, kwargs=kwargs)) 1366 return PerWorkerValues(tuple(results)) 1367 1368 def fetch(self, val): 1369 """Blocking call to fetch results from the remote values. 1370 1371 This is a wrapper around 1372 `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a 1373 `RemoteValue` structure; it returns the execution results of 1374 `RemoteValue`s. If not ready, wait for them while blocking the caller. 1375 1376 Example: 1377 ```python 1378 strategy = ... 1379 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1380 strategy) 1381 1382 def dataset_fn(): 1383 return tf.data.Dataset.from_tensor_slices([1, 1, 1]) 1384 1385 with strategy.scope(): 1386 v = tf.Variable(initial_value=0) 1387 1388 @tf.function 1389 def worker_fn(iterator): 1390 def replica_fn(x): 1391 v.assign_add(x) 1392 return v.read_value() 1393 return strategy.run(replica_fn, args=(next(iterator),)) 1394 1395 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 1396 distributed_iterator = iter(distributed_dataset) 1397 result = coordinator.schedule(worker_fn, args=(distributed_iterator,)) 1398 assert coordinator.fetch(result) == 1 1399 ``` 1400 1401 Args: 1402 val: The value to fetch the results from. If this is structure of 1403 `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be 1404 called on the individual 1405 `tf.distribute.experimental.coordinator.RemoteValue` to get the result. 1406 1407 Returns: 1408 If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a 1409 structure of `tf.distribute.experimental.coordinator.RemoteValue`s, 1410 return the fetched `tf.distribute.experimental.coordinator.RemoteValue` 1411 values immediately if they are available, or block the call until they are 1412 available, and return the fetched 1413 `tf.distribute.experimental.coordinator.RemoteValue` values with the same 1414 structure. If `val` is other types, return it as-is. 1415 """ 1416 1417 def _maybe_fetch(val): 1418 if isinstance(val, RemoteValue): 1419 return val.fetch() 1420 else: 1421 return val 1422 1423 # TODO(yuefengz): we should fetch values in a batch. 1424 return nest.map_structure(_maybe_fetch, val) 1425 1426 1427def _extract_failed_ps_instances(err_msg): 1428 """Return a set of potentially failing ps instances from error message.""" 1429 tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg) 1430 return set(int(t.split(":")[-1]) for t in tasks) 1431 1432 1433def _is_ps_failure(error): 1434 """Whether the error is considered a parameter server failure.""" 1435 1436 # For an `ClosureInputError` or `ClosureAbortedError`, extract 1437 # the original error and assess it accordingly. 1438 if isinstance(error, (ClosureInputError, ClosureAbortedError)): 1439 error = error.original_exception 1440 1441 if _RPC_ERROR_FROM_PS not in str(error): 1442 return False 1443 1444 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 1445 return True 1446 1447 # The following error could happen when the remote task fails and restarts 1448 # in a very short interval during which no RPCs were exchanged to detect the 1449 # failure. In that case, gRPC allows channel (which is different from a 1450 # connection) to be reused for a replaced server listening to same address. 1451 if isinstance(error, errors.InvalidArgumentError): 1452 if ("unknown device" in str(error).lower() or 1453 "Unable to find the relevant tensor remote_handle" in str(error)): 1454 return True 1455 1456 return False 1457 1458 1459def _handle_graph_execution_error_as_worker_failure(): 1460 return int(os.environ.get("TF_PS_HANDLE_UNKNOWN_ERROR", "0")) > 0 1461 1462 1463def _is_worker_failure(error): 1464 """Whether the error is considered a worker failure.""" 1465 1466 # TODO(b/216666282): Understand why worker failure can manifest as a 1467 # "Graph execution error" `UnknownError`. 1468 if (_handle_graph_execution_error_as_worker_failure() and 1469 isinstance(error, errors.UnknownError) and 1470 "Graph execution error" in str(error)): 1471 logging.info(f"Handling {type(error)}: {str(error)} as worker failure.") 1472 return True 1473 1474 # For an `ClosureInputError` or `ClosureAbortedError`, extract 1475 # the original error and assess it accordingly. 1476 if isinstance(error, (ClosureInputError, ClosureAbortedError)): 1477 error = error.original_exception 1478 1479 if _JOB_WORKER_STRING_IDENTIFIER not in str(error): 1480 return False 1481 if _RPC_ERROR_FROM_PS in str(error): 1482 return False 1483 1484 # TODO(haoyuzhang): Consider using special status code if error from a 1485 # remote is derived from RPC errors originated from other hosts. 1486 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 1487 return True 1488 1489 # The following error could happen when the remote task fails and restarts 1490 # in a very short interval during which no RPCs were exchanged to detect the 1491 # failure. In that case, gRPC allows channel (which is different from a 1492 # connection) to be reused for a replaced server listening to same address. 1493 if isinstance(error, errors.InvalidArgumentError): 1494 if ("unknown device" in str(error).lower() or 1495 "Primary device is not remote" in str(error) or 1496 "Unable to find the relevant tensor remote_handle" in str(error)): 1497 return True 1498 1499 # TODO(b/162541228): The following 2 types of errors are very rare and only 1500 # observed in large-scale testing. The types of errors should be reduced. 1501 # This could happen when the function registration fails. In the observed 1502 # cases this only happens to the dataset related functions. 1503 if isinstance(error, errors.NotFoundError): 1504 if ("is neither a type of a primitive operation nor a name of a function " 1505 "registered" in str(error)): 1506 return True 1507 1508 # NOTE(b/179061495): During worker preemptions, if multiple functions are 1509 # running concurrently (especially with subfunctions spanning chief/PS), 1510 # CancelledError can be returned due to chief/PS cancelling outstanding RPCs 1511 # to the failing workers. 1512 if isinstance(error, errors.CancelledError): 1513 return True 1514 1515 return False 1516