xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/mirrored_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
16
17import copy
18
19from tensorflow.python import tf2
20from tensorflow.python.distribute import collective_util
21from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
22from tensorflow.python.distribute import cross_device_utils
23from tensorflow.python.distribute import device_util
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import distribution_strategy_context
27from tensorflow.python.distribute import input_lib
28from tensorflow.python.distribute import input_util
29from tensorflow.python.distribute import mirrored_run
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import numpy_dataset
32from tensorflow.python.distribute import reduce_util
33from tensorflow.python.distribute import values
34from tensorflow.python.distribute import values_util
35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
36from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
37from tensorflow.python.eager import context
38from tensorflow.python.eager import tape
39from tensorflow.python.framework import config
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import device as tf_device
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import ops
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import control_flow_util
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
50
51# TODO(josh11b): Replace asserts in this file with if ...: raise ...
52
53
54def _is_device_list_single_worker(devices):
55  """Checks whether the devices list is for single or multi-worker.
56
57  Args:
58    devices: a list of device strings or tf.config.LogicalDevice objects, for
59      either local or for remote devices.
60
61  Returns:
62    a boolean indicating whether these device strings are for local or for
63    remote.
64
65  Raises:
66    ValueError: if device strings are not consistent.
67  """
68  specs = []
69  for d in devices:
70    name = d.name if isinstance(d, context.LogicalDevice) else d
71    specs.append(tf_device.DeviceSpec.from_string(name))
72  num_workers = len({(d.job, d.task, d.replica) for d in specs})
73  all_local = all(d.job in (None, "localhost") for d in specs)
74  any_local = any(d.job in (None, "localhost") for d in specs)
75
76  if any_local and not all_local:
77    raise ValueError("Local device should have only 'localhost' in the job "
78                     "field in device string. "
79                     "E.g. 'job:localhost' in "
80                     "/job:localhost/replica:0/task:0/device:CPU:0"
81                     "Devices cannot have mixed list of device strings "
82                     "containing both localhost and other job types such as "
83                     "worker, ps etc. ")
84
85  if num_workers == 1 and not all_local:
86    if any(d.task is None for d in specs):
87      raise ValueError("Remote device string must have task specified."
88                       "E.g. 'task:0' in "
89                       "/job:worker/replica:0/task:0/device:CPU:0")
90
91  return num_workers == 1
92
93
94def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
95  """Returns a device list given a cluster spec."""
96  cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
97  devices = []
98  for task_type in ("chief", "worker"):
99    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
100      if num_gpus_per_worker == 0:
101        devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
102      else:
103        devices.extend([
104            "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
105            for gpu_id in range(num_gpus_per_worker)
106        ])
107  return devices
108
109
110def _group_device_list(devices):
111  """Groups the devices list by task_type and task_id.
112
113  Args:
114    devices: a list of device strings for remote devices.
115
116  Returns:
117    a dict of list of device strings mapping from task_type to a list of devices
118    for the task_type in the ascending order of task_id.
119  """
120  assert not _is_device_list_single_worker(devices)
121  device_dict = {}
122
123  for d in devices:
124    d_spec = tf_device.DeviceSpec.from_string(d)
125
126    # Create an entry for the task_type.
127    if d_spec.job not in device_dict:
128      device_dict[d_spec.job] = []
129
130    # Fill the device list for task_type until it covers the task_id.
131    while len(device_dict[d_spec.job]) <= d_spec.task:
132      device_dict[d_spec.job].append([])
133
134    device_dict[d_spec.job][d_spec.task].append(d)
135
136  return device_dict
137
138
139def _is_gpu_device(device):
140  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
141
142
143def _infer_num_gpus_per_worker(devices):
144  """Infers the number of GPUs on each worker.
145
146  Currently to make multi-worker cross device ops work, we need all workers to
147  have the same number of GPUs.
148
149  Args:
150    devices: a list of device strings, can be either local devices or remote
151      devices.
152
153  Returns:
154    number of GPUs per worker.
155
156  Raises:
157    ValueError if workers have different number of GPUs or GPU indices are not
158    consecutive and starting from 0.
159  """
160  if _is_device_list_single_worker(devices):
161    return sum(1 for d in devices if _is_gpu_device(d))
162  else:
163    device_dict = _group_device_list(devices)
164    num_gpus = None
165    for _, devices_in_task in device_dict.items():
166      for device_in_task in devices_in_task:
167        if num_gpus is None:
168          num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
169
170        # Verify other workers have the same number of GPUs.
171        elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
172          raise ValueError("All workers should have the same number of GPUs.")
173
174        for d in device_in_task:
175          d_spec = tf_device.DeviceSpec.from_string(d)
176          if (d_spec.device_type == "GPU" and
177              d_spec.device_index >= num_gpus):
178            raise ValueError("GPU `device_index` on a worker should be "
179                             "consecutive and start from 0.")
180    return num_gpus
181
182
183def all_local_devices(num_gpus=None):
184  devices = config.list_logical_devices("GPU")
185  if num_gpus is not None:
186    devices = devices[:num_gpus]
187  return devices or config.list_logical_devices("CPU")
188
189
190def all_devices():
191  devices = []
192  tfconfig = TFConfigClusterResolver()
193  if tfconfig.cluster_spec().as_dict():
194    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
195                                           context.num_gpus())
196  return devices if devices else all_local_devices()
197
198
199@tf_export("distribute.MirroredStrategy", v1=[])  # pylint: disable=g-classes-have-attributes
200class MirroredStrategy(distribute_lib.Strategy):
201  """Synchronous training across multiple replicas on one machine.
202
203  This strategy is typically used for training on one
204  machine with multiple GPUs. For TPUs, use
205  `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
206  please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
207
208  For example, a variable created under a `MirroredStrategy` is a
209  `MirroredVariable`. If no devices are specified in the constructor argument of
210  the strategy then it will use all the available GPUs. If no GPUs are found, it
211  will use the available CPUs. Note that TensorFlow treats all CPUs on a
212  machine as a single device, and uses threads internally for parallelism.
213
214  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
215  >>> with strategy.scope():
216  ...   x = tf.Variable(1.)
217  >>> x
218  MirroredVariable:{
219    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
220    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
221  }
222
223  While using distribution strategies, all the variable creation should be done
224  within the strategy's scope. This will replicate the variables across all the
225  replicas and keep them in sync using an all-reduce algorithm.
226
227  Variables created inside a `MirroredStrategy` which is wrapped with a
228  `tf.function` are still `MirroredVariables`.
229
230  >>> x = []
231  >>> @tf.function  # Wrap the function with tf.function.
232  ... def create_variable():
233  ...   if not x:
234  ...     x.append(tf.Variable(1.))
235  ...   return x[0]
236  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
237  >>> with strategy.scope():
238  ...   _ = create_variable()
239  ...   print(x[0])
240  MirroredVariable:{
241    0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
242    1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
243  }
244
245  `experimental_distribute_dataset` can be used to distribute the dataset across
246  the replicas when writing your own training loop. If you are using `.fit` and
247  `.compile` methods available in `tf.keras`, then `tf.keras` will handle the
248  distribution for you.
249
250  For example:
251
252  ```python
253  my_strategy = tf.distribute.MirroredStrategy()
254  with my_strategy.scope():
255    @tf.function
256    def distribute_train_epoch(dataset):
257      def replica_fn(input):
258        # process input and return result
259        return result
260
261      total_result = 0
262      for x in dataset:
263        per_replica_result = my_strategy.run(replica_fn, args=(x,))
264        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
265                                           per_replica_result, axis=None)
266      return total_result
267
268    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
269    for _ in range(EPOCHS):
270      train_result = distribute_train_epoch(dist_dataset)
271  ```
272
273  Args:
274    devices: a list of device strings such as `['/gpu:0', '/gpu:1']`.  If
275      `None`, all available GPUs are used. If no GPUs are found, CPU is used.
276    cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
277      set, `NcclAllReduce()` will be used by default.  One would customize this
278      if NCCL isn't available or if a special implementation that exploits
279      the particular hardware is available.
280  """
281
282  # Only set this in tests.
283  _collective_key_base = 0
284
285  def __init__(self, devices=None, cross_device_ops=None):
286    extended = MirroredExtended(
287        self, devices=devices, cross_device_ops=cross_device_ops)
288    super(MirroredStrategy, self).__init__(extended)
289    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
290        "MirroredStrategy")
291
292
293@tf_export(v1=["distribute.MirroredStrategy"])
294class MirroredStrategyV1(distribute_lib.StrategyV1):  # pylint: disable=g-missing-docstring
295
296  __doc__ = MirroredStrategy.__doc__
297
298  # Only set this in tests.
299  _collective_key_base = 0
300
301  def __init__(self, devices=None, cross_device_ops=None):
302    extended = MirroredExtended(
303        self, devices=devices, cross_device_ops=cross_device_ops)
304    super(MirroredStrategyV1, self).__init__(extended)
305    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
306        "MirroredStrategy")
307
308
309# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
310class MirroredExtended(distribute_lib.StrategyExtendedV1):
311  """Implementation of MirroredStrategy."""
312
313  # If this is set to True, use NCCL collective ops instead of NCCL cross device
314  # ops.
315  _prefer_collective_ops = False
316
317  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
318    super(MirroredExtended, self).__init__(container_strategy)
319    if context.executing_eagerly():
320      if devices and not _is_device_list_single_worker(devices):
321        raise RuntimeError("In-graph multi-worker training with "
322                           "`MirroredStrategy` is not supported in eager mode.")
323      else:
324        if TFConfigClusterResolver().cluster_spec().as_dict():
325          # if you are executing in eager mode, only the single machine code
326          # path is supported.
327          logging.info("Initializing local devices since in-graph multi-worker "
328                       "training with `MirroredStrategy` is not supported in "
329                       "eager mode. TF_CONFIG will be ignored when "
330                       "when initializing `MirroredStrategy`.")
331        devices = devices or all_local_devices()
332    else:
333      devices = devices or all_devices()
334
335    assert devices, ("Got an empty `devices` list and unable to recognize "
336                     "any local devices.")
337    self._cross_device_ops = cross_device_ops
338    self._collective_ops_in_use = False
339    self._collective_key_base = container_strategy._collective_key_base
340    self._communication_options = collective_util.Options(
341        implementation=collective_util.CommunicationImplementation.NCCL)
342    self._initialize_strategy(devices)
343
344    # TODO(b/128995245): Enable last partial batch support in graph mode.
345    if ops.executing_eagerly_outside_functions():
346      self.experimental_enable_get_next_as_optional = True
347
348    # Flag to turn on VariablePolicy.
349    self._use_var_policy = False
350
351  def _use_merge_call(self):
352    # We currently only disable merge_call when XLA is used to compile the `fn`
353    # passed to `strategy.run` and all devices are GPU.
354    return not control_flow_util.GraphOrParentsInXlaContext(
355        ops.get_default_graph()) or not all(
356            [_is_gpu_device(d) for d in self._devices])
357
358  def _initialize_strategy(self, devices):
359    # The _initialize_strategy method is intended to be used by distribute
360    # coordinator as well.
361    assert devices, "Must specify at least one device."
362    devices = tuple(device_util.resolve(d) for d in devices)
363    assert len(set(devices)) == len(devices), (
364        "No duplicates allowed in `devices` argument: %s" % (devices,))
365    if _is_device_list_single_worker(devices):
366      self._initialize_single_worker(devices)
367      self._collective_ops = self._make_collective_ops(devices)
368      if self._prefer_collective_ops and (
369          isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce)
370          or isinstance(self._inferred_cross_device_ops,
371                        cross_device_ops_lib.NcclAllReduce)):
372        self._collective_ops_in_use = True
373        self._inferred_cross_device_ops = None
374      logging.info("Using MirroredStrategy with devices %r", devices)
375    else:
376      self._initialize_multi_worker(devices)
377
378  def _make_collective_ops(self, devices):
379    self._collective_keys = cross_device_utils.CollectiveKeys(
380        group_key_start=1 + self._collective_key_base)
381    return cross_device_ops_lib.CollectiveAllReduce(
382        devices=self._devices,
383        group_size=len(self._devices),
384        options=self._communication_options,
385        collective_keys=self._collective_keys)
386
387  def _initialize_single_worker(self, devices):
388    """Initializes the object for single-worker training."""
389    self._devices = tuple(device_util.canonicalize(d) for d in devices)
390    self._input_workers_devices = (
391        (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
392
393    self._inferred_cross_device_ops = None if self._cross_device_ops else (
394        cross_device_ops_lib.select_cross_device_ops(devices))
395    self._host_input_device = numpy_dataset.SingleDevice(
396        self._input_workers_devices[0][0])
397    self._is_multi_worker_training = False
398    device_spec = tf_device.DeviceSpec.from_string(
399        self._input_workers_devices[0][0])
400    # Ensures when we enter strategy.scope() we use the correct default device
401    if device_spec.job is not None and device_spec.job != "localhost":
402      self._default_device = "/job:%s/replica:%d/task:%d" % (
403          device_spec.job, device_spec.replica, device_spec.task)
404
405  def _initialize_multi_worker(self, devices):
406    """Initializes the object for multi-worker training."""
407    device_dict = _group_device_list(devices)
408    workers = []
409    worker_devices = []
410    for job in ("chief", "worker"):
411      for task in range(len(device_dict.get(job, []))):
412        worker = "/job:%s/task:%d" % (job, task)
413        workers.append(worker)
414        worker_devices.append((worker, device_dict[job][task]))
415
416    # Setting `_default_device` will add a device scope in the
417    # distribution.scope. We set the default device to the first worker. When
418    # users specify device under distribution.scope by
419    #   with tf.device("/cpu:0"):
420    #     ...
421    # their ops will end up on the cpu device of its first worker, e.g.
422    # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
423    self._default_device = workers[0]
424    self._host_input_device = numpy_dataset.SingleDevice(workers[0])
425
426    self._devices = tuple(devices)
427    self._input_workers_devices = worker_devices
428    self._is_multi_worker_training = True
429
430    if len(workers) > 1:
431      # Grandfather usage in the legacy tests if they're configured properly.
432      if (not isinstance(self._cross_device_ops,
433                         cross_device_ops_lib.ReductionToOneDevice) or
434          self._cross_device_ops._num_between_graph_workers > 1):  # pylint: disable=protected-access
435        raise ValueError(
436            "In-graph multi-worker training with `MirroredStrategy` is not "
437            "supported.")
438      self._inferred_cross_device_ops = self._cross_device_ops
439    else:
440      # TODO(yuefengz): make `select_cross_device_ops` work with device strings
441      # containing job names.
442      self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
443
444    logging.info("Using MirroredStrategy with remote devices %r", devices)
445
446  def _input_workers_with_options(self, options=None):
447    if not options:
448      return input_lib.InputWorkers(self._input_workers_devices)
449    if (options.experimental_replication_mode ==
450        distribute_lib.InputReplicationMode.PER_REPLICA):
451      if options.experimental_place_dataset_on_device:
452        self._input_workers_devices = (
453            tuple(
454                (device_util.canonicalize(d, d), (d,)) for d in self._devices))
455      else:
456        self._input_workers_devices = (
457            tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
458                  for d in self._devices))
459      return input_lib.InputWorkers(self._input_workers_devices)
460    else:
461      if not options.experimental_fetch_to_device:
462        return input_lib.InputWorkers([
463            (host_device, (host_device,) * len(compute_devices))
464            for host_device, compute_devices in self._input_workers_devices
465        ])
466      else:
467        return input_lib.InputWorkers(self._input_workers_devices)
468
469  @property
470  def _input_workers(self):
471    return self._input_workers_with_options()
472
473  def _get_variable_creator_initial_value(self,
474                                          replica_id,
475                                          device,
476                                          primary_var,
477                                          **kwargs):
478    """Return the initial value for variables on a replica."""
479    if replica_id == 0:
480      return kwargs["initial_value"]
481    else:
482      assert primary_var is not None
483      assert device is not None
484      assert kwargs is not None
485
486      def initial_value_fn():
487        if context.executing_eagerly() or ops.inside_function():
488          init_value = primary_var.value()
489          return array_ops.identity(init_value)
490        else:
491          with ops.device(device):
492            init_value = primary_var.initial_value
493            return array_ops.identity(init_value)
494
495      return initial_value_fn
496
497  def _create_variable(self, next_creator, **kwargs):
498    """Create a mirrored variable. See `DistributionStrategy.scope`."""
499    colocate_with = kwargs.pop("colocate_with", None)
500    if colocate_with is None:
501      devices = self._devices
502    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
503      with ops.device(colocate_with.device):
504        return next_creator(**kwargs)
505    else:
506      devices = colocate_with._devices  # pylint: disable=protected-access
507
508    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
509      value_list = []
510      for i, d in enumerate(devices):
511        with ops.device(d):
512          kwargs["initial_value"] = self._get_variable_creator_initial_value(
513              replica_id=i,
514              device=d,
515              primary_var=value_list[0] if value_list else None,
516              **kwargs)
517          if i > 0:
518            # Give replicas meaningful distinct names:
519            var0name = value_list[0].name.split(":")[0]
520            # We append a / to variable names created on replicas with id > 0 to
521            # ensure that we ignore the name scope and instead use the given
522            # name as the absolute name of the variable.
523            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
524          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
525            # Don't record operations (e.g. other variable reads) during
526            # variable creation.
527            with tape.stop_recording():
528              v = next_creator(**kwargs)
529          assert not isinstance(v, values.DistributedVariable)
530          value_list.append(v)
531      return value_list
532
533    return distribute_utils.create_mirrored_variable(
534        self._container_strategy(), _real_mirrored_creator,
535        distribute_utils.VARIABLE_CLASS_MAPPING,
536        distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
537
538  def _validate_colocate_with_variable(self, colocate_with_variable):
539    distribute_utils.validate_colocate_distributed_variable(
540        colocate_with_variable, self)
541
542  def _make_dataset_iterator(self, dataset):
543    return input_lib_v1.DatasetIterator(
544        dataset,
545        self._input_workers,
546        self._container_strategy(),
547        num_replicas_in_sync=self._num_replicas_in_sync)
548
549  def _make_input_fn_iterator(
550      self,
551      input_fn,
552      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
553    input_contexts = []
554    num_workers = self._input_workers.num_workers
555    for i in range(num_workers):
556      input_contexts.append(distribute_lib.InputContext(
557          num_input_pipelines=num_workers,
558          input_pipeline_id=i,
559          num_replicas_in_sync=self._num_replicas_in_sync))
560    return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
561                                              input_contexts,
562                                              self._container_strategy())
563
564  def _experimental_distribute_dataset(self, dataset, options):
565    if (options and options.experimental_replication_mode ==
566        distribute_lib.InputReplicationMode.PER_REPLICA):
567      raise NotImplementedError(
568          "InputReplicationMode.PER_REPLICA "
569          "is only supported in "
570          "`distribute_datasets_from_function`."
571      )
572    return input_util.get_distributed_dataset(
573        dataset,
574        self._input_workers_with_options(options),
575        self._container_strategy(),
576        num_replicas_in_sync=self._num_replicas_in_sync,
577        options=options)
578
579  def _experimental_make_numpy_dataset(self, numpy_input, session):
580    return numpy_dataset.one_host_numpy_dataset(
581        numpy_input, self._host_input_device, session)
582
583  def _distribute_datasets_from_function(self, dataset_fn, options):
584    input_workers = self._input_workers_with_options(options)
585    input_contexts = []
586    num_workers = input_workers.num_workers
587    for i in range(num_workers):
588      input_contexts.append(distribute_lib.InputContext(
589          num_input_pipelines=num_workers,
590          input_pipeline_id=i,
591          num_replicas_in_sync=self._num_replicas_in_sync))
592
593    return input_util.get_distributed_datasets_from_function(
594        dataset_fn, input_workers, input_contexts, self._container_strategy(),
595        options)
596
597  def _experimental_distribute_values_from_function(self, value_fn):
598    per_replica_values = []
599    for replica_id in range(self._num_replicas_in_sync):
600      per_replica_values.append(value_fn(
601          distribute_lib.ValueContext(replica_id,
602                                      self._num_replicas_in_sync)))
603    return distribute_utils.regroup(per_replica_values, always_wrap=True)
604
605  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
606  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
607                                          initial_loop_values=None):
608    if initial_loop_values is None:
609      initial_loop_values = {}
610    initial_loop_values = nest.flatten(initial_loop_values)
611
612    ctx = input_lib.MultiStepContext()
613    def body(i, *args):
614      """A wrapper around `fn` to create the while loop body."""
615      del args
616      fn_result = fn(ctx, iterator.get_next())
617      for (name, output) in ctx.last_step_outputs.items():
618        # Convert all outputs to tensors, potentially from `DistributedValues`.
619        ctx.last_step_outputs[name] = self._local_results(output)
620      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
621      with ops.control_dependencies([fn_result]):
622        return [i + 1] + flat_last_step_outputs
623
624    # We capture the control_flow_context at this point, before we run `fn`
625    # inside a while_loop. This is useful in cases where we might need to exit
626    # these contexts and get back to the outer context to do some things, for
627    # e.g. create an op which should be evaluated only once at the end of the
628    # loop on the host. One such usage is in creating metrics' value op.
629    self._outer_control_flow_context = (
630        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
631
632    cond = lambda i, *args: i < iterations
633    i = constant_op.constant(0)
634    loop_result = control_flow_ops.while_loop(
635        cond, body, [i] + initial_loop_values, name="",
636        parallel_iterations=1, back_prop=False, swap_memory=False,
637        return_same_structure=True)
638    del self._outer_control_flow_context
639
640    ctx.run_op = control_flow_ops.group(loop_result)
641
642    # Convert the last_step_outputs from a list to the original dict structure
643    # of last_step_outputs.
644    last_step_tensor_outputs = loop_result[1:]
645    last_step_tensor_outputs_dict = nest.pack_sequence_as(
646        ctx.last_step_outputs, last_step_tensor_outputs)
647
648    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
649      output = last_step_tensor_outputs_dict[name]
650      # For outputs that have already been reduced, wrap them in a Mirrored
651      # container, else in a PerReplica container.
652      if reduce_op is None:
653        last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
654      else:
655        assert len(output) == 1
656        last_step_tensor_outputs_dict[name] = output[0]
657
658    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
659    return ctx
660
661  def _broadcast_to(self, tensor, destinations):
662    # This is both a fast path for Python constants, and a way to delay
663    # converting Python values to a tensor until we know what type it
664    # should be converted to. Otherwise we have trouble with:
665    #   global_step.assign_add(1)
666    # since the `1` gets broadcast as an int32 but global_step is int64.
667    if isinstance(tensor, (float, int)):
668      return tensor
669    # TODO(josh11b): In eager mode, use one thread per device, or async mode.
670    if not destinations:
671      # TODO(josh11b): Use current logical device instead of 0 here.
672      destinations = self._devices
673    return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
674
675  def _call_for_each_replica(self, fn, args, kwargs):
676    return mirrored_run.call_for_each_replica(
677        self._container_strategy(), fn, args, kwargs)
678
679  def _configure(self,
680                 session_config=None,
681                 cluster_spec=None,
682                 task_type=None,
683                 task_id=None):
684    del task_type, task_id
685
686    if session_config:
687      session_config.CopyFrom(self._update_config_proto(session_config))
688
689    if cluster_spec:
690      # TODO(yuefengz): remove the following code once cluster_resolver is
691      # added.
692      num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
693      multi_worker_devices = _cluster_spec_to_device_list(
694          cluster_spec, num_gpus_per_worker)
695      self._initialize_multi_worker(multi_worker_devices)
696
697  def _update_config_proto(self, config_proto):
698    updated_config = copy.deepcopy(config_proto)
699    updated_config.isolate_session_state = True
700    return updated_config
701
702  def _get_cross_device_ops(self, value):
703    if not self._use_merge_call():
704      return self._collective_ops
705
706    if self._collective_ops_in_use:
707      if isinstance(value, values.DistributedValues):
708        value_int32 = True in {
709            dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
710        }
711      else:
712        value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
713      if value_int32:
714        return cross_device_ops_lib.ReductionToOneDevice()
715      else:
716        return self._collective_ops
717
718    return self._cross_device_ops or self._inferred_cross_device_ops
719
720  def _gather_to_implementation(self, value, destinations, axis, options):
721    if not isinstance(value, values.DistributedValues):
722      # ReductionToOneDevice._gather accepts DistributedValues only.
723      return value
724    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
725        value,
726        destinations=destinations,
727        axis=axis,
728        options=self._communication_options.merge(options))
729
730  def _reduce_to(self, reduce_op, value, destinations, options):
731    if (distribute_utils.is_mirrored(value) and
732        reduce_op == reduce_util.ReduceOp.MEAN):
733      return value
734    assert not distribute_utils.is_mirrored(value)
735    def get_values(value):
736      if not isinstance(value, values.DistributedValues):
737        # This function handles reducing values that are not PerReplica or
738        # Mirrored values. For example, the same value could be present on all
739        # replicas in which case `value` would be a single value or value could
740        # be 0.
741        return cross_device_ops_lib.reduce_non_distributed_value(
742            reduce_op, value, destinations, self._num_replicas_in_sync)
743      if self._use_merge_call() and self._collective_ops_in_use and ((
744          not cross_device_ops_lib._devices_match(value, destinations) or  # pylint: disable=protected-access
745          any("cpu" in d.lower()
746              for d in cross_device_ops_lib.get_devices_from(destinations)))):
747        return cross_device_ops_lib.ReductionToOneDevice().reduce(
748            reduce_op, value, destinations)
749      return self._get_cross_device_ops(value).reduce(
750          reduce_op,
751          value,
752          destinations=destinations,
753          options=self._communication_options.merge(options))
754
755    return nest.map_structure(get_values, value)
756
757  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
758    cross_device_ops = None
759    for value, _ in value_destination_pairs:
760      if cross_device_ops is None:
761        cross_device_ops = self._get_cross_device_ops(value)
762      elif cross_device_ops is not self._get_cross_device_ops(value):
763        raise ValueError("Inputs to batch_reduce_to must be either all on "
764                         "the host or all on the compute devices.")
765    return cross_device_ops.batch_reduce(
766        reduce_op,
767        value_destination_pairs,
768        options=self._communication_options.merge(options))
769
770  def _update(self, var, fn, args, kwargs, group):
771    # TODO(josh11b): In eager mode, use one thread per device.
772    assert isinstance(var, values.DistributedVariable)
773    updates = []
774    for i, v in enumerate(var.values):
775      name = "update_%d" % i
776      with ops.device(v.device), \
777           distribute_lib.UpdateContext(i), \
778           ops.name_scope(name):
779        # If args and kwargs are not mirrored, the value is returned as is.
780        updates.append(
781            fn(v, *distribute_utils.select_replica(i, args),
782               **distribute_utils.select_replica(i, kwargs)))
783    return distribute_utils.update_regroup(self, updates, group)
784
785  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
786    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
787    # This implementation avoids using `merge_call` and just launches collective
788    # ops in one replica.
789    if options is None:
790      options = collective_util.Options()
791
792    if context.executing_eagerly() or (
793        not tf2.enabled()) or self._use_merge_call():
794      # In eager mode, falls back to the default implementation that uses
795      # `merge_call`. Replica functions are running sequentially in eager mode,
796      # and due to the blocking nature of collective ops, execution will hang if
797      # collective ops are to be launched sequentially.
798      return super()._replica_ctx_all_reduce(reduce_op, value, options)
799
800    replica_context = distribution_strategy_context.get_replica_context()
801    assert replica_context, (
802        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
803        "replica context")
804    return self._get_cross_device_ops(value)._all_reduce(  # pylint: disable=protected-access
805        reduce_op,
806        value,
807        replica_context._replica_id,  # pylint: disable=protected-access
808        options)
809
810  def _replica_ctx_update(self, var, fn, args, kwargs, group):
811    if self._use_merge_call():
812      return super()._replica_ctx_update(var, fn, args, kwargs, group)
813
814    replica_context = distribution_strategy_context.get_replica_context()
815    assert replica_context
816    replica_id = values_util.get_current_replica_id_as_int()
817    name = "update_%d" % replica_id
818
819    if isinstance(var, values.DistributedVariable):
820      var = var._get_replica(replica_id)  # pylint: disable=protected-access
821
822    with ops.device(var.device), ops.name_scope(name):
823      result = fn(var, *args, **kwargs)
824    return result
825
826  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
827    assert isinstance(colocate_with, tuple)
828    # TODO(josh11b): In eager mode, use one thread per device.
829    updates = []
830    for i, d in enumerate(colocate_with):
831      name = "update_%d" % i
832      with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
833        updates.append(
834            fn(*distribute_utils.select_replica(i, args),
835               **distribute_utils.select_replica(i, kwargs)))
836    return distribute_utils.update_regroup(self, updates, group)
837
838  def read_var(self, replica_local_var):
839    """Read the aggregate value of a replica-local variable."""
840    # pylint: disable=protected-access
841    if distribute_utils.is_sync_on_read(replica_local_var):
842      return replica_local_var._get_cross_replica()
843    assert distribute_utils.is_mirrored(replica_local_var)
844    return array_ops.identity(replica_local_var._get())
845    # pylint: enable=protected-access
846
847  def value_container(self, val):
848    return distribute_utils.value_container(val)
849
850  @property
851  def _num_replicas_in_sync(self):
852    return len(self._devices)
853
854  @property
855  def worker_devices(self):
856    return self._devices
857
858  @property
859  def worker_devices_by_replica(self):
860    return [[d] for d in self._devices]
861
862  @property
863  def parameter_devices(self):
864    return self.worker_devices
865
866  @property
867  def experimental_between_graph(self):
868    return False
869
870  @property
871  def experimental_should_init(self):
872    return True
873
874  @property
875  def should_checkpoint(self):
876    return True
877
878  @property
879  def should_save_summary(self):
880    return True
881
882  def non_slot_devices(self, var_list):
883    del var_list
884    # TODO(josh11b): Should this be the last logical device instead?
885    return self._devices
886
887  # TODO(priyag): Delete this once all strategies use global batch size.
888  @property
889  def _global_batch_size(self):
890    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
891
892    `make_input_fn_iterator` assumes per-replica batching.
893
894    Returns:
895      Boolean.
896    """
897    return True
898
899  def _in_multi_worker_mode(self):
900    """Whether this strategy indicates working in multi-worker settings."""
901    return False
902
903  def _get_local_replica_id(self, replica_id_in_sync_group):
904    return replica_id_in_sync_group
905
906  def _get_replica_id_in_sync_group(self, replica_id):
907    return replica_id
908