xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/coordinator/cluster_coordinator.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
16
17This is currently under development and the API is subject to change.
18"""
19
20import collections
21import contextlib
22import os
23import re
24import threading
25import time
26import weakref
27
28from six.moves import queue
29
30from tensorflow.python.distribute import parameter_server_strategy_v2
31from tensorflow.python.distribute.coordinator import coordinator_context
32from tensorflow.python.distribute.coordinator import metric_utils
33from tensorflow.python.distribute.coordinator import values as values_lib
34from tensorflow.python.distribute.coordinator import watchdog
35from tensorflow.python.eager import cancellation
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.eager import executor
39from tensorflow.python.eager import function as tf_function
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import func_graph
42from tensorflow.python.framework import ops
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.util import nest
45from tensorflow.python.util.tf_export import tf_export
46
47# Maximum time for failed worker to come back is 1 hour
48_WORKER_MAXIMUM_RECOVERY_SEC = 3600
49
50# Maximum size for queued closures, "infinite" if set to 0.
51# When the maximum queue size is reached, further schedule calls will become
52# blocking until some previously queued closures are executed on workers.
53# Note that using an "infinite" queue size can take a non-trivial portion of
54# memory, and even lead to coordinator OOM. Modify the size to a smaller value
55# for coordinator with constrained memory resource (only recommended for
56# advanced users). Also used in unit tests to ensure the correctness when the
57# queue is full.
58_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
59
60# RPC error message from PS
61_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
62
63# InvalidArgumentError (unknown device) will not have "GRPC error..." string.
64_JOB_WORKER_STRING_IDENTIFIER = "/job:worker"
65
66
67RemoteValueStatus = values_lib.RemoteValueStatus
68RemoteValue = values_lib.RemoteValue
69RemoteValueImpl = values_lib.RemoteValueImpl
70PerWorkerValues = values_lib.PerWorkerValues
71
72
73class ClosureInputError(Exception):
74  """Wrapper for errors from resource building.
75
76  When a closure starts, it first checks for errors in any of its inputs, which
77  are RemoteValues from resource closures. If there were any errors, it wraps
78  the exception in this class and raises so it can be handled by the worker
79  failure handler.
80
81  Attributes:
82    original_exception:
83  """
84
85  def __init__(self, original_exception):
86    # Avoid doubly-nested errors
87    if isinstance(original_exception,
88                  (ClosureInputError, ClosureAbortedError)):
89      self.original_exception = original_exception.original_exception
90    else:
91      self.original_exception = original_exception
92    message = ("Input has an error, the original exception is %r, "
93               "error message is %s." %
94               (self.original_exception, str(self.original_exception)))
95    super().__init__(message)
96    self.with_traceback(original_exception.__traceback__)
97
98
99class ClosureAbortedError(Exception):
100  """Wrapper for errors from training closures, to attach to resource closures.
101
102  This wrapper is used when a dependent training closure fails to set errors on
103  its required resource closures.
104
105  Attributes:
106    original_exception: The Exception to wrap
107  """
108
109  def __init__(self, original_exception):
110    # Avoid doubly-nested errors
111    if isinstance(original_exception,
112                  (ClosureInputError, ClosureAbortedError)):
113      self.original_exception = original_exception.original_exception
114    else:
115      self.original_exception = original_exception
116    message = ("Other function has an execution error, as a result, the "
117               "current value is not available. The original exception is %r, "
118               "error message is %s." %
119               (self.original_exception, str(self.original_exception)))
120    super().__init__(message)
121    self.with_traceback(original_exception.__traceback__)
122
123
124def _get_error_from_remote_values(structure):
125  """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
126  errors_in_structure = []
127
128  def _get_error(val):
129    if isinstance(val, RemoteValue):
130      error = val._get_error()  # pylint: disable=protected-access
131      if error:
132        errors_in_structure.append(error)
133
134  nest.map_structure(_get_error, structure)
135  if errors_in_structure:
136    return errors_in_structure[0]
137  else:
138    return None
139
140
141def _maybe_get_remote_value(val):
142  """Gets the value of `val` if it is a `RemoteValue`."""
143  if isinstance(val, RemoteValue):
144    error = val._get_error()  # pylint: disable=protected-access
145    if error:
146      raise AssertionError(
147          "RemoteValue doesn't have a value because it has error %r:%s" %
148          (error, error))
149    elif val._status is not RemoteValueStatus.READY:  # pylint: disable=protected-access
150      raise AssertionError("The input RemoteValue has not been executed.")
151    else:
152      return val._get_values()  # pylint: disable=protected-access
153  else:
154    return val
155
156
157def _maybe_as_type_spec(val):
158  if isinstance(val, (RemoteValue, PerWorkerValues)):
159    if val._type_spec is None:  # pylint: disable=protected-access
160      raise ValueError("Output of a scheduled function that is not "
161                       "tf.function cannot be the input of another function.")
162    return val._type_spec  # pylint: disable=protected-access
163  else:
164    return val
165
166
167def _select_worker_slice(worker_id, structured):
168  """Selects the worker slice of each of the items in `structured`."""
169
170  def _get(x):
171    return x._values[worker_id] if isinstance(x, PerWorkerValues) else x  # pylint: disable=protected-access
172
173  return nest.map_structure(_get, structured)
174
175
176def _disallow_remote_value_as_input(structured):
177  """Raises if any element of `structured` is a RemoteValue."""
178
179  def _raise_if_remote_value(x):
180    if isinstance(x, RemoteValue):
181      raise ValueError(
182          "`tf.distribute.experimental.coordinator.RemoteValue` used "
183          "as an input to scheduled function is not yet "
184          "supported.")
185
186  nest.map_structure(_raise_if_remote_value, structured)
187
188
189class Closure(object):
190  """Hold a function to be scheduled and its arguments."""
191
192  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
193    if not callable(function):
194      raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
195                       "be a callable object.")
196    self._args = args or ()
197    self._kwargs = kwargs or {}
198
199    _disallow_remote_value_as_input(self._args)
200    _disallow_remote_value_as_input(self._kwargs)
201
202    if isinstance(function, def_function.Function):
203      replica_args = _select_worker_slice(0, self._args)
204      replica_kwargs = _select_worker_slice(0, self._kwargs)
205
206      # Note: no need to handle function registration failure since this kind of
207      # failure will not raise exceptions as designed in the runtime. The
208      # coordinator has to rely on subsequent operations that raise to catch
209      # function registration failure.
210
211      # Record the function tracing overhead. Note that we pass in the tracing
212      # count of the def_function.Function as a state tracker, so that metrics
213      # will only record the time for actual function tracing (i.e., excluding
214      # function cache lookups).
215      with metric_utils.monitored_timer(
216          "function_tracing", state_tracker=function._get_tracing_count):  # pylint: disable=protected-access
217        self._concrete_function = function.get_concrete_function(
218            *nest.map_structure(_maybe_as_type_spec, replica_args),
219            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
220    elif isinstance(function, tf_function.ConcreteFunction):
221      self._concrete_function = function
222
223    if hasattr(self, "_concrete_function"):
224      # If we have a concrete function, we get to retrieve the output type spec
225      # via the structured_output.
226      self._output_type_spec = func_graph.convert_structure_to_signature(
227          self._concrete_function.structured_outputs)
228      self._function = cancellation_mgr.get_cancelable_function(
229          self._concrete_function)
230    else:
231      # Otherwise (i.e. what is passed in is a regular python function), we have
232      # no such information.
233      self._output_type_spec = None
234      self._function = function
235
236    self._output_remote_value_ref = None
237
238  def build_output_remote_value(self):
239    if self._output_remote_value_ref is None:
240      ret = RemoteValueImpl(None, self._output_type_spec)
241      self._output_remote_value_ref = weakref.ref(ret)
242      return ret
243    else:
244      raise ValueError(
245          "The output of the Closure cannot be built more than once.")
246
247  def maybe_call_with_output_remote_value(self, method):
248    if self._output_remote_value_ref is None:
249      return None
250    output_remote_value = self._output_remote_value_ref()
251    if output_remote_value is not None:
252      return method(output_remote_value)
253    return None
254
255  def mark_cancelled(self):
256    e = errors.CancelledError(
257        None, None, "The corresponding function is "
258        "cancelled. Please reschedule the function.")
259    self.maybe_call_with_output_remote_value(lambda r: r._set_error(e))  # pylint: disable=protected-access
260
261  def execute_on(self, worker):
262    """Executes the closure on the given worker.
263
264    Args:
265      worker: a `Worker` object.
266    """
267    replica_args = _select_worker_slice(worker.worker_index, self._args)
268    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
269
270    e = (
271        _get_error_from_remote_values(replica_args) or
272        _get_error_from_remote_values(replica_kwargs))
273    if e:
274      if not isinstance(e, ClosureInputError):
275        e = ClosureInputError(e)
276      raise e
277
278    with ops.device(worker.device_name):
279      with context.executor_scope(worker.executor):
280        with coordinator_context.with_dispatch_context(worker):
281          with metric_utils.monitored_timer("closure_execution"):
282            output_values = self._function(
283                *nest.map_structure(_maybe_get_remote_value, replica_args),
284                **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
285    self.maybe_call_with_output_remote_value(
286        lambda r: r._set_values(output_values))  # pylint: disable=protected-access
287
288
289class ResourceClosure(Closure):
290
291  def build_output_remote_value(self):
292    if self._output_remote_value_ref is None:
293      # We need to remember the Closure object in the `RemoteValue` here.
294      ret = RemoteValueImpl(self, self._output_type_spec)
295      self._output_remote_value_ref = weakref.ref(ret)
296      return ret
297    else:
298      return self._output_remote_value_ref()
299
300
301class _CoordinatedClosureQueue(object):
302  """Manage a queue of closures, inflight count and errors from execution.
303
304  This class is thread-safe.
305  """
306
307  def __init__(self):
308    # `self._inflight_closure_count` only tracks the number of inflight closures
309    # that are "in generation". Once an error occurs, error generation is
310    # incremented and all subsequent arriving closures (from inflight) are
311    # considered "out of generation".
312    self._inflight_closure_count = 0
313
314    self._queue_lock = threading.Lock()
315
316    # Condition indicating that all pending closures (either queued or inflight)
317    # have been processed, failed, or cancelled.
318    self._stop_waiting_condition = threading.Condition(self._queue_lock)
319
320    # Condition indicating that an item becomes available in queue (not empty).
321    self._closures_queued_condition = threading.Condition(self._queue_lock)
322    self._should_process_closures = True
323
324    # Condition indicating that a queue slot becomes available (not full).
325    # Note that even with "infinite" queue size, there is still a "practical"
326    # size limit for the queue depending on host memory capacity, and thus the
327    # queue will eventually become full with a lot of enqueued closures.
328    self._queue_free_slot_condition = threading.Condition(self._queue_lock)
329
330    # Condition indicating there is no inflight closures.
331    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
332
333    # Use to cancel in-flight closures.
334    self._cancellation_mgr = cancellation.CancellationManager()
335
336    if _CLOSURE_QUEUE_MAX_SIZE <= 0:
337      logging.warning(
338          "In a `ClusterCoordinator`, creating an infinite closure queue can "
339          "consume a significant amount of memory and even lead to OOM.")
340    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
341    self._tagged_queue = collections.defaultdict(queue.Queue)
342    self._error = None
343
344    # The following is a lock to make sure when `wait` is called and before it
345    # returns no `put` can be executed during this period. It is because `wait`
346    # won't know what to do with newly put closures. This lock adds an cutoff
347    # for `wait` so that closures put into the queue while waiting would not be
348    # taken responsible by this `wait`.
349    #
350    # We cannot reuse the `self._queue_lock` since when `wait` waits for a
351    # condition, the `self._queue_lock` will be released.
352    #
353    # We don't use a reader/writer's lock on purpose to reduce the complexity
354    # of the code.
355    self._put_wait_lock = threading.Lock()
356
357    self._watchdog = watchdog.WatchDog(on_triggered=self._on_watchdog_timeout)
358
359  def _on_watchdog_timeout(self):
360    logging.info("inflight_closure_count is %d", self._inflight_closure_count)
361    logging.info("current error is %s:%r", self._error, self._error)
362
363  def stop(self):
364    with self._queue_lock:
365      self._should_process_closures = False
366      self._cancellation_mgr.start_cancel()
367      self._closures_queued_condition.notify_all()
368    self._watchdog.stop()
369
370  def _cancel_all_closures(self):
371    """Clears the queue and sets remaining closures cancelled error.
372
373    This method expects self._queue_lock to be held prior to entry.
374    """
375    self._cancellation_mgr.start_cancel()
376    logging.info("Canceling all closures: waiting for inflight closures to "
377                 "finish")
378    while self._inflight_closure_count > 0:
379      self._no_inflight_closure_condition.wait()
380    logging.info("Canceling all closures: canceling remaining closures on the "
381                 "queue")
382    while True:
383      try:
384        closure = self._queue.get(block=False)
385        self._queue_free_slot_condition.notify()
386        closure.mark_cancelled()
387      except queue.Empty:
388        break
389    # The cancellation manager cannot be reused once cancelled. After all
390    # closures (queued or inflight) are cleaned up, recreate the cancellation
391    # manager with clean state.
392    # Note on thread-safety: this is triggered when one of theses
393    # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
394    # same time, no new closures can be constructed (which reads the
395    # _cancellation_mgr to get cancellable functions).
396    self._cancellation_mgr = cancellation.CancellationManager()
397
398  def _raise_if_error(self):
399    """Raises the error if one exists.
400
401    If an error exists, cancel the closures in queue, raises it, and clear
402    the error.
403
404    This method expects self._queue_lock to be held prior to entry.
405    """
406    if self._error:
407      logging.error("Start cancelling closures due to error %r: %s",
408                    self._error, self._error)
409      self._cancel_all_closures()
410      try:
411        raise self._error  # pylint: disable=raising-bad-type
412      finally:
413        self._error = None
414
415  def put(self, closure, tag=None):
416    """Put a closure into the queue for later execution.
417
418    If `mark_failed` was called before `put`, the error from the first
419    invocation of `mark_failed` will be raised.
420
421    Args:
422      closure: The `Closure` to put into the queue.
423      tag: if not None, put into a queue with the given tag.
424    """
425    closure.tag = tag
426    if tag is not None:
427      with self._queue_lock:
428        self._tagged_queue[tag].put(closure, block=False)
429        self._closures_queued_condition.notifyAll()
430    else:
431      with self._put_wait_lock, self._queue_lock:
432        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
433        self._queue.put(closure, block=False)
434        self._raise_if_error()
435        self._closures_queued_condition.notify()
436
437  def get(self, timeout=None, tag=None):
438    """Return a closure from the queue to be executed.
439
440    It will try to fetch an item from the queue with the given tag. If this
441    queue is empty, it will then check the global queue.
442
443    Args:
444      timeout: timeout when waiting for a closure to be put.
445      tag: optional tag to specify which queue to query first before querying
446        the global queue.
447
448    Returns:
449      a closure or None after timeout.
450    """
451    with self._queue_lock:
452      while (self._should_process_closures and self._queue.empty() and
453             (tag is None or self._tagged_queue[tag].empty())):
454        if not self._closures_queued_condition.wait(timeout=timeout):
455          return None
456      if not self._should_process_closures:
457        return None
458      if tag is not None and not self._tagged_queue[tag].empty():
459        closure = self._tagged_queue[tag].get(block=False)
460        return closure
461      closure = self._queue.get(block=False)
462      assert closure.tag is None
463      assert tag is None or self._tagged_queue[tag].empty()
464      self._queue_free_slot_condition.notify()
465      self._inflight_closure_count += 1
466      return closure
467
468  def mark_finished(self):
469    """Let the queue know that a closure has been successfully executed."""
470    with self._queue_lock:
471      if self._inflight_closure_count < 1:
472        raise AssertionError("There is no inflight closures to mark_finished.")
473      self._inflight_closure_count -= 1
474      if self._inflight_closure_count == 0:
475        self._no_inflight_closure_condition.notify_all()
476      if self._queue.empty() and self._inflight_closure_count == 0:
477        self._stop_waiting_condition.notify_all()
478      self._watchdog.report_closure_done()
479
480  def put_back(self, closure):
481    """Put the closure back into the queue as it was not properly executed."""
482    assert closure.tag is None
483    with self._queue_lock:
484      if self._inflight_closure_count < 1:
485        raise AssertionError("There is no inflight closures to put_back.")
486      if self._error:
487        closure.mark_cancelled()
488      else:
489        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
490        self._queue.put(closure, block=False)
491        self._closures_queued_condition.notify()
492      self._inflight_closure_count -= 1
493      if self._inflight_closure_count == 0:
494        self._no_inflight_closure_condition.notify_all()
495
496  def wait(self, timeout=None):
497    """Wait for all closures to be finished before returning.
498
499    If `mark_failed` was called before or during `wait`, the error from the
500    first invocation of `mark_failed` will be raised.
501
502    Args:
503      timeout: A float specifying a timeout for the wait in seconds.
504
505    Returns:
506      True unless the given timeout expired, in which case it returns False.
507    """
508    with self._put_wait_lock, self._queue_lock:
509      logging.info("Waiting for all global closures to be finished.")
510      while (not self._error and
511             (not self._queue.empty() or self._inflight_closure_count > 0)):
512        if not self._stop_waiting_condition.wait(timeout=timeout):
513          return False
514      self._raise_if_error()
515      return True
516
517  def mark_failed(self, e):
518    """Sets error and unblocks any wait() call."""
519    with self._queue_lock:
520      # TODO(yuefengz): maybe record all failure and give users more
521      # information?
522      if self._inflight_closure_count < 1:
523        raise AssertionError("There is no inflight closures to mark_failed.")
524      if self._error is None:
525        self._error = e
526      self._inflight_closure_count -= 1
527      if self._inflight_closure_count == 0:
528        self._no_inflight_closure_condition.notify_all()
529      self._stop_waiting_condition.notify_all()
530
531  def done(self):
532    """Returns true if the queue is empty and there is no inflight closure.
533
534    If `mark_failed` was called before `done`, the error from the first
535    invocation of `mark_failed` will be raised.
536    """
537    with self._queue_lock:
538      self._raise_if_error()
539      return self._queue.empty() and self._inflight_closure_count == 0
540
541  def clear_tag_unlocked(self, tag):
542    self._tagged_queue[tag] = queue.Queue()
543
544
545class WorkerPreemptionHandler(object):
546  """Handles worker preemptions."""
547
548  def __init__(self, server_def, cluster):
549    self._server_def = server_def
550    self._cluster = cluster
551    self._cluster_update_lock = threading.Lock()
552    self._cluster_due_for_update_or_finish = threading.Event()
553    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
554    self._error_from_recovery = None
555    self._should_preemption_thread_run = True
556    self._preemption_handler_thread = threading.Thread(
557        target=self._preemption_handler,
558        name="WorkerPreemptionHandler",
559        daemon=True)
560    self._preemption_handler_thread.start()
561
562  def stop(self):
563    """Ensure the worker preemption thread is closed."""
564    self._should_preemption_thread_run = False
565    with self._cluster_update_lock:
566      self._cluster_due_for_update_or_finish.set()
567    # TODO(yuefengz): The preemption handler thread shouldn't be terminated
568    # asynchronously since it touches eager context which is a process-wide
569    # singleton. The problem is in OSS unit tests will time out.
570
571  def _validate_preemption_failure(self, e):
572    """Validates that the given exception represents worker preemption."""
573
574    # Only categorize the failure as a worker preemption if the cancellation
575    # manager did not attempt to cancel the blocking operations.
576    if _is_worker_failure(e) and (
577        not self._cluster.closure_queue._cancellation_mgr.is_cancelled):  # pylint: disable=protected-access
578      return
579    raise e
580
581  @contextlib.contextmanager
582  def wait_on_failure(self,
583                      on_failure_fn=None,
584                      on_transient_failure_fn=None,
585                      on_recovery_fn=None,
586                      worker_device_name="(unknown)"):
587    """Catches worker preemption error and wait until failed workers are back.
588
589    Args:
590      on_failure_fn: an optional function to run if preemption happens.
591      on_transient_failure_fn: an optional function to run if transient failure
592        happens.
593      on_recovery_fn: an optional function to run when a worker is recovered
594        from preemption.
595      worker_device_name: the device name of the worker instance that is passing
596        through the failure.
597
598    Yields:
599      None.
600    """
601    assert self._should_preemption_thread_run
602    try:
603      yield
604    except (errors.OpError, ClosureInputError,
605            ClosureAbortedError) as e:
606      # If the error is due to temporary connectivity issues between worker and
607      # ps, put back closure, ignore error and do not mark worker as failure.
608      if self._cluster._record_and_ignore_transient_ps_failure(e):  # pylint: disable=protected-access
609        logging.error(
610            "Remote function on worker %s failed with %r:%s\n"
611            "It is treated as a transient connectivity failure for now.",
612            worker_device_name, e, e)
613        if on_transient_failure_fn:
614          on_transient_failure_fn()
615        return
616
617      # If the error is due to temporary connectivity issues that cause the
618      # server-side RPCs to be cancelled, TF might not abort the step and the
619      # closure might timeout. The coordinator ignores certain amount of such
620      # failures without marking worker as failure.
621      if self._cluster._record_and_ignore_transient_timeouts(e):  # pylint: disable=protected-access
622        logging.error(
623            "Remote function on worker %s failed with %r:%s\n"
624            "This derived error is ignored and not reported to users.",
625            worker_device_name, e, e)
626        if on_transient_failure_fn:
627          on_transient_failure_fn()
628        return
629
630      # Ignoring derived CancelledErrors to tolerate transient failures in
631      # PS-worker communication, which initially exposed as an UnavailableError
632      # and then lead to sub-function cancellation, subsequently getting
633      # reported from worker to chief as CancelledError.
634      # We do not mark either worker or PS as failed due to only CancelledError.
635      # If there are real (non-transient) failures, they must also be reported
636      # as other errors (UnavailableError most likely) in closure executions.
637      if isinstance(e, errors.CancelledError) and "/job:" in str(e):
638        logging.error(
639            "Remote function on worker %s failed with %r:%s\n"
640            "This derived error is ignored and not reported to users.",
641            worker_device_name, e, e)
642        if on_transient_failure_fn:
643          on_transient_failure_fn()
644        return
645
646      # This reraises the error, if it's not considered recoverable; otherwise,
647      # the following failure recovery logic run. At this time, only worker
648      # unavailability is recoverable. PS unavailability as well as other
649      # errors in the user function is not recoverable.
650      self._validate_preemption_failure(e)
651
652      logging.error("Worker %s failed with %r:%s", worker_device_name, e, e)
653      if on_failure_fn:
654        on_failure_fn(e)
655
656      with self._cluster_update_lock:
657        self._cluster_due_for_update_or_finish.set()
658        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
659        if self._error_from_recovery:
660          # TODO(yuefengz): there is only one worker that will get this error.
661          # Ideally we shuold let all workers notified by `_worker_up_cond` get
662          # this error.
663          try:
664            raise self._error_from_recovery
665          finally:
666            self._error_from_recovery = None
667        logging.info("Worker %s has been recovered.", worker_device_name)
668
669      if on_recovery_fn:
670        logging.info("Worker %s calling on_recovery_fn", worker_device_name)
671        with self.wait_on_failure(
672            on_recovery_fn=on_recovery_fn,
673            on_transient_failure_fn=on_transient_failure_fn,
674            worker_device_name=worker_device_name):
675          on_recovery_fn()
676
677  def _preemption_handler(self):
678    """A loop that handles preemption.
679
680    This loop waits for signal of worker preemption and upon worker preemption,
681    it waits until all workers are back and updates the cluster about the
682    restarted workers.
683    """
684    assert self._should_preemption_thread_run
685    while True:
686      self._cluster_due_for_update_or_finish.wait()
687      if not self._should_preemption_thread_run:
688        logging.info("Stopping the failure handing thread.")
689        break
690
691      with self._cluster_update_lock:
692        try:
693          # TODO(haoyuzhang): support partial cluster recovery
694          logging.info("Cluster now being recovered.")
695          context.context().update_server_def(self._server_def)
696
697          # Cluster updated successfully, clear the update signal, and notify
698          # all workers that they are recovered from failure.
699          logging.info("Cluster successfully recovered.")
700          self._worker_up_cond.notify_all()
701          # The check for _should_preemption_thread_run is necessary since the
702          # `stop` may have already set _cluster_due_for_update_or_finish.
703          if self._should_preemption_thread_run:
704            self._cluster_due_for_update_or_finish.clear()
705        except Exception as e:  # pylint: disable=broad-except
706          logging.info("Error occurred while updating server def: %s", e)
707          try:
708            self._validate_preemption_failure(e)
709          except Exception as ps_e:  # pylint: disable=broad-except
710            logging.info("Error that occurred while updating server def is not "
711                         "a worker failure. So set it as _error_from_recovery")
712            # In this case, a parameter server fails. So we raise this error to
713            # the caller of `wait_on_failure`.
714            self._error_from_recovery = ps_e
715            self._worker_up_cond.notify_all()
716            if self._should_preemption_thread_run:
717              self._cluster_due_for_update_or_finish.clear()
718          # NOTE: Since the first RPC (GetStatus) of update_server_def is
719          # currently blocking by default, error should only happen if:
720          # (1) More workers failed while waiting for the previous workers to
721          #     come back;
722          # (2) Worker failed when exchanging subsequent RPCs after the first
723          #     RPC returns.
724          # Consider adding backoff retry logic if we see the error logged
725          # too frequently.
726          logging.error("Cluster update failed with error: %s. Retrying...", e)
727
728
729class Worker(object):
730  """A worker in a cluster.
731
732  Attributes:
733    worker_index: The index of the worker in the cluster.
734    device_name: The device string of the worker, e.g. "/job:worker/task:1".
735    executor: The worker's executor for remote function execution.
736    failure_handler: The failure handler used to handler worker preemption
737      failure.
738  """
739
740  def __init__(self, worker_index, device_name, cluster):
741    self.worker_index = worker_index
742    self.device_name = device_name
743    self.executor = executor.new_executor(enable_async=False)
744    self.failure_handler = cluster.failure_handler
745    self._cluster = cluster
746    self._resource_tracking_lock = threading.Lock()
747    self._resource_remote_value_refs = []
748    self._is_dead_with_error = None
749    self._should_worker_thread_run = True
750
751    # Worker threads need to start after `Worker`'s initialization.
752    threading.Thread(target=self._process_queue,
753                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
754                     daemon=True).start()
755
756  def stop(self):
757    """Ensure the worker thread is closed."""
758    self._should_worker_thread_run = False
759
760  def _schedule_resource(self, closure):
761    self._cluster.closure_queue.put(closure, tag=self.worker_index)
762
763  def _set_resources_aborted(self, e):
764    """Set the resource ABORTED and add an error to it."""
765    # TODO(yuefengz): maybe we can query whether a tensor is valid or not
766    # instead of marking a tensor aborted?
767    logging.info("[Worker %d] Clearing all resources.", self.worker_index)
768    for weakref_resource in self._resource_remote_value_refs:
769      resource = weakref_resource()
770      if resource:
771        # It is important to set an error on an aborted RemoteValue from a
772        # ResourceClosure because its failure will not trigger the worker thread
773        # to raise error immediately and the worker may continue executing
774        # closures taking it as an input. The error will then be correctly
775        # reported to users.
776        resource._set_aborted(ClosureAbortedError(e))  # pylint: disable=protected-access
777
778  def _on_closure_failure(self, closure, e):
779    logging.info("[Worker %d] Putting back a closure after it failed.",
780                 self.worker_index)
781    self._cluster.closure_queue.put_back(closure)
782
783    with self._resource_tracking_lock:
784      self._is_dead_with_error = e
785      self._set_resources_aborted(e)
786
787  def _on_resource_closure_failure(self, e):
788    """Clear tagged queue to ensure resource closures are rebuilt.
789
790    Args:
791      e: The exception arisen from the resource closure.
792    """
793    logging.info("[Worker %d] Clearing tagged queue after resource closure "
794                 "failure.", self.worker_index)
795    with self._resource_tracking_lock:
796      self._is_dead_with_error = e
797      # No locking on queue is needed since
798      #  * get will not happen concurrently here.
799      #  * put to the specific tagged queue will be guarded by
800      #    `self._resource_tracking_lock`.
801      self._cluster.closure_queue.clear_tag_unlocked(self.worker_index)
802      self._set_resources_aborted(e)
803
804  def _on_worker_recovery(self):
805    logging.info("[Worker %d] calling _on_worker_recovery", self.worker_index)
806    with self._resource_tracking_lock:
807      for weakref_resource in self._resource_remote_value_refs:
808        resource = weakref_resource()
809        if resource:
810          self._schedule_resource(resource._closure)  # pylint: disable=protected-access
811      self._is_dead_with_error = False
812
813  def _process_closure(self, closure):
814    """Runs a closure with preemption handling."""
815    try:
816      with self.failure_handler.wait_on_failure(
817          on_failure_fn=lambda e: self._on_closure_failure(closure, e),
818          on_transient_failure_fn=(
819              lambda: self._cluster.closure_queue.put_back(closure)),
820          on_recovery_fn=self._on_worker_recovery,
821          worker_device_name=self.device_name):
822        closure.execute_on(self)
823        with metric_utils.monitored_timer("remote_value_fetch"):
824          # Copy the remote tensor to local (the coordinator) in case worker
825          # becomes unavailable at a later time.
826          closure.maybe_call_with_output_remote_value(lambda r: r.get())
827        self._cluster.closure_queue.mark_finished()
828    except Exception as e:  # pylint: disable=broad-except
829      # Avoid logging the derived cancellation error
830      if not isinstance(e, errors.CancelledError):
831        logging.error(
832            " /job:worker/task:%d encountered the following error when "
833            "processing closure: %r:%s", self.worker_index, e, e)
834      closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e))  # pylint: disable=protected-access
835      self._cluster.closure_queue.mark_failed(e)
836
837  def _process_resource_closure(self, closure):
838    """Run the given resource closure with preemption handling."""
839    assert closure.tag == self.worker_index
840    try:
841      with self.failure_handler.wait_on_failure(
842          on_failure_fn=self._on_resource_closure_failure,
843          on_transient_failure_fn=(
844              lambda: self._process_resource_closure(closure)),
845          on_recovery_fn=self._on_worker_recovery,
846          worker_device_name=self.device_name):
847        closure.execute_on(self)
848    except Exception as e:  # pylint: disable=broad-except
849      # Avoid logging the derived cancellation error
850      logging.info("[Worker %d] got an exception when processing resource "
851                   "closure", self.worker_index)
852      if not isinstance(e, errors.CancelledError):
853        logging.error(
854            " /job:worker/task:%d encountered the following error when "
855            "processing resource closure: %r:%s", self.worker_index, e, e)
856      closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e))  # pylint: disable=protected-access
857
858  def _maybe_delay(self):
859    """Delay if corresponding env vars are set."""
860    # If the following two env vars variables are set. Scheduling for workers
861    # will start in a staggered manner. Worker i will wait for
862    # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
863    # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
864    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
865    delay_secs *= self.worker_index
866    delay_cap = int(
867        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
868    if delay_cap:
869      delay_secs = min(delay_secs, delay_cap)
870    if delay_secs > 0:
871      logging.info(" Worker %d sleeping for %d seconds before running function",
872                   self.worker_index, delay_secs)
873    time.sleep(delay_secs)
874
875  def _process_queue(self):
876    """Function running in a worker thread to process closure queues."""
877    self._maybe_delay()
878    while self._should_worker_thread_run:
879      closure = self._cluster.closure_queue.get(tag=self.worker_index)
880      if not self._should_worker_thread_run or closure is None:
881        if closure is not None:
882          closure.mark_cancelled()
883        return
884      if isinstance(closure, ResourceClosure):
885        self._process_resource_closure(closure)
886      else:
887        self._process_closure(closure)
888      # To properly stop the worker and preemption threads, it is important that
889      # `ClusterCoordinator` object is not held onto so its `__del__` can be
890      # called. By removing the reference to the `closure` that has already been
891      # processed, we ensure that the `closure` object is released, while
892      # getting the next `closure` at above `self._cluster.closure_queue.get()`
893      # call.
894      del closure
895
896  def create_resource(self, function, args=None, kwargs=None):
897    """Synchronously creates a per-worker resource represented by a `RemoteValue`.
898
899    Args:
900      function: the resource function to be run remotely. It should be a
901        `tf.function`, a concrete function or a Python function.
902      args: positional arguments to be passed to the function.
903      kwargs: keyword arguments to be passed to the function.
904
905    Returns:
906      one or several RemoteValue objects depending on the function return
907      values.
908    """
909    # Some notes about the concurrency: currently all the activities related to
910    # the same worker such as creating resources, setting resources' aborted
911    # status, and executing closures happen on the same thread. This allows us
912    # to have simpler logic of concurrency.
913
914    closure = ResourceClosure(
915        function,
916        self._cluster.resource_cancellation_mgr,
917        args=args,
918        kwargs=kwargs)
919    resource_remote_value = closure.build_output_remote_value()
920    with self._resource_tracking_lock:
921      self._register_resource(resource_remote_value)
922      if self._is_dead_with_error:
923        resource_remote_value._set_aborted(  # pylint: disable=protected-access
924            ClosureAbortedError(self._is_dead_with_error))
925      else:
926        self._schedule_resource(closure)
927    return resource_remote_value
928
929  def _register_resource(self, resource_remote_value):
930    if not isinstance(resource_remote_value, RemoteValue):
931      raise ValueError("Resource being registered is not of type "
932                       "`tf.distribute.experimental.coordinator.RemoteValue`.")
933    self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
934
935
936class Cluster(object):
937  """A cluster with workers.
938
939  We assume all function errors are fatal and based on this assumption our
940  error reporting logic is:
941  1) Both `schedule` and `join` can raise a non-retryable error which is the
942  first error seen by the coordinator from any previously scheduled functions.
943  2) When an error is raised, there is no guarantee on how many previously
944  scheduled functions have been executed; functions that have not been executed
945  will be thrown away and marked as cancelled.
946  3) After an error is raised, the internal state of error will be cleared.
947  I.e. functions can continue to be scheduled and subsequent calls of `schedule`
948  or `join` will not raise the same error again.
949
950  Attributes:
951    failure_handler: The failure handler used to handler worker preemption
952      failure.
953    workers: a list of `Worker` objects in the cluster.
954    closure_queue: the global Closure queue.
955    resource_cancellation_mgr: the cancellation manager used to cancel resource
956      closures.
957  """
958
959  def __init__(self, strategy):
960    """Initializes the cluster instance."""
961
962    self._num_workers = strategy._num_workers
963    self._num_ps = strategy._num_ps
964
965    # Ignore PS failures reported by workers due to transient connection errors.
966    # Transient connectivity issues between workers and PS are relayed by the
967    # workers to the coordinator, leading the coordinator to believe that there
968    # are PS failures. The difference between transient vs. permanent PS failure
969    # is the number of reports from the workers. When this env var is set to a
970    # positive integer K, the coordinator ignores up to K reports of a failed PS
971    # task, i.e., only when there are more than K trials of executing closures
972    # fail due to errors from the same PS instance do we consider the PS
973    # instance encounters a failure.
974    # TODO(b/164279603): Remove this workaround when the underlying connectivity
975    # issue in gRPC server is resolved.
976    self._transient_ps_failures_threshold = int(
977        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
978    self._potential_ps_failures_lock = threading.Lock()
979    self._potential_ps_failures_count = [0] * self._num_ps
980
981    # Ignore worker timeouts due to transient connection errors.
982    # Transient connectivity issues might cause the server side to unexpectedly
983    # cancel RPC handling logic, leading to closure execution timeouts. When
984    # the _transient_timeout_threshold is set to a positive number, the cluster
985    # coordinator ignores DeadlineExceeded errors from workers for the specified
986    # times before raising the error to users.
987    self._transient_timeouts_threshold = int(
988        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS",
989                       self._num_workers // 10))
990    self._transient_timeouts_lock = threading.Lock()
991    self._transient_timeouts_count = 0
992
993    self.closure_queue = _CoordinatedClosureQueue()
994    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
995                                                   self)
996    worker_device_strings = [
997        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
998    ]
999    self.workers = [
1000        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
1001    ]
1002
1003    # Cancellation manager for all resource closures.
1004    self.resource_cancellation_mgr = cancellation.CancellationManager()
1005
1006  def stop(self):
1007    """Stop worker, worker preemption threads, and the closure queue."""
1008    logging.info("Stopping cluster, starting with failure handler")
1009    self.failure_handler.stop()
1010
1011    logging.info("Stopping workers")
1012    for worker in self.workers:
1013      worker.stop()
1014    logging.info("Stopping queue")
1015    self.closure_queue.stop()
1016    logging.info("Start cancelling remote resource-building functions")
1017    self.resource_cancellation_mgr.start_cancel()
1018
1019  def _record_and_ignore_transient_ps_failure(self, e):
1020    """Records potential PS failures and return if failure should be ignored."""
1021    if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
1022      return False
1023
1024    ps_tasks = _extract_failed_ps_instances(str(e))
1025    with self._potential_ps_failures_lock:
1026      for t in ps_tasks:
1027        self._potential_ps_failures_count[t] += 1
1028        # The number of UnavailableError encountered on this PS task exceeds the
1029        # maximum number of ignored error
1030        if (self._potential_ps_failures_count[t] >=
1031            self._transient_ps_failures_threshold):
1032          return False
1033    return True
1034
1035  def _record_and_ignore_transient_timeouts(self, e):
1036    """Records observed timeout error and return if it should be ignored."""
1037    if self._transient_timeouts_threshold <= 0:
1038      return False
1039    if not isinstance(e, errors.DeadlineExceededError):
1040      return False
1041    with self._transient_timeouts_lock:
1042      self._transient_timeouts_count += 1
1043      if self._transient_timeouts_count >= self._transient_timeouts_threshold:
1044        return False
1045    return True
1046
1047  def schedule(self, function, args, kwargs):
1048    """Schedules `function` to be dispatched to a worker for execution.
1049
1050    Args:
1051      function: The function to be dispatched to a worker for execution
1052        asynchronously.
1053      args: Positional arguments for `fn`.
1054      kwargs: Keyword arguments for `fn`.
1055
1056    Returns:
1057      A `RemoteValue` object.
1058    """
1059    closure = Closure(
1060        function,
1061        self.closure_queue._cancellation_mgr,  # pylint: disable=protected-access
1062        args=args,
1063        kwargs=kwargs)
1064    ret = closure.build_output_remote_value()
1065    self.closure_queue.put(closure)
1066    return ret
1067
1068  def join(self):
1069    """Blocks until all scheduled functions are executed."""
1070    self.closure_queue.wait()
1071
1072  def done(self):
1073    """Returns true if all scheduled functions are executed."""
1074    return self.closure_queue.done()
1075
1076
1077@tf_export("distribute.experimental.coordinator.ClusterCoordinator",
1078           "distribute.coordinator.ClusterCoordinator", v1=[])
1079class ClusterCoordinator(object):
1080  """An object to schedule and coordinate remote function execution.
1081
1082  This class is used to create fault-tolerant resources and dispatch functions
1083  to remote TensorFlow servers.
1084
1085  Currently, this class is not supported to be used in a standalone manner. It
1086  should be used in conjunction with a `tf.distribute` strategy that is designed
1087  to work with it. The `ClusterCoordinator` class currently only works
1088  `tf.distribute.experimental.ParameterServerStrategy`.
1089
1090  __The `schedule`/`join` APIs__
1091
1092  The most important APIs provided by this class is the `schedule`/`join` pair.
1093  The `schedule` API is non-blocking in that it queues a `tf.function` and
1094  returns a `RemoteValue` immediately. The queued functions will be dispatched
1095  to remote workers in background threads and their `RemoteValue`s will be
1096  filled asynchronously. Since `schedule` doesn’t require worker assignment, the
1097  `tf.function` passed in can be executed on any available worker. If the worker
1098  it is executed on becomes unavailable before its completion, it will be
1099  migrated to another worker. Because of this fact and function execution is not
1100  atomic, a function may be executed more than once.
1101
1102  __Handling Task Failure__
1103
1104  This class when used with
1105  `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
1106  fault tolerance for worker failures. That is, when some workers are not
1107  available for any reason to be reached from the coordinator, the training
1108  progress continues to be made with the remaining workers. Upon recovery of a
1109  failed worker, it will be added for function execution after datasets created
1110  by `create_per_worker_dataset` are re-built on it.
1111
1112  When a parameter server fails, a `tf.errors.UnavailableError` is raised by
1113  `schedule`, `join` or `done`. In this case, in addition to bringing back the
1114  failed parameter server, users should restart the coordinator so that it
1115  reconnects to workers and parameter servers, re-creates the variables, and
1116  loads checkpoints. If the coordinator fails, after the user brings it back,
1117  the program will automatically connect to workers and parameter servers, and
1118  continue the progress from a checkpoint.
1119
1120  It is thus essential that in user's program, a checkpoint file is periodically
1121  saved, and restored at the start of the program. If an
1122  `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a
1123  checkpoiont, its `iterations` property roughly indicates the number of steps
1124  that have been made. This can be used to decide how many epochs and steps are
1125  needed before the training completion.
1126
1127  See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
1128  example usage of this API.
1129
1130  This is currently under development, and the API as well as implementation
1131  are subject to changes.
1132  """
1133
1134  def __new__(cls, strategy):
1135    # `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
1136    # TODO(rchao): Needs a lock for thread-safety
1137    if strategy._cluster_coordinator is None:
1138      strategy._cluster_coordinator = super(
1139          ClusterCoordinator, cls).__new__(cls)
1140    return strategy._cluster_coordinator
1141
1142  def __init__(self, strategy):
1143    """Initialization of a `ClusterCoordinator` instance.
1144
1145    Args:
1146      strategy: a supported `tf.distribute.Strategy` object. Currently, only
1147        `tf.distribute.experimental.ParameterServerStrategy` is supported.
1148
1149    Raises:
1150      ValueError: if the strategy being used is not supported.
1151    """
1152    if not getattr(self, "_has_initialized", False):
1153      if not isinstance(strategy,
1154                        parameter_server_strategy_v2.ParameterServerStrategyV2):
1155        raise ValueError(
1156            "Only `tf.distribute.experimental.ParameterServerStrategy` "
1157            "is supported to work with "
1158            "`tf.distribute.experimental.coordinator.ClusterCoordinator` "
1159            "currently.")
1160      self._strategy = strategy
1161      self.strategy.extended._used_with_coordinator = True
1162      self._cluster = Cluster(strategy)
1163      self._has_initialized = True
1164
1165  def __del__(self):
1166    logging.info("ClusterCoordinator destructor: stopping cluster")
1167    self._cluster.stop()
1168
1169  @property
1170  def strategy(self):
1171    """Returns the `Strategy` associated with the `ClusterCoordinator`."""
1172    return self._strategy
1173
1174  def schedule(self, fn, args=None, kwargs=None):
1175    """Schedules `fn` to be dispatched to a worker for asynchronous execution.
1176
1177    This method is non-blocking in that it queues the `fn` which will be
1178    executed later and returns a
1179    `tf.distribute.experimental.coordinator.RemoteValue` object immediately.
1180    `fetch` can be called on it to wait for the function execution to finish
1181    and retrieve its output from a remote worker. On the other hand, call
1182    `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
1183    all scheduled functions to finish.
1184
1185    `schedule` guarantees that `fn` will be executed on a worker at least once;
1186    it could be more than once if its corresponding worker fails in the middle
1187    of its execution. Note that since worker can fail at any point when
1188    executing the function, it is possible that the function is partially
1189    executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
1190    guarantees that in those events, the function will eventually be executed on
1191    any worker that is available.
1192
1193    If any previously scheduled function raises an error, `schedule` will raise
1194    any one of those errors, and clear the errors collected so far. What happens
1195    here, some of the previously scheduled functions may have not been executed.
1196    User can call `fetch` on the returned
1197    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1198    executed, failed, or cancelled, and reschedule the corresponding function if
1199    needed.
1200
1201    When `schedule` raises, it guarantees that there is no function that is
1202    still being executed.
1203
1204    At this time, there is no support of worker assignment for function
1205    execution, or priority of the workers.
1206
1207    `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
1208    executed on a worker. They can be
1209    `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case,
1210    the argument will be substituted with the corresponding component on the
1211    target worker. Arguments that are not
1212    `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
1213    `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
1214    is not supported to be input `args` or `kwargs`.
1215
1216    Args:
1217      fn: A `tf.function`; the function to be dispatched to a worker for
1218        execution asynchronously. Regular python function is not supported to be
1219        scheduled.
1220      args: Positional arguments for `fn`.
1221      kwargs: Keyword arguments for `fn`.
1222
1223    Returns:
1224      A `tf.distribute.experimental.coordinator.RemoteValue` object that
1225      represents the output of the function scheduled.
1226
1227    Raises:
1228      Exception: one of the exceptions caught by the coordinator from any
1229        previously scheduled function, since the last time an error was thrown
1230        or since the beginning of the program.
1231    """
1232    if not isinstance(fn,
1233                      (def_function.Function, tf_function.ConcreteFunction)):
1234      raise TypeError(
1235          "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`"
1236          " only accepts a `tf.function` or a concrete function.")
1237    # Slot variables are usually created during function tracing time; thus
1238    # `schedule` needs to be called within the `strategy.scope()`.
1239    with self.strategy.scope():
1240      self.strategy.extended._being_scheduled = True  # pylint: disable=protected-access
1241      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
1242      self.strategy.extended._being_scheduled = False  # pylint: disable=protected-access
1243      return remote_value
1244
1245  def join(self):
1246    """Blocks until all the scheduled functions have finished execution.
1247
1248    If any previously scheduled function raises an error, `join` will fail by
1249    raising any one of those errors, and clear the errors collected so far. If
1250    this happens, some of the previously scheduled functions may have not been
1251    executed. Users can call `fetch` on the returned
1252    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1253    executed, failed, or cancelled. If some that have been cancelled need to be
1254    rescheduled, users should call `schedule` with the function again.
1255
1256    When `join` returns or raises, it guarantees that there is no function that
1257    is still being executed.
1258
1259    Raises:
1260      Exception: one of the exceptions caught by the coordinator by any
1261        previously scheduled function since the last time an error was thrown or
1262        since the beginning of the program.
1263    """
1264    self._cluster.join()
1265
1266  def done(self):
1267    """Returns whether all the scheduled functions have finished execution.
1268
1269    If any previously scheduled function raises an error, `done` will fail by
1270    raising any one of those errors.
1271
1272    When `done` returns True or raises, it guarantees that there is no function
1273    that is still being executed.
1274
1275    Returns:
1276      Whether all the scheduled functions have finished execution.
1277    Raises:
1278      Exception: one of the exceptions caught by the coordinator by any
1279        previously scheduled function since the last time an error was thrown or
1280        since the beginning of the program.
1281    """
1282    return self._cluster.done()
1283
1284  def create_per_worker_dataset(self, dataset_fn):
1285    """Create dataset on each worker.
1286
1287    This creates dataset on workers from the input which can be either a
1288    `tf.data.Dataset`, a `tf.distribute.DistributedDataset` or a function which
1289    returns a dataset, and returns an object that represents the collection of
1290    those individual datasets. Calling `iter` on such collection of datasets
1291    returns a `tf.distribute.experimental.coordinator.PerWorkerValues`, which is
1292    a collection of iterators, where the iterators have been placed on
1293    respective workers.
1294
1295    Calling `next` on a `PerWorkerValues` of iterator is unsupported. The
1296    iterator is meant to be passed as an argument into
1297    `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
1298    the scheduled function is about to be executed by a worker, the
1299    function will receive the individual iterator that corresponds to the
1300    worker. The `next` method can be called on an iterator inside a
1301    scheduled function when the iterator is an input of the function.
1302
1303    Currently the `schedule` method assumes workers are all the same and thus
1304    assumes the datasets on different workers are the same, except they may be
1305    shuffled differently if they contain a `dataset.shuffle` operation and a
1306    random seed is not set. Because of this, we also recommend the datasets to
1307    be repeated indefinitely and schedule a finite number of steps instead of
1308    relying on the `OutOfRangeError` from a dataset.
1309
1310
1311    Example:
1312
1313    ```python
1314    strategy = tf.distribute.experimental.ParameterServerStrategy(
1315        cluster_resolver=...)
1316    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1317        strategy=strategy)
1318
1319    @tf.function
1320    def worker_fn(iterator):
1321      return next(iterator)
1322
1323    def per_worker_dataset_fn():
1324      return strategy.distribute_datasets_from_function(
1325          lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
1326
1327    per_worker_dataset = coordinator.create_per_worker_dataset(
1328        per_worker_dataset_fn)
1329    per_worker_iter = iter(per_worker_dataset)
1330    remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
1331    assert remote_value.fetch() == 3
1332    ```
1333
1334    Args:
1335      dataset_fn: The dataset function that returns a dataset. This is to be
1336        executed on the workers.
1337
1338    Returns:
1339      An object that represents the collection of those individual
1340      datasets. `iter` is expected to be called on this object that returns
1341      a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
1342      iterators (that are on the workers).
1343    """
1344    return values_lib.get_per_worker_dataset(dataset_fn, self)
1345
1346  def _create_per_worker_resources(self, fn, args=None, kwargs=None):
1347    """Synchronously create resources on the workers.
1348
1349    The resources are represented by
1350    `tf.distribute.experimental.coordinator.RemoteValue`s.
1351
1352    Args:
1353      fn: The function to be dispatched to all workers for execution
1354        asynchronously.
1355      args: Positional arguments for `fn`.
1356      kwargs: Keyword arguments for `fn`.
1357
1358    Returns:
1359      A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
1360      wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
1361      objects.
1362    """
1363    results = []
1364    for w in self._cluster.workers:
1365      results.append(w.create_resource(fn, args=args, kwargs=kwargs))
1366    return PerWorkerValues(tuple(results))
1367
1368  def fetch(self, val):
1369    """Blocking call to fetch results from the remote values.
1370
1371    This is a wrapper around
1372    `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
1373    `RemoteValue` structure; it returns the execution results of
1374    `RemoteValue`s. If not ready, wait for them while blocking the caller.
1375
1376    Example:
1377    ```python
1378    strategy = ...
1379    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1380        strategy)
1381
1382    def dataset_fn():
1383      return tf.data.Dataset.from_tensor_slices([1, 1, 1])
1384
1385    with strategy.scope():
1386      v = tf.Variable(initial_value=0)
1387
1388    @tf.function
1389    def worker_fn(iterator):
1390      def replica_fn(x):
1391        v.assign_add(x)
1392        return v.read_value()
1393      return strategy.run(replica_fn, args=(next(iterator),))
1394
1395    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
1396    distributed_iterator = iter(distributed_dataset)
1397    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
1398    assert coordinator.fetch(result) == 1
1399    ```
1400
1401    Args:
1402      val: The value to fetch the results from. If this is structure of
1403        `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
1404        called on the individual
1405        `tf.distribute.experimental.coordinator.RemoteValue` to get the result.
1406
1407    Returns:
1408      If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
1409      structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
1410      return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
1411      values immediately if they are available, or block the call until they are
1412      available, and return the fetched
1413      `tf.distribute.experimental.coordinator.RemoteValue` values with the same
1414      structure. If `val` is other types, return it as-is.
1415    """
1416
1417    def _maybe_fetch(val):
1418      if isinstance(val, RemoteValue):
1419        return val.fetch()
1420      else:
1421        return val
1422
1423    # TODO(yuefengz): we should fetch values in a batch.
1424    return nest.map_structure(_maybe_fetch, val)
1425
1426
1427def _extract_failed_ps_instances(err_msg):
1428  """Return a set of potentially failing ps instances from error message."""
1429  tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
1430  return set(int(t.split(":")[-1]) for t in tasks)
1431
1432
1433def _is_ps_failure(error):
1434  """Whether the error is considered a parameter server failure."""
1435
1436  # For an `ClosureInputError` or `ClosureAbortedError`, extract
1437  # the original error and assess it accordingly.
1438  if isinstance(error, (ClosureInputError, ClosureAbortedError)):
1439    error = error.original_exception
1440
1441  if _RPC_ERROR_FROM_PS not in str(error):
1442    return False
1443
1444  if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1445    return True
1446
1447  # The following error could happen when the remote task fails and restarts
1448  # in a very short interval during which no RPCs were exchanged to detect the
1449  # failure. In that case, gRPC allows channel (which is different from a
1450  # connection) to be reused for a replaced server listening to same address.
1451  if isinstance(error, errors.InvalidArgumentError):
1452    if ("unknown device" in str(error).lower() or
1453        "Unable to find the relevant tensor remote_handle" in str(error)):
1454      return True
1455
1456  return False
1457
1458
1459def _handle_graph_execution_error_as_worker_failure():
1460  return int(os.environ.get("TF_PS_HANDLE_UNKNOWN_ERROR", "0")) > 0
1461
1462
1463def _is_worker_failure(error):
1464  """Whether the error is considered a worker failure."""
1465
1466  # TODO(b/216666282): Understand why worker failure can manifest as a
1467  # "Graph execution error" `UnknownError`.
1468  if (_handle_graph_execution_error_as_worker_failure() and
1469      isinstance(error, errors.UnknownError) and
1470      "Graph execution error" in str(error)):
1471    logging.info(f"Handling {type(error)}: {str(error)} as worker failure.")
1472    return True
1473
1474  # For an `ClosureInputError` or `ClosureAbortedError`, extract
1475  # the original error and assess it accordingly.
1476  if isinstance(error, (ClosureInputError, ClosureAbortedError)):
1477    error = error.original_exception
1478
1479  if _JOB_WORKER_STRING_IDENTIFIER not in str(error):
1480    return False
1481  if _RPC_ERROR_FROM_PS in str(error):
1482    return False
1483
1484  # TODO(haoyuzhang): Consider using special status code if error from a
1485  # remote is derived from RPC errors originated from other hosts.
1486  if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1487    return True
1488
1489  # The following error could happen when the remote task fails and restarts
1490  # in a very short interval during which no RPCs were exchanged to detect the
1491  # failure. In that case, gRPC allows channel (which is different from a
1492  # connection) to be reused for a replaced server listening to same address.
1493  if isinstance(error, errors.InvalidArgumentError):
1494    if ("unknown device" in str(error).lower() or
1495        "Primary device is not remote" in str(error) or
1496        "Unable to find the relevant tensor remote_handle" in str(error)):
1497      return True
1498
1499  # TODO(b/162541228): The following 2 types of errors are very rare and only
1500  # observed in large-scale testing. The types of errors should be reduced.
1501  # This could happen when the function registration fails. In the observed
1502  # cases this only happens to the dataset related functions.
1503  if isinstance(error, errors.NotFoundError):
1504    if ("is neither a type of a primitive operation nor a name of a function "
1505        "registered" in str(error)):
1506      return True
1507
1508  # NOTE(b/179061495): During worker preemptions, if multiple functions are
1509  # running concurrently (especially with subfunctions spanning chief/PS),
1510  # CancelledError can be returned due to chief/PS cancelling outstanding RPCs
1511  # to the failing workers.
1512  if isinstance(error, errors.CancelledError):
1513    return True
1514
1515  return False
1516