xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/collective_all_reduce_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
16
17import copy
18import threading
19import time
20import weakref
21
22from tensorflow.core.protobuf import rewriter_config_pb2
23from tensorflow.core.protobuf import tensorflow_server_pb2
24from tensorflow.python.distribute import collective_util
25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
26from tensorflow.python.distribute import cross_device_utils
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute import distribute_lib
29from tensorflow.python.distribute import distribute_utils
30from tensorflow.python.distribute import distribution_strategy_context as ds_context
31from tensorflow.python.distribute import input_lib
32from tensorflow.python.distribute import input_util
33from tensorflow.python.distribute import mirrored_strategy
34from tensorflow.python.distribute import multi_worker_util
35from tensorflow.python.distribute import numpy_dataset
36from tensorflow.python.distribute import reduce_util
37from tensorflow.python.distribute import values
38from tensorflow.python.distribute.cluster_resolver import ClusterResolver
39from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
40from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
41from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
42from tensorflow.python.eager import context
43from tensorflow.python.framework import device as tf_device
44from tensorflow.python.framework import errors
45from tensorflow.python.framework import ops
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import collective_ops
48from tensorflow.python.ops import control_flow_util
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.tpu import tpu_strategy_util
51from tensorflow.python.trackable import base
52from tensorflow.python.util import deprecation
53from tensorflow.python.util.tf_export import tf_export
54
55
56# pylint: disable=line-too-long
57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
58class CollectiveAllReduceStrategy(distribute_lib.Strategy):
59  """A distribution strategy for synchronous training on multiple workers.
60
61  This strategy implements synchronous distributed training across multiple
62  workers, each with potentially multiple GPUs. Similar to
63  `tf.distribute.MirroredStrategy`, it replicates all variables and computations
64  to each local device. The difference is that it uses a distributed collective
65  implementation (e.g. all-reduce), so that multiple workers can work together.
66
67  You need to launch your program on each worker and configure
68  `cluster_resolver` correctly. For example, if you are using
69  `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
70  have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
71  environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
72  is:
73
74  ```
75  TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
76  ```
77
78  Your program runs on each worker as-is. Note that collectives require each
79  worker to participate. All `tf.distribute` and non `tf.distribute` API may use
80  collectives internally, e.g. checkpointing and saving since reading a
81  `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
82  Therefore it's recommended to run exactly the same program on each worker.
83  Dispatching based on `task_type` or `task_id` of the worker is error-prone.
84
85  `cluster_resolver.num_accelerators()` determines the number of GPUs the
86  strategy uses. If it's zero, the strategy uses the CPU. All workers need to
87  use the same number of devices, otherwise the behavior is undefined.
88
89  This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
90  instead.
91
92  After setting up TF_CONFIG, using this strategy is similar to using
93  `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
94
95  ```
96  strategy = tf.distribute.MultiWorkerMirroredStrategy()
97
98  with strategy.scope():
99    model = tf.keras.Sequential([
100      tf.keras.layers.Dense(2, input_shape=(5,)),
101    ])
102    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
103
104  def dataset_fn(ctx):
105    x = np.random.random((2, 5)).astype(np.float32)
106    y = np.random.randint(2, size=(2, 1))
107    dataset = tf.data.Dataset.from_tensor_slices((x, y))
108    return dataset.repeat().batch(1, drop_remainder=True)
109  dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
110
111  model.compile()
112  model.fit(dist_dataset)
113  ```
114
115  You can also write your own training loop:
116
117  ```
118  @tf.function
119  def train_step(iterator):
120
121    def step_fn(inputs):
122      features, labels = inputs
123      with tf.GradientTape() as tape:
124        logits = model(features, training=True)
125        loss = tf.keras.losses.sparse_categorical_crossentropy(
126            labels, logits)
127
128      grads = tape.gradient(loss, model.trainable_variables)
129      optimizer.apply_gradients(zip(grads, model.trainable_variables))
130
131    strategy.run(step_fn, args=(next(iterator),))
132
133  for _ in range(NUM_STEP):
134    train_step(iterator)
135  ```
136
137  See
138  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
139  for a detailed tutorial.
140
141  __Saving__
142
143  You need to save and checkpoint on all workers instead of just one. This is
144  because variables whose synchronization=ON_READ triggers aggregation during
145  saving. It's recommended to save to a different path on each worker to avoid
146  race conditions. Each worker saves the same thing. See
147  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
148  tutorial for examples.
149
150  __Known Issues__
151
152  * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
153  correct number of accelerators. The strategy uses all available GPUs if
154  `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
155  or `None`.
156  * In eager mode, the strategy needs to be created before calling any other
157  Tensorflow API.
158
159  """
160  # pylint: enable=line-too-long
161
162  # TODO(anjalisridhar): Update our guides with examples showing how we can use
163  # the cluster_resolver argument.
164
165  # The starting number for collective keys. This should only be set in tests.
166  _collective_key_base = 0
167
168  def __init__(self,
169               cluster_resolver=None,
170               communication_options=None):
171    """Creates the strategy.
172
173    Args:
174      cluster_resolver: optional
175        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
176        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
177      communication_options: optional
178        `tf.distribute.experimental.CommunicationOptions`. This configures the
179        default options for cross device communications. It can be overridden by
180        options provided to the communication APIs like
181        `tf.distribute.ReplicaContext.all_reduce`. See
182        `tf.distribute.experimental.CommunicationOptions` for details.
183    """
184    if communication_options is None:
185      communication_options = collective_util.Options()
186    super(CollectiveAllReduceStrategy, self).__init__(
187        CollectiveAllReduceExtended(
188            self,
189            cluster_resolver=cluster_resolver,
190            communication_options=communication_options))
191
192    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
193        "MultiWorkerMirroredStrategy")
194    # pylint: disable=protected-access
195    distribute_lib.distribution_strategy_replica_gauge.get_cell(
196        "num_workers").set(self.extended._num_workers)
197    distribute_lib.distribution_strategy_replica_gauge.get_cell(
198        "num_replicas_per_worker").set(self.extended._num_devices_per_worker)
199
200  @classmethod
201  def _from_local_devices(cls, devices, communication_options=None):
202    """A convenience method to create an object with a list of devices."""
203    obj = cls(communication_options=communication_options)
204    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
205    return obj
206
207  @property
208  def cluster_resolver(self):
209    """Returns the cluster resolver associated with this strategy.
210
211    As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy`
212    provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If
213    the user provides one in `__init__`, that instance is returned; if the user
214    does not, a default `TFConfigClusterResolver` is provided.
215    """
216    return self.extended._cluster_resolver  # pylint: disable=protected-access
217
218
219class _CollectiveAllReduceStrategyExperimentalMeta(type):
220
221  @classmethod
222  def __instancecheck__(cls, instance):
223    # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
224    # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
225    # performing such check.
226    return isinstance(instance, CollectiveAllReduceStrategy)
227
228
229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
230class _CollectiveAllReduceStrategyExperimental(
231    CollectiveAllReduceStrategy,
232    metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
233
234  __doc__ = CollectiveAllReduceStrategy.__doc__
235
236  @deprecation.deprecated(
237      None, "use distribute.MultiWorkerMirroredStrategy instead")
238  def __init__(self,
239               communication=collective_util.CommunicationImplementation.AUTO,
240               cluster_resolver=None):
241    """Creates the strategy.
242
243    Args:
244      communication: optional
245        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
246        on the preferred collective communication implementation. Possible
247        values include `AUTO`, `RING`, and `NCCL`.
248      cluster_resolver: optional
249        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
250        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
251    """
252    communication_options = collective_util.Options(
253        implementation=communication)
254    super(_CollectiveAllReduceStrategyExperimental,
255          self).__init__(cluster_resolver, communication_options)
256
257  @classmethod
258  def _from_local_devices(
259      cls,
260      devices,
261      communication=collective_util.CommunicationImplementation.AUTO):
262    """A convenience method to create an object with a list of devices."""
263    obj = cls(communication)
264    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
265    return obj
266
267
268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
269
270
271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])  # pylint: disable=missing-docstring
272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
273
274  __doc__ = CollectiveAllReduceStrategy.__doc__
275
276  # The starting number for collective keys. This should only be set in tests.
277  _collective_key_base = 0
278
279  def __init__(self,
280               communication=collective_util.CommunicationImplementation.AUTO,
281               cluster_resolver=None):
282    """Initializes the object."""
283    communication_options = collective_util.Options(
284        implementation=communication)
285    super(CollectiveAllReduceStrategyV1, self).__init__(
286        CollectiveAllReduceExtended(
287            self,
288            cluster_resolver=cluster_resolver,
289            communication_options=communication_options))
290    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
291        "MultiWorkerMirroredStrategy")
292    # pylint: disable=protected-access
293    distribute_lib.distribution_strategy_replica_gauge.get_cell(
294        "num_workers").set(self.extended._num_workers)
295    distribute_lib.distribution_strategy_replica_gauge.get_cell(
296        "num_gpu_per_worker").set(
297            self.extended._num_devices_per_worker
298            if self.extended._local_device_type == "GPU"
299            else 0)
300
301
302def _is_gpu_device(device):
303  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
304
305
306class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
307  """Implementation of CollectiveAllReduceStrategy."""
308
309  # Whether to perdically check the health of the cluster. If any worker is not
310  # reachable, collectives are aborted and the user program should get a
311  # tf.errors.UnavailableError. It's required to restart in order to recover.
312  _enable_check_health = True
313  # Check health interval in seconds.
314  _check_health_interval = 30
315  # Timeout in seconds for the first check health. The first check health needs
316  # to wait for cluster, which may make a longer time.
317  _check_health_initial_timeout = 0
318  # Times to retry before considering the peer is down.
319  _check_health_retry_limit = 3
320  # Timeout in seconds the each check health.
321  _check_health_timeout = 10
322
323  def __init__(self, container_strategy, cluster_resolver,
324               communication_options, devices=None):
325    if not isinstance(communication_options, collective_util.Options):
326      raise ValueError("communication_options must be an instance of "
327                       "tf.distribute.experimental.CommunicationOptions")
328    if cluster_resolver and devices:
329      raise ValueError(
330          "cluster_resolver and devices cannot be set at the same time")
331
332    self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
333    if not isinstance(self._cluster_resolver, ClusterResolver):
334      raise ValueError("cluster_resolver must be an instance of "
335                       "tf.distribute.cluster_resolver.ClusterResolver")
336    distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
337    self._communication_options = communication_options
338    self._collective_key_base = container_strategy._collective_key_base  # pylint: disable=protected-access
339    self._initialize_strategy(self._cluster_resolver, devices=devices)
340    self._cfer_fn_cache = weakref.WeakKeyDictionary()
341    self.experimental_enable_get_next_as_optional = True
342    assert isinstance(self._cross_device_ops,
343                      cross_device_ops_lib.CollectiveAllReduce)
344
345  def _use_merge_call(self):
346    # We currently only disable merge_call when XLA is used to compile the `fn`
347    # passed to `strategy.run` and all devices are GPU.
348    return not control_flow_util.GraphOrParentsInXlaContext(
349        ops.get_default_graph()) or not all(
350            [_is_gpu_device(d) for d in self._devices])
351
352  def _initialize_strategy(self, cluster_resolver, devices):
353    # If devices are provided or cluster_spec is not specified, initialize
354    # single worker. Otherwise initialize multi workers.
355    if devices or not cluster_resolver.cluster_spec().as_dict():
356      self._initialize_local(cluster_resolver, devices=devices)
357    else:
358      self._initialize_multi_worker(cluster_resolver)
359
360  def _initialize_local_devices(self, cluster_resolver, worker_device):
361    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
362    # some cases.
363    if isinstance(cluster_resolver, TFConfigClusterResolver):
364      num_gpus = context.num_gpus()
365      num_tpus = 0
366    else:
367      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
368      num_tpus = cluster_resolver.num_accelerators().get("TPU", 0)
369
370    if num_gpus:
371      local_device_type = "GPU"
372      num_local_devices = num_gpus
373    elif num_tpus:
374      local_device_type = "TPU"
375      num_local_devices = num_tpus
376    else:
377      local_device_type = "CPU"
378      num_local_devices = 1
379    local_devices = tuple(
380        f"{worker_device}/device:{local_device_type}:{i}"
381        for i in range(num_local_devices))
382    return local_devices, local_device_type
383
384  def _initialize_local(self, cluster_resolver, devices=None):
385    """Initializes the object for local training."""
386    self._is_chief = True
387    self._num_workers = 1
388
389    if ops.executing_eagerly_outside_functions():
390      try:
391        context.context().configure_collective_ops(
392            scoped_allocator_enabled_ops=("CollectiveReduce",))
393      except RuntimeError:
394        logging.warning("Collective ops is not configured at program startup. "
395                        "Some performance features may not be enabled.")
396      self._collective_ops_configured = True
397
398    if devices:
399      local_devices = devices
400      if "GPU" in devices[0]:
401        local_device_type = "GPU"
402      elif "TPU" in devices[0]:
403        local_device_type = "TPU"
404      else:
405        local_device_type = "CPU"
406    else:
407      local_devices, local_device_type = self._initialize_local_devices(
408          cluster_resolver, worker_device="")
409
410    self._worker_device = device_util.canonicalize("/device:CPU:0")
411    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
412
413    self._collective_keys = cross_device_utils.CollectiveKeys(
414        group_key_start=1 + self._collective_key_base)
415    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
416        devices=local_devices,
417        group_size=len(local_devices),
418        options=self._communication_options,
419        collective_keys=self._collective_keys)
420    # CrossDeviceOps for per host tensors.
421    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
422        devices=[self._worker_device],
423        group_size=self._num_workers,
424        options=self._communication_options,
425        collective_keys=self._collective_keys)
426    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
427        local_devices)
428
429    self._cluster_spec = None
430    self._task_type = None
431    self._task_id = None
432    self._id_in_cluster = 0
433
434    # This is a mark to tell whether we are running with standalone client or
435    # independent worker. Right now with standalone client, strategy object is
436    # created as local strategy and then turn into multi-worker strategy via
437    # configure call.
438    self._local_or_standalone_client_mode = True
439
440    # Save the num_devices_per_worker and rpc_layer for configure method.
441    self._num_devices_per_worker = len(local_devices)
442    self._local_device_type = local_device_type
443    self._rpc_layer = cluster_resolver.rpc_layer
444    self._warn_nccl_no_gpu()
445
446    logging.info(
447        "Single-worker MultiWorkerMirroredStrategy with local_devices "
448        "= %r, communication = %s", local_devices,
449        self._communication_options.implementation)
450
451  def _initialize_multi_worker(self, cluster_resolver):
452    """Initializes the object for multi-worker training."""
453    cluster_spec = multi_worker_util.normalize_cluster_spec(
454        cluster_resolver.cluster_spec())
455    task_type = cluster_resolver.task_type
456    task_id = cluster_resolver.task_id
457    if task_type is None or task_id is None:
458      raise ValueError("When `cluster_spec` is given, you must also specify "
459                       "`task_type` and `task_id`.")
460    self._cluster_spec = cluster_spec
461    self._task_type = task_type
462    self._task_id = task_id
463    self._id_in_cluster = multi_worker_util.id_in_cluster(
464        self._cluster_spec, self._task_type, self._task_id)
465
466    self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
467    if not self._num_workers:
468      raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
469                       "in `cluster_spec`.")
470
471    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
472                                                task_id)
473
474    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
475    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
476
477    if (ops.executing_eagerly_outside_functions() and
478        not getattr(self, "_local_or_standalone_client_mode", False)):
479      context.context().configure_collective_ops(
480          collective_leader=multi_worker_util.collective_leader(
481              cluster_spec, task_type, task_id),
482          scoped_allocator_enabled_ops=("CollectiveReduce",),
483          device_filters=("/job:%s/task:%d" % (task_type, task_id),))
484      self._collective_ops_configured = True
485      if context.context().coordination_service is None:
486        coordinated_jobs = ["chief", "worker"]
487        if task_type in coordinated_jobs:
488          context.context().configure_coordination_service(
489              service_type="standalone",
490              service_leader=multi_worker_util.coordination_leader(
491                  cluster_spec),
492              coordinated_jobs=coordinated_jobs)
493
494    # Starting a std server in eager mode and in independent worker mode.
495    if (context.executing_eagerly() and
496        not getattr(self, "_std_server_started", False) and
497        not getattr(self, "_local_or_standalone_client_mode", False)):
498      # Checking _local_or_standalone_client_mode as well because we should not
499      # create the std server in standalone client mode.
500      config_proto = copy.deepcopy(context.context().config)
501      config_proto = self._update_config_proto(config_proto)
502
503      # If coordination service is enabled, use its internal heartbeat to detect
504      # peer failures instead of the Python-level health check.
505      if config_proto.experimental.coordination_config.service_type:
506        self._enable_check_health = False
507
508      if hasattr(cluster_resolver, "port"):
509        port = cluster_resolver.port
510      else:
511        port = 0
512      server_def = tensorflow_server_pb2.ServerDef(
513          cluster=cluster_spec.as_cluster_def(),
514          default_session_config=config_proto,
515          job_name=task_type,
516          task_index=task_id,
517          protocol=cluster_resolver.rpc_layer or "grpc",
518          port=port)
519      context.context().enable_collective_ops(server_def)
520      self._std_server_started = True
521      # The `ensure_initialized` is needed before calling
522      # `context.context().devices()`.
523      context.context().ensure_initialized()
524      logging.info(
525          "Enabled multi-worker collective ops with available devices: %r",
526          context.context().devices())
527
528    # TODO(yuefengz): The `num_gpus` is only for this particular task. It
529    # assumes all workers have the same number of GPUs. We should remove this
530    # assumption by querying all tasks for their numbers of GPUs.
531    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
532    # some cases.
533    local_devices, local_device_type = self._initialize_local_devices(
534        cluster_resolver, self._worker_device)
535    if local_device_type == "TPU":
536      tpu_strategy_util.initialize_tpu_system()
537
538    self._collective_keys = cross_device_utils.CollectiveKeys(
539        group_key_start=1 + self._collective_key_base)
540    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
541        devices=local_devices,
542        group_size=len(local_devices) * self._num_workers,
543        options=self._communication_options,
544        collective_keys=self._collective_keys)
545    # CrossDeviceOps for per host tensors.
546    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
547        devices=[self._worker_device],
548        group_size=self._num_workers,
549        options=self._communication_options,
550        collective_keys=self._collective_keys)
551    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
552        local_devices)
553
554    # Add a default device so that ops without specified devices will not end up
555    # on other workers.
556    self._default_device = "/job:%s/task:%d" % (task_type, task_id)
557
558    # Save the num_devices_per_worker and rpc_layer for configure method.
559    self._num_devices_per_worker = len(local_devices)
560    self._local_device_type = local_device_type
561    self._rpc_layer = cluster_resolver.rpc_layer
562    self._warn_nccl_no_gpu()
563
564    if self._enable_check_health and context.executing_eagerly():
565      self._start_check_health_thread()
566    else:
567      logging.info("Check health not enabled.")
568
569    logging.info(
570        "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
571        "task_id = %r, num_workers = %r, local_devices = %r, "
572        "communication = %s", cluster_spec.as_dict(), task_type, task_id,
573        self._num_workers, local_devices,
574        self._communication_options.implementation)
575
576  def __del__(self):
577    self._stop_check_health_thread()
578
579  def _input_workers_with_options(self, options=None):
580    host_device = device_util.get_host_for_device(self._worker_device)
581    if not options or options.experimental_fetch_to_device:
582      return input_lib.InputWorkers([(host_device, self.worker_devices)])
583    else:
584      return input_lib.InputWorkers([(
585          host_device,
586          [device_util.get_host_for_device(worker) for worker in
587           self.worker_devices])])
588
589  @property
590  def _input_workers(self):
591    return self._input_workers_with_options()
592
593  def _get_variable_creator_initial_value(self,
594                                          replica_id,
595                                          device,
596                                          primary_var,
597                                          **kwargs):
598    if replica_id == 0:  # First replica on each worker.
599      assert device is not None
600      assert primary_var is None
601
602      def initial_value_fn():  # pylint: disable=g-missing-docstring
603        # Only the first device participates in the broadcast of initial values.
604        group_key = self._collective_keys.get_group_key([device])
605        group_size = self._num_workers
606        collective_instance_key = (
607            self._collective_keys.get_instance_key(group_key, device))
608
609        with ops.device(device):
610          initial_value = kwargs["initial_value"]
611          if callable(initial_value):
612            initial_value = initial_value()
613          if isinstance(initial_value, base.CheckpointInitialValue):
614            initial_value = initial_value.wrapped_value
615          assert not callable(initial_value)
616          initial_value = ops.convert_to_tensor(
617              initial_value, dtype=kwargs.get("dtype", None))
618
619          if self._num_workers > 1:
620            if self._is_chief:
621              bcast_send = collective_ops.broadcast_send(
622                  initial_value, initial_value.shape, initial_value.dtype,
623                  group_size, group_key, collective_instance_key)
624              with ops.control_dependencies([bcast_send]):
625                return array_ops.identity(initial_value)
626            else:
627              return collective_ops.broadcast_recv(initial_value.shape,
628                                                   initial_value.dtype,
629                                                   group_size, group_key,
630                                                   collective_instance_key)
631          return initial_value
632
633      return initial_value_fn
634    else:
635      return super(CollectiveAllReduceExtended,
636                   self)._get_variable_creator_initial_value(
637                       replica_id=replica_id,
638                       device=device,
639                       primary_var=primary_var,
640                       **kwargs)
641
642  def _make_input_context(self):
643    input_context = distribute_lib.InputContext(
644        num_input_pipelines=self._num_workers,
645        input_pipeline_id=self._id_in_cluster,
646        num_replicas_in_sync=self._num_replicas_in_sync)
647    return input_context
648
649  def _experimental_distribute_dataset(self, dataset, options):
650    if (options and options.experimental_replication_mode ==
651        distribute_lib.InputReplicationMode.PER_REPLICA):
652      raise NotImplementedError(
653          "InputReplicationMode.PER_REPLICA "
654          "is only supported in "
655          "`distribute_datasets_from_function` "
656          "of tf.distribute.MirroredStrategy"
657      )
658    input_context = self._make_input_context()
659    return input_util.get_distributed_dataset(
660        dataset,
661        self._input_workers_with_options(options),
662        self._container_strategy(),
663        num_replicas_in_sync=self._num_replicas_in_sync,
664        input_context=input_context,
665        options=options)
666
667  def _distribute_datasets_from_function(self, dataset_fn, options):
668    if (options and options.experimental_replication_mode ==
669        distribute_lib.InputReplicationMode.PER_REPLICA):
670      raise NotImplementedError(
671          "InputReplicationMode.PER_REPLICA "
672          "is only supported in "
673          "`distribute_datasets_from_function` "
674          "of tf.distribute.MirroredStrategy")
675    input_context = self._make_input_context()
676    return input_util.get_distributed_datasets_from_function(
677        dataset_fn=dataset_fn,
678        input_workers=self._input_workers_with_options(options),
679        input_contexts=[input_context],
680        strategy=self._container_strategy(),
681        options=options)
682
683  def _experimental_distribute_values_from_function(self, value_fn):
684    per_replica_values = []
685    num_local_replicas = len(self.worker_devices)
686    for local_replica_id in range(num_local_replicas):
687      replica_id = (self._id_in_cluster * num_local_replicas +
688                    local_replica_id)
689      value_context = distribute_lib.ValueContext(
690          replica_id, self._num_replicas_in_sync)
691      per_replica_values.append(value_fn(value_context))
692    return distribute_utils.regroup(per_replica_values, always_wrap=True)
693
694  def _make_dataset_iterator(self, dataset):
695    """Distributes the dataset to each local GPU."""
696    input_context = self._make_input_context()
697    return input_lib_v1.DatasetIterator(
698        dataset,
699        self._input_workers,
700        self._container_strategy(),
701        num_replicas_in_sync=self._num_replicas_in_sync,
702        input_context=input_context)
703
704  def _make_input_fn_iterator(
705      self,
706      input_fn,
707      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
708    """Distributes the input function to each local GPU."""
709    input_context = self._make_input_context()
710    return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
711                                              [input_context],
712                                              self._container_strategy())
713
714  def _configure(self,
715                 session_config=None,
716                 cluster_spec=None,
717                 task_type=None,
718                 task_id=None):
719    """Configures the object.
720
721    Args:
722      session_config: a `tf.compat.v1.ConfigProto`
723      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
724        cluster configurations.
725      task_type: the current task type, such as "worker".
726      task_id: the current task id.
727
728    Raises:
729      ValueError: if `task_type` is not in the `cluster_spec`.
730    """
731    if cluster_spec:
732      cluster_resolver = SimpleClusterResolver(
733          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
734          task_type=task_type,
735          task_id=task_id,
736          num_accelerators={
737              self._local_device_type: self._num_devices_per_worker},
738          rpc_layer=self._rpc_layer)
739      self._initialize_multi_worker(cluster_resolver)
740      assert isinstance(self._cross_device_ops,
741                        cross_device_ops_lib.CollectiveAllReduce)
742
743    if session_config:
744      session_config.CopyFrom(self._update_config_proto(session_config))
745
746  def _update_config_proto(self, config_proto):
747    updated_config = copy.deepcopy(config_proto)
748    # Enable the scoped allocator optimization for CollectiveOps.  This
749    # optimization converts many small all-reduces into fewer larger
750    # all-reduces.
751    rewrite_options = updated_config.graph_options.rewrite_options
752    rewrite_options.scoped_allocator_optimization = (
753        rewriter_config_pb2.RewriterConfig.ON)
754    # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
755    # ["CollectiveReduce"].  Since we can't assign to a repeated proto field, we
756    # clear and then append.
757    del rewrite_options.scoped_allocator_opts.enable_op[:]
758    rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
759
760    if (not ops.executing_eagerly_outside_functions() and
761        self._communication_options.implementation ==
762        collective_util.CommunicationImplementation.NCCL):
763      updated_config.experimental.collective_nccl = True
764
765    if not self._cluster_spec:
766      return updated_config
767
768    assert self._task_type
769    assert self._task_id is not None
770
771    # Collective group leader is needed for collective ops to coordinate
772    # workers.
773    updated_config.experimental.collective_group_leader = (
774        multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
775                                            self._task_id))
776
777    # The device filters prevent communication between workers.
778    del updated_config.device_filters[:]
779    updated_config.device_filters.append(
780        "/job:%s/task:%d" % (self._task_type, self._task_id))
781
782    return updated_config
783
784  def _get_cross_device_ops(self, value):
785    # CollectiveAllReduce works on a predefined set of devices. In most cases
786    # they should be the compute devices, but certain use cases may reduce host
787    # tensors as well (e.g. early stopping). We infer the cross_device_ops to
788    # use based on the number of devices, since inputs don't always have device
789    # annotations. The compute devices one is preferred since we can potentially
790    # leverage NCCL.
791    if isinstance(value, values.DistributedValues):
792      num_devices = len(value._values)  # pylint: disable=protected-access
793    else:
794      num_devices = 1
795    if num_devices == len(self.worker_devices):
796      return self._cross_device_ops
797    else:
798      return self._host_cross_device_ops
799
800  def _gather_to_implementation(self, value, destinations, axis, options):
801    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
802        value,
803        destinations=destinations,
804        axis=axis,
805        options=options)
806
807  def _reduce_to(self, reduce_op, value, destinations, options):
808    if (isinstance(value, values.Mirrored) and
809        reduce_op == reduce_util.ReduceOp.MEAN):
810      return value
811    assert not isinstance(value, values.Mirrored)
812
813    if (isinstance(value, values.DistributedValues) and
814        len(self.worker_devices) == 1):
815      value = value.values[0]
816
817    # When there are multiple workers, we need to reduce across workers using
818    # collective ops.
819    if (not isinstance(value, values.DistributedValues) and
820        self._num_workers == 1):
821      # This function handles reducing values that are not PerReplica or
822      # Mirrored values. For example, the same value could be present on all
823      # replicas in which case `value` would be a single value or value could
824      # be 0.
825      return cross_device_ops_lib.reduce_non_distributed_value(
826          reduce_op, value, destinations, len(self.worker_devices))
827    return self._get_cross_device_ops(value).reduce(
828        reduce_op,
829        value,
830        destinations=destinations,
831        options=self._communication_options.merge(options))
832
833  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
834    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
835    # This implementation avoids using `merge_call` and just launches collective
836    # ops in one replica.
837    if options is None:
838      options = collective_util.Options()
839
840    if context.executing_eagerly():
841      # In eager mode, falls back to the default implemenation that uses
842      # `merge_call`. Replica functions are running sequentially in eager mode,
843      # and due to the blocking nature of collective ops, execution will hang if
844      # collective ops are to be launched sequentially.
845      return super()._replica_ctx_all_reduce(reduce_op, value, options)
846
847    replica_context = ds_context.get_replica_context()
848    assert replica_context, (
849        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
850        "replica context")
851    return self._cross_device_ops._all_reduce(  # pylint: disable=protected-access
852        reduce_op,
853        value,
854        replica_context._replica_id,  # pylint: disable=protected-access
855        options)
856
857  def _check_health(self):
858    while True:
859      if self._check_health_thread_should_stop.is_set():
860        return
861      for job in self._cluster_spec.jobs:
862        for task_id in range(self._cluster_spec.num_tasks(job)):
863          peer = "/job:{}/replica:0/task:{}".format(job, task_id)
864          attempts = 0
865          while True:
866            attempts += 1
867            try:
868              context.context().check_collective_ops_peer_health(
869                  peer, timeout_in_ms=self._check_health_timeout * 1000)
870              # If check_collective_ops_peer_health doesn't raise an Exception,
871              # the peer is healthy.
872              break
873            except (errors.UnavailableError, errors.FailedPreconditionError,
874                    errors.DeadlineExceededError) as e:
875              # TODO(b/151232436): Always raise UnavailableError when a peer
876              # fails. Now there could be many kinds of errors:
877              # - Unavailable: when the peer is not reachable, e.g. it's down.
878              # - FailedPrecondition: when the peer has restarted.
879              if attempts < self._check_health_retry_limit:
880                logging.warning("%s seems down, retrying %d/%d", peer, attempts,
881                                self._check_health_retry_limit)
882                continue
883              logging.error(
884                  "Cluster check alive failed, %s is down, "
885                  "aborting collectives: %s", peer, e)
886              context.context().abort_collective_ops(
887                  errors.UNAVAILABLE,
888                  "cluster check alive failed, {} is down".format(peer))
889              return
890            except Exception as e:  # pylint: disable=broad-except
891              logging.error("Unexpected exception in check alive: %s", e)
892              context.context().abort_collective_ops(
893                  errors.INTERNAL,
894                  "unexecpted exception in check alive: %s" % e)
895              return
896      time.sleep(self._check_health_interval)
897
898  def _start_check_health_thread(self):
899    # Use a dummy all-reduce as a barrier to wait for all workers to be up,
900    # otherwise the check health may fail immediately.
901
902    # Use array_ops.identity to create the dummy tensor so that we have a new
903    # Tensor. If we use constant it may be a cached from on a /job:localhost
904    # device, which will cause some code that relies on tensor.device to error.
905    #
906    # TODO(b/151232436): change to an explicit barrier if we have it.
907    dummy_value = array_ops.identity([])
908    logging.info("Waiting for the cluster, timeout = %s",
909                 self._check_health_initial_timeout or "inf")
910    try:
911      self._host_cross_device_ops.reduce(
912          reduce_util.ReduceOp.SUM,
913          dummy_value,
914          dummy_value,
915          options=collective_util.Options(
916              timeout_seconds=self._check_health_initial_timeout,
917              implementation=collective_util.CommunicationImplementation.RING))
918      if context.is_async():
919        context.async_wait()
920    except errors.DeadlineExceededError:
921      raise RuntimeError(
922          "Timeout waiting for the cluster, timeout is %d seconds" %
923          self._check_health_initial_timeout)
924    logging.info("Cluster is ready.")
925    self._check_health_thread_should_stop = threading.Event()
926    # Start the thread as daemon to avoid it blocking the program from exiting.
927    # We try best to shutdown the thread but __del__ is not guaranteed to be
928    # called when program exists.
929    self._check_health_thread = threading.Thread(
930        target=self._check_health,
931        daemon=True)
932    self._check_health_thread.start()
933
934  def _stop_check_health_thread(self):
935    if getattr(self, "_check_health_thread", None):
936      logging.info("stopping check health thread")
937      self._check_health_thread_should_stop.set()
938      self._check_health_thread.join()
939      self._check_health_thread = None
940      logging.info("check health thread stopped")
941
942  def _warn_nccl_no_gpu(self):
943    if ((self._communication_options.implementation ==
944         collective_util.CommunicationImplementation.NCCL) and
945        self._local_device_type != "GPU"):
946      logging.warning("Enabled NCCL communication but no GPUs detected/"
947                      "specified.")
948
949  def _in_multi_worker_mode(self):
950    """Whether this strategy indicates working in multi-worker settings."""
951    return self._num_workers > 1
952
953  @property
954  def experimental_between_graph(self):
955    return True
956
957  @property
958  def experimental_should_init(self):
959    return True
960
961  @property
962  def should_checkpoint(self):
963    return self._is_chief
964
965  @property
966  def should_save_summary(self):
967    return self._is_chief
968
969  @property
970  def _num_replicas_in_sync(self):
971    return len(self.worker_devices) * self._num_workers
972
973  # TODO(priyag): Delete this once all strategies use global batch size.
974  @property
975  def _global_batch_size(self):
976    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
977
978    `make_input_fn_iterator` assumes per-replica batching.
979
980    Returns:
981      Boolean.
982    """
983    return True
984
985  def _get_replica_id_in_sync_group(self, replica_id):
986    return self._id_in_cluster * len(self.worker_devices) + replica_id
987
988  def _get_local_replica_id(self, replica_id_in_sync_group):
989    return (replica_id_in_sync_group -
990            self._id_in_cluster * len(self.worker_devices))
991
992  def __deepcopy__(self, memo):
993    # We check the check health thread instead of whether we are in eager mode
994    # to limit the backward incompatibility.
995    if hasattr(self, "_check_health_thread"):
996      raise ValueError(
997          "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. "
998          "If you're using Estimator and see this error message, call "
999          "tf.compat.v1.disable_eager_execution() at the beginning of your "
1000          "program")
1001    # Otherwise, do a regular deepcopy.
1002    cls = self.__class__
1003    result = cls.__new__(cls)
1004    memo[id(self)] = result
1005    for k, v in self.__dict__.items():
1006      setattr(result, k, copy.deepcopy(v, memo))
1007    return result
1008