xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
16
17import atexit
18import collections
19import contextlib
20import copy
21import functools
22import weakref
23
24from absl import logging
25import numpy as np
26
27from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
28from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
29from tensorflow.python.autograph.impl import api as autograph
30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import input_lib
35from tensorflow.python.distribute import input_util
36from tensorflow.python.distribute import numpy_dataset
37from tensorflow.python.distribute import reduce_util
38from tensorflow.python.distribute import tpu_replicated_variable
39from tensorflow.python.distribute import tpu_util
40from tensorflow.python.distribute import tpu_values
41from tensorflow.python.distribute import values
42from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
43from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
44from tensorflow.python.eager import context
45from tensorflow.python.eager import def_function
46from tensorflow.python.eager import function
47from tensorflow.python.framework import constant_op
48from tensorflow.python.framework import device_spec
49from tensorflow.python.framework import dtypes
50from tensorflow.python.framework import indexed_slices
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import sparse_tensor
53from tensorflow.python.framework import tensor_shape
54from tensorflow.python.framework import tensor_util
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import control_flow_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import resource_variable_ops
59from tensorflow.python.ops import variables as variables_lib
60from tensorflow.python.ops.ragged import ragged_tensor
61from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
62from tensorflow.python.tpu import tpu
63from tensorflow.python.tpu import tpu_hardware_feature
64from tensorflow.python.tpu import tpu_strategy_util
65from tensorflow.python.tpu import training_loop
66from tensorflow.python.tpu.ops import tpu_ops
67from tensorflow.python.util import deprecation
68from tensorflow.python.util import nest
69from tensorflow.python.util import tf_inspect
70from tensorflow.python.util.tf_export import tf_export
71
72
73_XLA_OP_BY_OP_INPUTS_LIMIT = 200
74
75
76@contextlib.contextmanager
77def maybe_init_scope():
78  if ops.executing_eagerly_outside_functions():
79    yield
80  else:
81    with ops.init_scope():
82      yield
83
84
85def validate_run_function(fn):
86  """Validate the function passed into strategy.run."""
87
88  # We allow three types of functions/objects passed into TPUStrategy
89  # run in eager mode:
90  #   1. a user annotated tf.function
91  #   2. a ConcreteFunction, this is mostly what you get from loading a saved
92  #      model.
93  #   3. a callable object and the `__call__` method itself is a tf.function.
94  #
95  # Otherwise we return an error, because we don't support eagerly running
96  # run in TPUStrategy.
97
98  if context.executing_eagerly() \
99      and not isinstance(fn, def_function.Function) \
100      and not isinstance(fn, function.ConcreteFunction) \
101      and not (callable(fn) and isinstance(fn.__call__, def_function.Function)):
102    raise NotImplementedError(
103        "TPUStrategy.run(fn, ...) does not support pure eager "
104        "execution. please make sure the function passed into "
105        "`strategy.run` is a `tf.function` or "
106        "`strategy.run` is called inside a `tf.function` if "
107        "eager behavior is enabled.")
108
109
110def _maybe_partial_apply_variables(fn, args, kwargs):
111  """Inspects arguments to partially apply any DistributedVariable.
112
113  This avoids an automatic cast of the current variable value to tensor.
114
115  Note that a variable may be captured implicitly with Python scope instead of
116  passing it to run(), but supporting run() keeps behavior consistent
117  with MirroredStrategy.
118
119  Since positional arguments must be applied from left to right, this function
120  does some tricky function inspection to move variable positional arguments
121  into kwargs. As a result of this, we can't support passing Variables as *args,
122  nor as args to functions which combine both explicit positional arguments and
123  *args.
124
125  Args:
126    fn: The function to run, as passed to run().
127    args: Positional arguments to fn, as passed to run().
128    kwargs: Keyword arguments to fn, as passed to run().
129
130  Returns:
131    A tuple of the function (possibly wrapped), args, kwargs (both
132    possibly filtered, with members of args possibly moved to kwargs).
133    If no variables are found, this function is a noop.
134
135  Raises:
136    ValueError: If the function signature makes unsupported use of *args, or if
137      too many arguments are passed.
138  """
139
140  def is_distributed_var(x):
141    flat = nest.flatten(x)
142    return flat and isinstance(flat[0], values.DistributedVariable)
143
144  # We will split kwargs into two dicts, one of which will be applied now.
145  var_kwargs = {}
146  nonvar_kwargs = {}
147
148  if kwargs:
149    var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)}
150  if var_kwargs:
151    nonvar_kwargs = {
152        k: v for k, v in kwargs.items() if not is_distributed_var(v)
153    }
154
155  # Dump the argument names of `fn` to a list. This will include both positional
156  # and keyword arguments, but since positional arguments come first we can
157  # look up names of positional arguments by index.
158  positional_args = []
159  index_of_star_args = None
160  for i, p in enumerate(tf_inspect.signature(fn).parameters.values()):
161    # Class methods define "self" as first argument, but we don't pass "self".
162    # Note that this is a heuristic, as a method can name its first argument
163    # something else, and a function can define a first argument "self" as well.
164    # In both of these cases, using a Variable will fail with an unfortunate
165    # error about the number of arguments.
166    # inspect.is_method() seems not to work here, possibly due to the use of
167    # tf.function().
168    if i == 0 and p.name == "self":
169      continue
170
171    if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD:
172      positional_args.append(p.name)
173
174    elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL:
175      # We'll raise an error later if a variable is passed to *args, since we
176      # can neither pass it by name nor partially apply it. This case only
177      # happens once at most.
178      index_of_star_args = i
179
180    elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY:
181      # This is a rare Python feature, indicating a / in the arg list.
182      if var_kwargs or any(is_distributed_var(a) for a in args):
183        raise ValueError(
184            "Mixing Variables and positional-only parameters not supported by "
185            f"TPUStrategy. Received {len(var_kwargs)} DistributedVariables in "
186            f"**kwargs and {sum(is_distributed_var(a) for a in args)} in *args,"
187            " expected zero for both."
188        )
189      return fn, args, kwargs
190
191  star_args = []
192  have_seen_var_arg = False
193
194  for i, a in enumerate(args):
195    if is_distributed_var(a):
196      if index_of_star_args is not None and i >= index_of_star_args:
197        raise ValueError(
198            "TPUStrategy.run() cannot handle Variables passed to *args. "
199            "Either name the function argument, or capture the Variable "
200            "implicitly.")
201      if len(positional_args) <= i:
202        raise ValueError(
203            "Too many positional arguments passed to call to TPUStrategy.run()."
204        )
205      var_kwargs[positional_args[i]] = a
206      have_seen_var_arg = True
207    else:
208      if index_of_star_args is not None and i >= index_of_star_args:
209        if have_seen_var_arg:
210          raise ValueError(
211              "TPUStrategy.run() cannot handle both Variables and a mix of "
212              "positional args and *args. Either remove the *args, or capture "
213              "the Variable implicitly.")
214        else:
215          star_args.append(a)
216          continue
217
218      if len(positional_args) <= i:
219        raise ValueError(
220            "Too many positional arguments passed to call to TPUStrategy.run()."
221        )
222      nonvar_kwargs[positional_args[i]] = a
223
224  if var_kwargs:
225    return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs
226  return fn, args, kwargs
227
228
229@tf_export("distribute.TPUStrategy", v1=[])
230class TPUStrategyV2(distribute_lib.Strategy):
231  """Synchronous training on TPUs and TPU Pods.
232
233  To construct a TPUStrategy object, you need to run the
234  initialization code as below:
235
236  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
237  >>> tf.config.experimental_connect_to_cluster(resolver)
238  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
239  >>> strategy = tf.distribute.TPUStrategy(resolver)
240
241  While using distribution strategies, the variables created within the
242  strategy's scope will be replicated across all the replicas and can be kept in
243  sync using all-reduce algorithms.
244
245  To run TF2 programs on TPUs, you can either use `.compile` and
246  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
247  training loop by calling `strategy.run` directly. Note that
248  TPUStrategy doesn't support pure eager execution, so please make sure the
249  function passed into `strategy.run` is a `tf.function` or
250  `strategy.run` is called inside a `tf.function` if eager
251  behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
252
253  `distribute_datasets_from_function` and
254  `experimental_distribute_dataset` APIs can be used to distribute the dataset
255  across the TPU workers when writing your own training loop. If you are using
256  `fit` and `compile` methods available in `tf.keras.Model`, then Keras will
257  handle the distribution for you.
258
259  An example of writing customized training loop on TPUs:
260
261  >>> with strategy.scope():
262  ...   model = tf.keras.Sequential([
263  ...     tf.keras.layers.Dense(2, input_shape=(5,)),
264  ...   ])
265  ...   optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
266
267  >>> def dataset_fn(ctx):
268  ...   x = np.random.random((2, 5)).astype(np.float32)
269  ...   y = np.random.randint(2, size=(2, 1))
270  ...   dataset = tf.data.Dataset.from_tensor_slices((x, y))
271  ...   return dataset.repeat().batch(1, drop_remainder=True)
272  >>> dist_dataset = strategy.distribute_datasets_from_function(
273  ...     dataset_fn)
274  >>> iterator = iter(dist_dataset)
275
276  >>> @tf.function()
277  ... def train_step(iterator):
278  ...
279  ...   def step_fn(inputs):
280  ...     features, labels = inputs
281  ...     with tf.GradientTape() as tape:
282  ...       logits = model(features, training=True)
283  ...       loss = tf.keras.losses.sparse_categorical_crossentropy(
284  ...           labels, logits)
285  ...
286  ...     grads = tape.gradient(loss, model.trainable_variables)
287  ...     optimizer.apply_gradients(zip(grads, model.trainable_variables))
288  ...
289  ...   strategy.run(step_fn, args=(next(iterator),))
290
291  >>> train_step(iterator)
292
293  For the advanced use cases like model parallelism, you can set
294  `experimental_device_assignment` argument when creating TPUStrategy to specify
295  number of replicas and number of logical devices. Below is an example to
296  initialize TPU system with 2 logical devices and 1 replica.
297
298  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
299  >>> tf.config.experimental_connect_to_cluster(resolver)
300  >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
301  >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
302  ...     topology,
303  ...     computation_shape=[1, 1, 1, 2],
304  ...     num_replicas=1)
305  >>> strategy = tf.distribute.TPUStrategy(
306  ...     resolver, experimental_device_assignment=device_assignment)
307
308  Then you can run a `tf.add` operation only on logical device 0.
309
310  >>> @tf.function()
311  ... def step_fn(inputs):
312  ...   features, _ = inputs
313  ...   output = tf.add(features, features)
314  ...
315  ...   # Add operation will be executed on logical device 0.
316  ...   output = strategy.experimental_assign_to_logical_device(output, 0)
317  ...   return output
318  >>> dist_dataset = strategy.distribute_datasets_from_function(
319  ...     dataset_fn)
320  >>> iterator = iter(dist_dataset)
321  >>> strategy.run(step_fn, args=(next(iterator),))
322
323  `experimental_spmd_xla_partitioning` enables the experimental XLA SPMD feature
324  for model parallelism. This flag can reduce the compilation time and HBM
325  requirements. When running in this mode, every input tensor must either be
326  partitioned (via `strategy.experimental_split_to_logical_devices`) or fully
327  replicated (via `strategy.experimental_replicate_to_logical_devices`) to all
328  logical devices. And calling `strategy.experimental_assign_to_logical_device`
329  will result in a ValueError in this mode.
330  """
331
332  def __init__(self,
333               tpu_cluster_resolver=None,
334               experimental_device_assignment=None,
335               experimental_spmd_xla_partitioning=False):
336    """Synchronous training in TPU donuts or Pods.
337
338    Args:
339      tpu_cluster_resolver: A
340        `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which
341        provides information about the TPU cluster. If None, it will assume
342        running on a local TPU worker.
343      experimental_device_assignment: Optional
344        `tf.tpu.experimental.DeviceAssignment` to specify the placement of
345        replicas on the TPU cluster.
346      experimental_spmd_xla_partitioning: If True, enable the SPMD (Single
347        Program Multiple Data) mode in XLA compiler. This flag only affects the
348        performance of XLA compilation and the HBM requirement of the compiled
349        TPU program. Ceveat: if this flag is True, calling
350        `tf.distribute.TPUStrategy.experimental_assign_to_logical_device` will
351        result in a ValueError.
352    """
353    super(TPUStrategyV2, self).__init__(
354        TPUExtended(
355            self,
356            tpu_cluster_resolver,
357            device_assignment=experimental_device_assignment,
358            use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning))
359    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
360    distribute_lib.distribution_strategy_replica_gauge.get_cell(
361        "num_workers").set(self.extended.num_hosts)
362    distribute_lib.distribution_strategy_replica_gauge.get_cell(
363        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
364    # Packed variable is used to reduce the overhead of function execution.
365    # For a DistributedVariable, only one variable handle is captured into a
366    # function graph. It's only supported in eager mode.
367    # Packed variable is currently not supported when SPMD is enabled.
368    # TODO(b/202047549): enable Packed variable in SPMD mode.
369    self._enable_packed_variable_in_eager_mode = (
370        not experimental_spmd_xla_partitioning)
371
372  def run(self, fn, args=(), kwargs=None, options=None):
373    """Run the computation defined by `fn` on each TPU replica.
374
375    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
376    `tf.distribute.DistributedValues`, such as those produced by a
377    `tf.distribute.DistributedDataset` from
378    `tf.distribute.Strategy.experimental_distribute_dataset` or
379    `tf.distribute.Strategy.distribute_datasets_from_function`,
380    when `fn` is executed on a particular replica, it will be executed with the
381    component of `tf.distribute.DistributedValues` that correspond to that
382    replica.
383
384    `fn` may call `tf.distribute.get_replica_context()` to access members such
385    as `all_reduce`.
386
387    All arguments in `args` or `kwargs` should either be nest of tensors or
388    `tf.distribute.DistributedValues` containing tensors or composite tensors.
389
390    Example usage:
391
392    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
393    >>> tf.config.experimental_connect_to_cluster(resolver)
394    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
395    >>> strategy = tf.distribute.TPUStrategy(resolver)
396    >>> @tf.function
397    ... def run():
398    ...   def value_fn(value_context):
399    ...     return value_context.num_replicas_in_sync
400    ...   distributed_values = (
401    ...       strategy.experimental_distribute_values_from_function(value_fn))
402    ...   def replica_fn(input):
403    ...     return input * 2
404    ...   return strategy.run(replica_fn, args=(distributed_values,))
405    >>> result = run()
406
407    Args:
408      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
409      args: (Optional) Positional arguments to `fn`.
410      kwargs: (Optional) Keyword arguments to `fn`.
411      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
412        the options to run `fn`.
413
414    Returns:
415      Merged return value of `fn` across replicas. The structure of the return
416      value is the same as the return value from `fn`. Each element in the
417      structure can either be `tf.distribute.DistributedValues`, `Tensor`
418      objects, or `Tensor`s (for example, if running on a single replica).
419    """
420    validate_run_function(fn)
421
422    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
423
424    # Note: the target function is converted to graph even when in Eager mode,
425    # so autograph is on by default here.
426    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
427    options = options or distribute_lib.RunOptions()
428    return self.extended.tpu_run(fn, args, kwargs, options)
429
430  @property
431  def cluster_resolver(self):
432    """Returns the cluster resolver associated with this strategy.
433
434    `tf.distribute.TPUStrategy` provides the associated
435    `tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one
436    in `__init__`, that instance is returned; if the user does not, a default
437    `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
438    """
439    return self.extended._tpu_cluster_resolver  # pylint: disable=protected-access
440
441  def experimental_assign_to_logical_device(self, tensor, logical_device_id):
442    """Adds annotation that `tensor` will be assigned to a logical device.
443
444    This adds an annotation to `tensor` specifying that operations on
445    `tensor` will be invoked on logical core device id `logical_device_id`.
446    When model parallelism is used, the default behavior is that all ops
447    are placed on zero-th logical device.
448
449    ```python
450
451    # Initializing TPU system with 2 logical devices and 4 replicas.
452    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
453    tf.config.experimental_connect_to_cluster(resolver)
454    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
455    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
456        topology,
457        computation_shape=[1, 1, 1, 2],
458        num_replicas=4)
459    strategy = tf.distribute.TPUStrategy(
460        resolver, experimental_device_assignment=device_assignment)
461    iterator = iter(inputs)
462
463    @tf.function()
464    def step_fn(inputs):
465      output = tf.add(inputs, inputs)
466
467      # Add operation will be executed on logical device 0.
468      output = strategy.experimental_assign_to_logical_device(output, 0)
469      return output
470
471    strategy.run(step_fn, args=(next(iterator),))
472    ```
473
474    Args:
475      tensor: Input tensor to annotate.
476      logical_device_id: Id of the logical core to which the tensor will be
477        assigned.
478
479    Raises:
480      ValueError: The logical device id presented is not consistent with total
481      number of partitions specified by the device assignment or the TPUStrategy
482      is constructed with `experimental_spmd_xla_partitioning=True`.
483
484    Returns:
485      Annotated tensor with identical value as `tensor`.
486    """
487    if self.extended._use_spmd_for_xla_partitioning:  # pylint: disable=protected-access
488      raise ValueError(
489          "Cannot assign a tensor to a logical device in SPMD mode. To disable "
490          "SPMD, Please construct the TPUStrategy with "
491          "`experimental_spmd_xla_partitioning=False`")
492
493    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
494    if (logical_device_id < 0 or
495        logical_device_id >= num_logical_devices_per_replica):
496      raise ValueError("`logical_core_id` to assign must be lower then total "
497                       "number of logical devices per replica. Received "
498                       "logical device id {} but there are only total of {} "
499                       "logical devices in replica.".format(
500                           logical_device_id, num_logical_devices_per_replica))
501    return xla_sharding.assign_device(
502        tensor, logical_device_id, use_sharding_op=True)
503
504  def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
505    """Adds annotation that `tensor` will be split across logical devices.
506
507    This adds an annotation to tensor `tensor` specifying that operations on
508    `tensor` will be split among multiple logical devices. Tensor `tensor` will
509    be split across dimensions specified by `partition_dimensions`.
510    The dimensions of `tensor` must be divisible by corresponding value in
511    `partition_dimensions`.
512
513    For example, for system with 8 logical devices, if `tensor` is an image
514    tensor with shape (batch_size, width, height, channel) and
515    `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
516    2 in width dimension and 4 way in height dimension and the split
517    tensor values will be fed into 8 logical devices.
518
519    ```python
520    # Initializing TPU system with 8 logical devices and 1 replica.
521    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
522    tf.config.experimental_connect_to_cluster(resolver)
523    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
524    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
525        topology,
526        computation_shape=[1, 2, 2, 2],
527        num_replicas=1)
528    # Construct the TPUStrategy. Since we are going to split the image across
529    # logical devices, here we set `experimental_spmd_xla_partitioning=True`
530    # so that the partitioning can be compiled in SPMD mode, which usually
531    # results in faster compilation and smaller HBM requirement if the size of
532    # input and activation tensors are much bigger than that of the model
533    # parameters. Note that this flag is suggested but not a hard requirement
534    # for `experimental_split_to_logical_devices`.
535    strategy = tf.distribute.TPUStrategy(
536        resolver, experimental_device_assignment=device_assignment,
537        experimental_spmd_xla_partitioning=True)
538
539    iterator = iter(inputs)
540
541    @tf.function()
542    def step_fn(inputs):
543      inputs = strategy.experimental_split_to_logical_devices(
544        inputs, [1, 2, 4, 1])
545
546      # model() function will be executed on 8 logical devices with `inputs`
547      # split 2 * 4  ways.
548      output = model(inputs)
549      return output
550
551    strategy.run(step_fn, args=(next(iterator),))
552    ```
553    Args:
554      tensor: Input tensor to annotate.
555      partition_dimensions: An unnested list of integers with the size equal to
556        rank of `tensor` specifying how `tensor` will be partitioned. The
557        product of all elements in `partition_dimensions` must be equal to the
558        total number of logical devices per replica.
559
560    Raises:
561      ValueError: 1) If the size of partition_dimensions does not equal to rank
562        of `tensor` or 2) if product of elements of `partition_dimensions` does
563        not match the number of logical devices per replica defined by the
564        implementing DistributionStrategy's device specification or
565        3) if a known size of `tensor` is not divisible by corresponding
566        value in `partition_dimensions`.
567
568    Returns:
569      Annotated tensor with identical value as `tensor`.
570    """
571    num_logical_devices_per_replica = self.extended._tpu_devices.shape[1]  # pylint: disable=protected-access
572    num_partition_splits = np.prod(partition_dimensions)
573    input_shape = tensor.shape
574    tensor_rank = len(input_shape)
575
576    if tensor_rank != len(partition_dimensions):
577      raise ValueError("Length of `partition_dimensions` must equal to the "
578                       "rank of `tensor.shape` ({}). Received "
579                       "len(partition_dimensions)={}.".format(
580                           tensor_rank, len(partition_dimensions)))
581
582    for dim_index, dim_size in enumerate(input_shape):
583      if dim_size is None:
584        continue
585
586      split_size = partition_dimensions[dim_index]
587      if dim_size % split_size != 0:
588        raise ValueError("Tensor shape at `partition_dimensions[{}]` must be "
589                         "divisible by corresponding value specified "
590                         "by `partition_dimensions` ({}). Received: {}.".format(
591                             dim_index, split_size, dim_size))
592
593    if num_partition_splits != num_logical_devices_per_replica:
594      raise ValueError(
595          "The product of `partition_dimensions` should be the same as the "
596          "number of logical devices (={}). Received `partition_dimensions`={},"
597          "and their product is {}.".format(num_logical_devices_per_replica,
598                                            partition_dimensions,
599                                            num_partition_splits))
600
601    tile_assignment = np.arange(num_partition_splits).reshape(
602        partition_dimensions)
603    return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
604
605  def experimental_replicate_to_logical_devices(self, tensor):
606    """Adds annotation that `tensor` will be replicated to all logical devices.
607
608    This adds an annotation to tensor `tensor` specifying that operations on
609    `tensor` will be invoked on all logical devices.
610
611    ```python
612    # Initializing TPU system with 2 logical devices and 4 replicas.
613    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
614    tf.config.experimental_connect_to_cluster(resolver)
615    topology = tf.tpu.experimental.initialize_tpu_system(resolver)
616    device_assignment = tf.tpu.experimental.DeviceAssignment.build(
617        topology,
618        computation_shape=[1, 1, 1, 2],
619        num_replicas=4)
620    strategy = tf.distribute.TPUStrategy(
621        resolver, experimental_device_assignment=device_assignment)
622
623    iterator = iter(inputs)
624
625    @tf.function()
626    def step_fn(inputs):
627      images, labels = inputs
628      images = strategy.experimental_split_to_logical_devices(
629        inputs, [1, 2, 4, 1])
630
631      # model() function will be executed on 8 logical devices with `inputs`
632      # split 2 * 4  ways.
633      output = model(inputs)
634
635      # For loss calculation, all logical devices share the same logits
636      # and labels.
637      labels = strategy.experimental_replicate_to_logical_devices(labels)
638      output = strategy.experimental_replicate_to_logical_devices(output)
639      loss = loss_fn(labels, output)
640
641      return loss
642
643    strategy.run(step_fn, args=(next(iterator),))
644    ```
645    Args:
646      tensor: Input tensor to annotate.
647
648    Returns:
649      Annotated tensor with identical value as `tensor`.
650    """
651    return xla_sharding.replicate(tensor, use_sharding_op=True)
652
653
654@tf_export("distribute.experimental.TPUStrategy", v1=[])
655@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
656class TPUStrategy(distribute_lib.Strategy):
657  """Synchronous training on TPUs and TPU Pods.
658
659  To construct a TPUStrategy object, you need to run the
660  initialization code as below:
661
662  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
663  >>> tf.config.experimental_connect_to_cluster(resolver)
664  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
665  >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
666
667  While using distribution strategies, the variables created within the
668  strategy's scope will be replicated across all the replicas and can be kept in
669  sync using all-reduce algorithms.
670
671  To run TF2 programs on TPUs, you can either use `.compile` and
672  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
673  training loop by calling `strategy.run` directly. Note that
674  TPUStrategy doesn't support pure eager execution, so please make sure the
675  function passed into `strategy.run` is a `tf.function` or
676  `strategy.run` is called inside a `tf.function` if eager
677  behavior is enabled.
678  """
679
680  def __init__(self,
681               tpu_cluster_resolver=None,
682               device_assignment=None):
683    """Synchronous training in TPU donuts or Pods.
684
685    Args:
686      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
687        which provides information about the TPU cluster.
688      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
689        specify the placement of replicas on the TPU cluster.
690    """
691    logging.warning(
692        "`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
693        " the non experimental symbol `tf.distribute.TPUStrategy` instead.")
694
695    super(TPUStrategy, self).__init__(
696        TPUExtended(
697            self, tpu_cluster_resolver, device_assignment=device_assignment))
698    distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
699    distribute_lib.distribution_strategy_replica_gauge.get_cell(
700        "num_workers").set(self.extended.num_hosts)
701    distribute_lib.distribution_strategy_replica_gauge.get_cell(
702        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
703    # Packed variable is used to reduce the overhead of function execution.
704    # For a DistributedVariable, only one variable handle is captured into a
705    # function graph. It's only supported in eager mode.
706    self._enable_packed_variable_in_eager_mode = True
707
708  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
709  # can use the default implementation.
710  # This implementation runs a single step. It does not use infeed or outfeed.
711  def run(self, fn, args=(), kwargs=None, options=None):
712    """See base class."""
713    validate_run_function(fn)
714
715    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
716
717    # Note: the target function is converted to graph even when in Eager mode,
718    # so autograph is on by default here.
719    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
720    options = options or distribute_lib.RunOptions()
721    return self.extended.tpu_run(fn, args, kwargs, options)
722
723  @property
724  def cluster_resolver(self):
725    """Returns the cluster resolver associated with this strategy.
726
727    `tf.distribute.experimental.TPUStrategy` provides the
728    associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
729    provides one in `__init__`, that instance is returned; if the user does
730    not, a default
731    `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
732    """
733    return self.extended._tpu_cluster_resolver  # pylint: disable=protected-access
734
735
736@tf_export(v1=["distribute.experimental.TPUStrategy"])
737class TPUStrategyV1(distribute_lib.StrategyV1):
738  """TPU distribution strategy implementation."""
739
740  def __init__(self,
741               tpu_cluster_resolver=None,
742               steps_per_run=None,
743               device_assignment=None):
744    """Initializes the TPUStrategy object.
745
746    Args:
747      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
748          which provides information about the TPU cluster.
749      steps_per_run: Number of steps to run on device before returning to the
750          host. Note that this can have side-effects on performance, hooks,
751          metrics, summaries etc.
752          This parameter is only used when Distribution Strategy is used with
753          estimator or keras.
754      device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
755          specify the placement of replicas on the TPU cluster. Currently only
756          supports the usecase of using a single core within a TPU cluster.
757    """
758    super(TPUStrategyV1, self).__init__(TPUExtended(
759        self, tpu_cluster_resolver, steps_per_run, device_assignment))
760    distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
761    distribute_lib.distribution_strategy_replica_gauge.get_cell(
762        "num_workers").set(self.extended.num_hosts)
763    distribute_lib.distribution_strategy_replica_gauge.get_cell(
764        "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
765    # Packed variable is used to reduce the overhead of function execution.
766    # For a DistributedVariable, only one variable handle is captured into a
767    # function graph. It's only supported in eager mode.
768    self._enable_packed_variable_in_eager_mode = True
769
770  @property
771  def steps_per_run(self):
772    """DEPRECATED: use .extended.steps_per_run instead."""
773    return self._extended.steps_per_run
774
775  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
776  # can use the default implementation.
777  # This implementation runs a single step. It does not use infeed or outfeed.
778  def run(self, fn, args=(), kwargs=None, options=None):
779    """Run `fn` on each replica, with the given arguments.
780
781    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
782    "per-replica" values, such as those produced by a "distributed `Dataset`",
783    when `fn` is executed on a particular replica, it will be executed with the
784    component of those "per-replica" values that correspond to that replica.
785
786    `fn` may call `tf.distribute.get_replica_context()` to access members such
787    as `all_reduce`.
788
789    All arguments in `args` or `kwargs` should either be nest of tensors or
790    per-replica objects containing tensors or composite tensors.
791
792    Users can pass strategy specific options to `options` argument. An example
793    to enable bucketizing dynamic shapes in `TPUStrategy.run`
794    is:
795
796    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
797    >>> tf.config.experimental_connect_to_cluster(resolver)
798    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
799    >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
800
801    >>> options = tf.distribute.RunOptions(
802    ...     experimental_bucketizing_dynamic_shape=True)
803
804    >>> dataset = tf.data.Dataset.range(
805    ...    strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
806    ...        strategy.num_replicas_in_sync, drop_remainder=True)
807    >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
808
809    >>> @tf.function()
810    ... def step_fn(inputs):
811    ...  output = tf.reduce_sum(inputs)
812    ...  return output
813
814    >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
815
816    Args:
817      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
818      args: (Optional) Positional arguments to `fn`.
819      kwargs: (Optional) Keyword arguments to `fn`.
820      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
821        the options to run `fn`.
822
823    Returns:
824      Merged return value of `fn` across replicas. The structure of the return
825      value is the same as the return value from `fn`. Each element in the
826      structure can either be "per-replica" `Tensor` objects or `Tensor`s
827      (for example, if running on a single replica).
828    """
829    validate_run_function(fn)
830
831    fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
832
833    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
834    options = options or distribute_lib.RunOptions()
835    return self.extended.tpu_run(fn, args, kwargs, options)
836
837
838# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
839class TPUExtended(distribute_lib.StrategyExtendedV1):
840  """Implementation of TPUStrategy."""
841
842  def __init__(self,
843               container_strategy,
844               tpu_cluster_resolver=None,
845               steps_per_run=None,
846               device_assignment=None,
847               use_spmd_for_xla_partitioning=False):
848    super(TPUExtended, self).__init__(container_strategy)
849
850    if tpu_cluster_resolver is None:
851      tpu_cluster_resolver = TPUClusterResolver("")
852
853    if steps_per_run is None:
854      # TODO(frankchn): Warn when we are being used by DS/Keras and this is
855      # not specified.
856      steps_per_run = 1
857
858    # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a
859    # `tf.function` is passed into `strategy.run` in eager mode, the
860    # `tf.function` won't get retraced.
861    self._tpu_function_cache = weakref.WeakKeyDictionary()
862
863    self._tpu_cluster_resolver = tpu_cluster_resolver
864    self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata()
865    self._device_assignment = device_assignment
866
867    tpu_devices_flat = [
868        d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
869
870    # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
871    # indexed using `[replica_id][logical_device_id]`.
872    if device_assignment is None:
873      self._tpu_devices = np.array(
874          [[d] for d in tpu_devices_flat], dtype=object)
875    else:
876      job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
877
878      tpu_devices = []
879      for replica_id in range(device_assignment.num_replicas):
880        replica_devices = []
881
882        for logical_core in range(device_assignment.num_cores_per_replica):
883          replica_devices.append(
884              device_util.canonicalize(
885                  device_assignment.tpu_device(
886                      replica=replica_id,
887                      logical_core=logical_core,
888                      job=job_name)))
889
890        tpu_devices.append(replica_devices)
891      self._tpu_devices = np.array(tpu_devices, dtype=object)
892
893    self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
894
895    # Preload the data onto the TPUs. Currently we always preload onto logical
896    # device 0 for each replica.
897    # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
898    # input onto a different logical device?
899    self._device_input_worker_devices = collections.OrderedDict()
900    self._host_input_worker_devices = collections.OrderedDict()
901    for tpu_device in self._tpu_devices[:, 0]:
902      host_device = device_util.get_host_for_device(tpu_device)
903      self._device_input_worker_devices.setdefault(host_device, [])
904      self._device_input_worker_devices[host_device].append(tpu_device)
905      self._host_input_worker_devices.setdefault(host_device, [])
906      self._host_input_worker_devices[host_device].append(host_device)
907
908    # TODO(sourabhbajaj): Remove this once performance of running one step
909    # at a time is comparable to multiple steps.
910    self.steps_per_run = steps_per_run
911    self._require_static_shapes = True
912
913    self.experimental_enable_get_next_as_optional = True
914
915    self._logical_device_stack = [0]
916
917    if context.executing_eagerly():
918      # In async remote eager, we want to sync the executors before exiting the
919      # program.
920      atexit.register(context.async_wait)
921
922    # Flag to turn on VariablePolicy. Var policy is deprecated because there is
923    # another effort unifying DistributedVariables (see values_v2.py). SPMD XLA
924    # partitioning is not implemented for var policies.
925    # TODO(b/202048882): remove var policy from TPUStrategy.
926    self._use_var_policy = not use_spmd_for_xla_partitioning
927
928    # Flag to enable XLA SPMD partitioning.
929    self._use_spmd_for_xla_partitioning = use_spmd_for_xla_partitioning
930
931  def _validate_colocate_with_variable(self, colocate_with_variable):
932    distribute_utils. validate_colocate(colocate_with_variable, self)
933
934  def _make_dataset_iterator(self, dataset):
935    """Make iterators for each of the TPU hosts."""
936    input_workers = input_lib.InputWorkers(
937        tuple(self._device_input_worker_devices.items()))
938    return input_lib_v1.DatasetIterator(
939        dataset,
940        input_workers,
941        self._container_strategy(),
942        num_replicas_in_sync=self._num_replicas_in_sync)
943
944  def _make_input_fn_iterator(
945      self,
946      input_fn,
947      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
948    input_contexts = []
949    input_workers = input_lib.InputWorkers(
950        tuple(self._device_input_worker_devices.items()))
951    num_workers = input_workers.num_workers
952    for i in range(num_workers):
953      input_contexts.append(
954          distribute_lib.InputContext(
955              num_input_pipelines=num_workers,
956              input_pipeline_id=i,
957              num_replicas_in_sync=self._num_replicas_in_sync))
958    return input_lib_v1.InputFunctionIterator(input_fn, input_workers,
959                                              input_contexts,
960                                              self._container_strategy())
961
962  def _experimental_make_numpy_dataset(self, numpy_input, session):
963    return numpy_dataset.one_host_numpy_dataset(
964        numpy_input, numpy_dataset.SingleDevice(self._host_device),
965        session)
966
967  def _get_input_workers(self, options):
968    if not options or options.experimental_fetch_to_device:
969      return input_lib.InputWorkers(
970          tuple(self._device_input_worker_devices.items()))
971    else:
972      return input_lib.InputWorkers(
973          tuple(self._host_input_worker_devices.items()))
974
975  def _check_spec(self, element_spec):
976    if isinstance(element_spec, values.PerReplicaSpec):
977      element_spec = element_spec._component_specs  # pylint: disable=protected-access
978    specs = nest.flatten_with_joined_string_paths(element_spec)
979    for path, spec in specs:
980      if isinstance(spec, (sparse_tensor.SparseTensorSpec,
981                           ragged_tensor.RaggedTensorSpec)):
982        raise ValueError(
983            "Found tensor {} with spec {}. TPUStrategy does not support "
984            "distributed datasets with device prefetch when using sparse or "
985            "ragged tensors. If you intend to use sparse or ragged tensors, "
986            "please pass a tf.distribute.InputOptions object with "
987            "experimental_fetch_to_device set to False to your dataset "
988            "distribution function.".format(path, type(spec)))
989
990  def _experimental_distribute_dataset(self, dataset, options):
991    if (options and options.experimental_replication_mode ==
992        distribute_lib.InputReplicationMode.PER_REPLICA):
993      raise NotImplementedError(
994          "InputReplicationMode.PER_REPLICA "
995          "is only supported in "
996          "`experimental_distribute_datasets_from_function`."
997      )
998    if options is None or options.experimental_fetch_to_device:
999      self._check_spec(dataset.element_spec)
1000
1001    return input_util.get_distributed_dataset(
1002        dataset,
1003        self._get_input_workers(options),
1004        self._container_strategy(),
1005        num_replicas_in_sync=self._num_replicas_in_sync,
1006        options=options)
1007
1008  def _distribute_datasets_from_function(self, dataset_fn, options):
1009    if (options and options.experimental_replication_mode ==
1010        distribute_lib.InputReplicationMode.PER_REPLICA):
1011      raise NotImplementedError(
1012          "InputReplicationMode.PER_REPLICA "
1013          "is only supported in "
1014          " `experimental_distribute_datasets_from_function` "
1015          "of tf.distribute.MirroredStrategy")
1016    input_workers = self._get_input_workers(options)
1017    input_contexts = []
1018    num_workers = input_workers.num_workers
1019    for i in range(num_workers):
1020      input_contexts.append(distribute_lib.InputContext(
1021          num_input_pipelines=num_workers,
1022          input_pipeline_id=i,
1023          num_replicas_in_sync=self._num_replicas_in_sync))
1024
1025    distributed_dataset = input_util.get_distributed_datasets_from_function(
1026        dataset_fn,
1027        input_workers,
1028        input_contexts,
1029        self._container_strategy(),
1030        options=options)
1031
1032    # We can only check after the dataset_fn is called.
1033    if options is None or options.experimental_fetch_to_device:
1034      self._check_spec(distributed_dataset.element_spec)
1035    return distributed_dataset
1036
1037  def _experimental_distribute_values_from_function(self, value_fn):
1038    per_replica_values = []
1039    for replica_id in range(self._num_replicas_in_sync):
1040      per_replica_values.append(
1041          value_fn(distribute_lib.ValueContext(replica_id,
1042                                               self._num_replicas_in_sync)))
1043    return distribute_utils.regroup(per_replica_values, always_wrap=True)
1044
1045  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
1046  # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
1047  # a mechanism to infer the outputs of `fn`. Pending b/110550782.
1048  def _experimental_run_steps_on_iterator(
1049      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
1050    # Wrap `fn` for repeat.
1051    if initial_loop_values is None:
1052      initial_loop_values = {}
1053    initial_loop_values = nest.flatten(initial_loop_values)
1054    ctx = input_lib.MultiStepContext()
1055
1056    def run_fn(inputs):
1057      """Single step on the TPU device."""
1058      fn_result = fn(ctx, inputs)
1059      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
1060      if flat_last_step_outputs:
1061        with ops.control_dependencies([fn_result]):
1062          return [array_ops.identity(f) for f in flat_last_step_outputs]
1063      else:
1064        return fn_result
1065
1066    # We capture the control_flow_context at this point, before we run `fn`
1067    # inside a while_loop and TPU replicate context. This is useful in cases
1068    # where we might need to exit these contexts and get back to the outer
1069    # context to do some things, for e.g. create an op which should be
1070    # evaluated only once at the end of the loop on the host. One such usage
1071    # is in creating metrics' value op.
1072    self._outer_control_flow_context = (
1073        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
1074
1075    def rewrite_fn(*args):
1076      """The rewritten step fn running on TPU."""
1077      del args
1078
1079      per_replica_inputs = multi_worker_iterator.get_next()
1080      replicate_inputs = []
1081      for replica_id in range(self._num_replicas_in_sync):
1082        select_replica = lambda x: distribute_utils.select_replica(  # pylint: disable=g-long-lambda
1083            replica_id, x)   # pylint: disable=cell-var-from-loop
1084        replicate_inputs.append((nest.map_structure(
1085            select_replica, per_replica_inputs),))
1086
1087      replicate_outputs = tpu.replicate(
1088          run_fn,
1089          replicate_inputs,
1090          device_assignment=self._device_assignment,
1091          xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
1092                                     ._use_spmd_for_xla_partitioning))
1093      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
1094      # will flatten it in this case. If run_fn has no tensor outputs,
1095      # tpu.replicate returns a list of no_ops, we will keep the output as it
1096      # is.
1097      if isinstance(replicate_outputs[0], list):
1098        replicate_outputs = nest.flatten(replicate_outputs)
1099
1100      return replicate_outputs
1101
1102    # TODO(sourabhbajaj): The input to while loop should be based on the
1103    # output type of the step_fn
1104    assert isinstance(initial_loop_values, list)
1105    initial_loop_values = initial_loop_values * self._num_replicas_in_sync
1106
1107    # Put the while loop op on TPU host 0.
1108    with ops.device(self._host_device):
1109      if self.steps_per_run == 1:
1110        replicate_outputs = rewrite_fn()
1111      else:
1112        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
1113                                                 initial_loop_values)
1114
1115    del self._outer_control_flow_context
1116    ctx.run_op = control_flow_ops.group(replicate_outputs)
1117
1118    if isinstance(replicate_outputs, list):
1119      # Filter out any ops from the outputs, typically this would be the case
1120      # when there were no tensor outputs.
1121      last_step_tensor_outputs = [
1122          x for x in replicate_outputs if not isinstance(x, ops.Operation)
1123      ]
1124
1125      # Outputs are currently of the structure (flattened)
1126      # [output0_device0, output1_device0, output2_device0,
1127      #  output0_device1, output1_device1, output2_device1,
1128      #  ...]
1129      # Convert this to the following structure instead: (grouped by output)
1130      # [[output0_device0, output0_device1],
1131      #  [output1_device0, output1_device1],
1132      #  [output2_device0, output2_device1]]
1133      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
1134      last_step_tensor_outputs = [
1135          last_step_tensor_outputs[i::output_num] for i in range(output_num)
1136      ]
1137    else:
1138      # no tensors returned.
1139      last_step_tensor_outputs = []
1140
1141    _set_last_step_outputs(ctx, last_step_tensor_outputs)
1142    return ctx
1143
1144  def _call_for_each_replica(self, fn, args, kwargs):
1145    # TODO(jhseu): Consider making it so call_for_each_replica implies that
1146    # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
1147    with _TPUReplicaContext(self._container_strategy()):
1148      return fn(*args, **kwargs)
1149
1150  @contextlib.contextmanager
1151  def experimental_logical_device(self, logical_device_id):
1152    """Places variables and ops on the specified logical device."""
1153    num_logical_devices_per_replica = self._tpu_devices.shape[1]
1154    if logical_device_id >= num_logical_devices_per_replica:
1155      raise ValueError(
1156          "`logical_device_id` not in range (was {}, but there are only {} "
1157          "logical devices per replica).".format(
1158              logical_device_id, num_logical_devices_per_replica))
1159
1160    self._logical_device_stack.append(logical_device_id)
1161    try:
1162      if tpu_util.enclosing_tpu_context() is None:
1163        yield
1164      else:
1165        with ops.device(tpu.core(logical_device_id)):
1166          yield
1167    finally:
1168      self._logical_device_stack.pop()
1169
1170  def _experimental_initialize_system(self):
1171    """Experimental method added to be used by Estimator.
1172
1173    This is a private method only to be used by Estimator. Other frameworks
1174    should directly be calling `tf.tpu.experimental.initialize_tpu_system`
1175    """
1176    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
1177
1178  def _create_variable(self, next_creator, **kwargs):
1179    """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
1180    if kwargs.pop("skip_mirrored_creator", False):
1181      return next_creator(**kwargs)
1182
1183    colocate_with = kwargs.pop("colocate_with", None)
1184    if colocate_with is None:
1185      devices = self._tpu_devices[:, self._logical_device_stack[-1]]
1186    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
1187      with ops.device(colocate_with.device):
1188        return next_creator(**kwargs)
1189    else:
1190      devices = colocate_with._devices  # pylint: disable=protected-access
1191
1192    num_replicas, num_cores_per_replica = self._tpu_devices.shape
1193
1194    def _create_mirrored_tpu_variables(**kwargs):
1195      """Returns a list of `tf.Variable`s.
1196
1197      The list contains `number_replicas` `tf.Variable`s and can be used to
1198      initialize a `TPUMirroredVariable`.
1199
1200      Args:
1201        **kwargs: the keyword arguments for creating a variable
1202      """
1203      initial_value = None
1204      value_list = []
1205      for i, d in enumerate(devices):
1206        with ops.device(d):
1207          if i == 0:
1208            initial_value = kwargs["initial_value"]
1209            # Note: some v1 code expects variable initializer creation to happen
1210            # inside a init_scope.
1211            with maybe_init_scope():
1212              initial_value = initial_value() if callable(
1213                  initial_value) else initial_value
1214
1215          if i > 0:
1216            # Give replicas meaningful distinct names:
1217            var0name = value_list[0].name.split(":")[0]
1218            # We append a / to variable names created on replicas with id > 0 to
1219            # ensure that we ignore the name scope and instead use the given
1220            # name as the absolute name of the variable.
1221            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
1222          kwargs["initial_value"] = initial_value
1223
1224          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
1225            v = next_creator(**kwargs)
1226
1227          assert not isinstance(v, tpu_values.TPUMirroredVariable)
1228          value_list.append(v)
1229      return value_list
1230
1231    def _create_mirrored_tpu_replicated_variables(**kwargs):
1232      """Returns a list of `TPUReplicatedVariable`s.
1233
1234      The list consists of `num_replicas` `TPUReplicatedVariable`s and can be
1235      used to initialize a `TPUMirroredVariable`. Each `TPUReplicatedVariable`
1236      contains a list of `tf.Variable`s which are replicated to
1237      `num_cores_per_replica` logical cores to enable XLA SPMD compilation.
1238
1239      Args:
1240        **kwargs: the keyword arguments for creating a variable
1241      """
1242      initial_value = kwargs["initial_value"]
1243      # Note: some v1 code expects variable initializer creation to happen
1244      # inside a init_scope.
1245      with maybe_init_scope():
1246        initial_value = initial_value() if callable(
1247            initial_value) else initial_value
1248
1249      mirrored_replicated_var_list = []
1250
1251      for replica_id in range(num_replicas):
1252        replicated_var_list = []
1253        for logic_core_id in range(num_cores_per_replica):
1254          with ops.device(self._tpu_devices[replica_id][logic_core_id]):
1255            kwargs["initial_value"] = initial_value
1256            v = next_creator(**kwargs)
1257          replicated_var_list.append(v)
1258        replica_name = "{}/r:{}".format(kwargs["name"], replica_id)
1259        tpu_replicated_var = tpu_replicated_variable.TPUReplicatedVariable(
1260            variables=replicated_var_list, name=replica_name)
1261
1262        mirrored_replicated_var_list.append(tpu_replicated_var)
1263      return mirrored_replicated_var_list
1264
1265    if self._use_spmd_for_xla_partitioning and num_cores_per_replica > 1:
1266      real_creator = _create_mirrored_tpu_replicated_variables
1267    else:
1268      real_creator = _create_mirrored_tpu_variables
1269
1270    return distribute_utils.create_mirrored_variable(
1271        self._container_strategy(), real_creator,
1272        distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
1273        distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
1274
1275  def _resource_creator_scope(self):
1276
1277    def lookup_creator(next_creator, *args, **kwargs):
1278      host_to_table = collections.OrderedDict()
1279      for host_device in self._device_input_worker_devices.keys():
1280        with ops.device(host_device):
1281          host_to_table[host_device] = next_creator(*args, **kwargs)
1282
1283      return values.PerWorkerResource(self._container_strategy(), host_to_table)
1284
1285    # TODO(b/194362531): Define creator(s) for other resources.
1286    return ops.resource_creator_scope("StaticHashTable", lookup_creator)
1287
1288  def _gather_to_implementation(self, value, destinations, axis, options):
1289    if not isinstance(value, values.DistributedValues):
1290      return value
1291
1292    value_list = list(value.values)
1293    # pylint: disable=protected-access
1294    if isinstance(
1295        value,
1296        values.DistributedVariable) and value._packed_variable is not None:
1297      value_list = list(
1298          value._packed_variable.on_device(d)
1299          for d in value._packed_variable.devices)
1300    # pylint: enable=protected-access
1301
1302    # Currently XLA op by op mode has a limit for the number of inputs for a
1303    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1304    # work around the constraint.
1305    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1306      output = array_ops.concat(value_list, axis=axis)
1307    else:
1308      output = array_ops.concat(
1309          value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis)
1310      for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list),
1311                     _XLA_OP_BY_OP_INPUTS_LIMIT - 1):
1312        output = array_ops.concat(
1313            [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1],
1314            axis=axis)
1315
1316    output = self._broadcast_output(destinations, output)
1317    return output
1318
1319  def _broadcast_output(self, destinations, output):
1320    devices = cross_device_ops_lib.get_devices_from(destinations)
1321
1322    if len(devices) == 1:
1323      # If necessary, copy to requested destination.
1324      dest_canonical = device_util.canonicalize(devices[0])
1325      host_canonical = device_util.canonicalize(self._host_device)
1326
1327      if dest_canonical != host_canonical:
1328        with ops.device(dest_canonical):
1329          output = array_ops.identity(output)
1330    else:
1331      output = cross_device_ops_lib.simple_broadcast(output, destinations)
1332
1333    return output
1334
1335  def _reduce_to(self, reduce_op, value, destinations, options):
1336    if (isinstance(value, values.DistributedValues) or
1337        tensor_util.is_tf_type(value)
1338       ) and tpu_util.enclosing_tpu_context() is not None:
1339      if reduce_op == reduce_util.ReduceOp.MEAN:
1340        # TODO(jhseu):  Revisit once we support model-parallelism.
1341        # scalar_mul maintains the type of value: tensor or IndexedSlices.
1342        value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value)
1343      elif reduce_op != reduce_util.ReduceOp.SUM:
1344        raise NotImplementedError(
1345            f"`reduce_op`={reduce_op} is not supported. Currently we only "
1346            "support ReduceOp.SUM and ReduceOp.MEAN in TPUStrategy.")
1347      return tpu_ops.cross_replica_sum(value)
1348
1349    if not isinstance(value, values.DistributedValues):
1350      # This function handles reducing values that are not PerReplica or
1351      # Mirrored values. For example, the same value could be present on all
1352      # replicas in which case `value` would be a single value or value could
1353      # be 0.
1354      return cross_device_ops_lib.reduce_non_distributed_value(
1355          reduce_op, value, destinations, self._num_replicas_in_sync)
1356
1357    value_list = value.values
1358    # pylint: disable=protected-access
1359    if isinstance(
1360        value,
1361        values.DistributedVariable) and value._packed_variable is not None:
1362      value_list = tuple(
1363          value._packed_variable.on_device(d)
1364          for d in value._packed_variable.devices)
1365    # pylint: enable=protected-access
1366
1367    # Currently XLA op by op mode has a limit for the number of inputs for a
1368    # single op, thus we break one `add_n` op into a group of `add_n` ops to
1369    # work around the constraint.
1370    # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
1371    if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1372      output = math_ops.add_n(value_list)
1373    else:
1374      output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
1375      for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
1376        output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
1377
1378    if reduce_op == reduce_util.ReduceOp.MEAN:
1379      output *= (1. / len(value_list))
1380
1381    output = self._broadcast_output(destinations, output)
1382    return output
1383
1384  def _update(self, var, fn, args, kwargs, group):
1385    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1386        var, resource_variable_ops.BaseResourceVariable)
1387    if tpu_util.enclosing_tpu_context() is not None:
1388      if group:
1389        return fn(var, *args, **kwargs)
1390      else:
1391        return (fn(var, *args, **kwargs),)
1392
1393    # Inside `tf.function`, we don't expand PackedVariable in python as it will
1394    # be expanded later during function instantiation in the runtime.
1395    packed_var = var._packed_variable  # pylint: disable=protected-access
1396    if packed_var is not None and not context.executing_eagerly():
1397      if group:
1398        return fn(packed_var, *args, **kwargs)
1399      else:
1400        return (fn(packed_var, *args, **kwargs),)
1401
1402    # Otherwise, we revert to MirroredStrategy behavior and update the variable
1403    # on each replica directly.
1404    updates = []
1405    values_and_devices = []
1406    if packed_var is not None:
1407      for device in packed_var.devices:
1408        values_and_devices.append((packed_var, device))
1409    else:
1410      for value in var.values:
1411        values_and_devices.append((value, value.device))
1412
1413    if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
1414        var.aggregation != variables_lib.VariableAggregation.NONE):
1415      distribute_utils.assert_mirrored(args)
1416      distribute_utils.assert_mirrored(kwargs)
1417    for i, value_and_device in enumerate(values_and_devices):
1418      value = value_and_device[0]
1419      device = value_and_device[1]
1420      name = "update_%d" % i
1421      with ops.device(device), \
1422           distribute_lib.UpdateContext(i), \
1423           ops.name_scope(name):
1424        # If args and kwargs are not mirrored, the value is returned as is.
1425        updates.append(
1426            fn(value, *distribute_utils.select_replica(i, args),
1427               **distribute_utils.select_replica(i, kwargs)))
1428    return distribute_utils.update_regroup(self, updates, group)
1429
1430  def read_var(self, var):
1431    assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1432        var, resource_variable_ops.BaseResourceVariable)
1433    return var.read_value()
1434
1435  def value_container(self, value):
1436    return value
1437
1438  def _broadcast_to(self, tensor, destinations):
1439    del destinations
1440    # This is both a fast path for Python constants, and a way to delay
1441    # converting Python values to a tensor until we know what type it
1442    # should be converted to. Otherwise we have trouble with:
1443    #   global_step.assign_add(1)
1444    # since the `1` gets broadcast as an int32 but global_step is int64.
1445    if isinstance(tensor, (float, int)):
1446      return tensor
1447    if tpu_util.enclosing_tpu_context() is not None:
1448      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
1449      result = tpu_ops.all_to_all(
1450          broadcast_tensor,
1451          concat_dimension=0,
1452          split_dimension=0,
1453          split_count=self._num_replicas_in_sync)
1454
1455      # This uses the broadcasted value from the first replica because the only
1456      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
1457      return result[0]
1458    return tensor
1459
1460  @property
1461  def num_hosts(self):
1462    if self._device_assignment is None:
1463      return self._tpu_metadata.num_hosts
1464
1465    return len(set([self._device_assignment.host_device(r)
1466                    for r in range(self._device_assignment.num_replicas)]))
1467
1468  @property
1469  def num_replicas_per_host(self):
1470    if self._device_assignment is None:
1471      return self._tpu_metadata.num_of_cores_per_host
1472
1473    # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
1474    # as the computation of num_replicas_per_host is not a constant
1475    # when using device_assignment. This is a temporary workaround to support
1476    # StatefulRNN as everything is 1 in that case.
1477    # This method needs to take host_id as input for correct computation.
1478    max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
1479                           self._device_assignment.num_cores_per_replica)
1480    return min(self._device_assignment.num_replicas, max_models_per_host)
1481
1482  @property
1483  def _num_replicas_in_sync(self):
1484    if self._device_assignment is None:
1485      return self._tpu_metadata.num_cores
1486    return self._device_assignment.num_replicas
1487
1488  @property
1489  def experimental_between_graph(self):
1490    return False
1491
1492  @property
1493  def experimental_should_init(self):
1494    return True
1495
1496  @property
1497  def should_checkpoint(self):
1498    return True
1499
1500  @property
1501  def should_save_summary(self):
1502    return True
1503
1504  @property
1505  def worker_devices(self):
1506    return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
1507
1508  @property
1509  def parameter_devices(self):
1510    return self.worker_devices
1511
1512  @property
1513  def tpu_hardware_feature(self):
1514    """Return the `tf.tpu.experimental.HardwareFeature` class."""
1515    return tpu_hardware_feature.HardwareFeature(
1516        self._tpu_cluster_resolver.tpu_hardware_feature)
1517
1518  def non_slot_devices(self, var_list):
1519    return self._host_device
1520
1521  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1522    del colocate_with
1523    with ops.device(self._host_device), distribute_lib.UpdateContext(None):
1524      result = fn(*args, **kwargs)
1525      if group:
1526        return result
1527      else:
1528        return nest.map_structure(self._local_results, result)
1529
1530  def _configure(self,
1531                 session_config=None,
1532                 cluster_spec=None,
1533                 task_type=None,
1534                 task_id=None):
1535    del cluster_spec, task_type, task_id
1536    if session_config:
1537      session_config.CopyFrom(self._update_config_proto(session_config))
1538
1539  def _update_config_proto(self, config_proto):
1540    updated_config = copy.deepcopy(config_proto)
1541    updated_config.isolate_session_state = True
1542    cluster_spec = self._tpu_cluster_resolver.cluster_spec()
1543    if cluster_spec:
1544      updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
1545    return updated_config
1546
1547  # TODO(priyag): Delete this once all strategies use global batch size.
1548  @property
1549  def _global_batch_size(self):
1550    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
1551
1552    `make_input_fn_iterator` assumes per-replica batching.
1553
1554    Returns:
1555      Boolean.
1556    """
1557    return True
1558
1559  def tpu_run(self, fn, args, kwargs, options=None):
1560    func = self._tpu_function_creator(fn, options)
1561    return func(args, kwargs)
1562
1563  def _tpu_function_creator(self, fn, options):
1564    if context.executing_eagerly() and fn in self._tpu_function_cache:
1565      return self._tpu_function_cache[fn]
1566
1567    strategy = self._container_strategy()
1568
1569    def tpu_function(args, kwargs):
1570      """TF Function used to replicate the user computation."""
1571      logging.vlog(1,
1572                   "`TPUStrategy.run` is called with [args: %s] [kwargs: %s]",
1573                   args, kwargs)
1574
1575      if kwargs is None:
1576        kwargs = {}
1577
1578      # Used to re-structure flattened output tensors from `tpu.replicate()`
1579      # into a structured format.
1580      result = [[]]
1581
1582      def replicated_fn(replica_id, replica_args, replica_kwargs):
1583        """Wraps user function to provide replica ID and `Tensor` inputs."""
1584        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
1585          result[0] = fn(*replica_args, **replica_kwargs)
1586        return result[0]
1587
1588      replicate_inputs = []  # By replica.
1589      for i in range(strategy.num_replicas_in_sync):
1590        replicate_inputs.append(
1591            [constant_op.constant(i, dtype=dtypes.int32),
1592             distribute_utils.select_replica(i, args),
1593             distribute_utils.select_replica(i, kwargs)])
1594
1595      # Construct and pass `maximum_shapes` so that we could support dynamic
1596      # shapes using dynamic padder.
1597      if options.experimental_enable_dynamic_batch_size and replicate_inputs:
1598        maximum_shapes = []
1599        flattened_list = nest.flatten(replicate_inputs[0])
1600        for input_tensor in flattened_list:
1601          if tensor_util.is_tf_type(input_tensor):
1602            rank = input_tensor.shape.rank
1603          else:
1604            rank = np.ndim(input_tensor)
1605          if rank is None:
1606            raise ValueError(
1607                "input tensor {} to TPUStrategy.run() has unknown rank, "
1608                "which is not allowed".format(input_tensor))
1609          maximum_shape = tensor_shape.TensorShape([None] * rank)
1610          maximum_shapes.append(maximum_shape)
1611        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
1612                                               maximum_shapes)
1613      else:
1614        maximum_shapes = None
1615
1616      if options.experimental_bucketizing_dynamic_shape:
1617        padding_spec = tpu.PaddingSpec.POWER_OF_TWO
1618      else:
1619        padding_spec = None
1620
1621      with strategy.scope():
1622        xla_options = options.experimental_xla_options or tpu.XLAOptions(
1623            use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning)
1624        replicate_outputs = tpu.replicate(
1625            replicated_fn,
1626            replicate_inputs,
1627            device_assignment=self._device_assignment,
1628            maximum_shapes=maximum_shapes,
1629            padding_spec=padding_spec,
1630            xla_options=xla_options)
1631
1632      # Remove all no ops that may have been added during 'tpu.replicate()'
1633      filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)]
1634      if isinstance(result[0], list):
1635        result[0] = filter_ops(result[0])
1636
1637      # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
1638      if result[0] is None or isinstance(result[0], ops.Operation):
1639        replicate_outputs = [None] * len(replicate_outputs)
1640      else:
1641        replicate_outputs = [
1642            nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output)))
1643            for output in replicate_outputs
1644        ]
1645      return distribute_utils.regroup(replicate_outputs)
1646
1647    if context.executing_eagerly():
1648      tpu_function = def_function.function(tpu_function)
1649      self._tpu_function_cache[fn] = tpu_function
1650    return tpu_function
1651
1652  def _in_multi_worker_mode(self):
1653    """Whether this strategy indicates working in multi-worker settings."""
1654    # TPUStrategy has different distributed training structure that the whole
1655    # cluster should be treated as single worker from higher-level (e.g. Keras)
1656    # library's point of view.
1657    # TODO(rchao): Revisit this as we design a fault-tolerance solution for
1658    # TPUStrategy.
1659    return False
1660
1661  def _get_local_replica_id(self, replica_id_in_sync_group):
1662    return replica_id_in_sync_group
1663
1664
1665def _make_axis_nonnegative(axis, rank):
1666  # Convert a potentially negative `axis` to a non-negative one.
1667  if isinstance(axis, int):
1668    if axis >= 0:
1669      return axis
1670    else:
1671      return axis + rank
1672  else:
1673    return array_ops.where_v2(
1674        math_ops.greater_equal(axis, 0),
1675        axis,
1676        axis + rank)
1677
1678
1679# List of Tensor dtypes supported by cross_replica_sum().
1680_DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM = (
1681    dtypes.bfloat16,
1682    dtypes.float16,
1683    dtypes.float32,
1684    dtypes.float64,
1685    dtypes.int32,
1686    dtypes.uint32,
1687)
1688
1689
1690class _TPUReplicaContext(distribute_lib.ReplicaContext):
1691  """Replication Context class for TPU Strategy."""
1692
1693  # TODO(sourabhbajaj): Call for each replica should be updating this.
1694  # TODO(b/118385803): Always properly initialize replica_id.
1695  def __init__(self, strategy, replica_id_in_sync_group=0):
1696    distribute_lib.ReplicaContext.__init__(
1697        self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
1698
1699  @property
1700  def devices(self):
1701    distribute_lib.require_replica_context(self)
1702    ds = self._strategy
1703    replica_id = tensor_util.constant_value(self.replica_id_in_sync_group)
1704
1705    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
1706      # TODO(cjfj): Return other devices when model parallelism is supported.
1707      return (tpu.core(0),)
1708    else:
1709      return (ds.extended.worker_devices[replica_id],)
1710
1711  def experimental_logical_device(self, logical_device_id):
1712    """Places variables and ops on the specified logical device."""
1713    return self.strategy.extended.experimental_logical_device(logical_device_id)
1714
1715  def _compute_all_gather_output_shape(self, value_shape, value_rank, axis):
1716    if isinstance(value_rank, int):
1717      output_shape = list(value_shape)
1718      output_shape[axis] *= self.num_replicas_in_sync
1719    else:
1720      output_shape = array_ops.where_v2(
1721          math_ops.equal(math_ops.range(value_rank), axis),
1722          value_shape * context.num_replicas_in_sync,
1723          value_shape)
1724    return output_shape
1725
1726  def all_gather(self, value, axis, experimental_hints=None):
1727    del experimental_hints
1728    for v in nest.flatten(value):
1729      if isinstance(v, indexed_slices.IndexedSlices):
1730        raise NotImplementedError("all_gather does not support IndexedSlices")
1731
1732    def _all_gather_tensor(value, axis):
1733      value = ops.convert_to_tensor(value)
1734
1735      # Compute the shape and rank and rank of the input tensor. Use static
1736      # shapes when possible to help with shape inference in graph mode, but
1737      # fall back on dynamic shapes when necessary.
1738      if value.shape.rank is None:
1739        value_rank = array_ops.rank(value)
1740        value_shape = array_ops.shape(value)
1741      else:
1742        value_rank = value.shape.rank
1743        value_shape = value.shape.as_list()
1744        value_shape_tensor = array_ops.shape(value)
1745        for i in range(len(value_shape)):
1746          if value_shape[i] is None:
1747            value_shape[i] = value_shape_tensor[i]
1748
1749      # In the code below, we will insert a new "replica" dimension immediately
1750      # *before* `axis`. To ensure that it's inserted before and not after, we
1751      # must make `axis` non-negative.
1752      axis = _make_axis_nonnegative(axis, value_rank)
1753
1754      # Create a list or 1D int Tensor such as
1755      #     [1, 1, ..., 1, num_replicas_in_sync, 1, ..., 1],
1756      # which is equal to `num_replicas_in_sync` at index `axis`
1757      # and is equal to 1 everywhere else.
1758      if isinstance(value_rank, int):
1759        replica_broadcast_shape = [1] * (value_rank + 1)
1760        replica_broadcast_shape[axis] = self.num_replicas_in_sync
1761      else:
1762        replica_broadcast_shape = array_ops.where_v2(
1763            math_ops.equal(math_ops.range(value_rank+1), axis),
1764            self.num_replicas_in_sync,
1765            1)
1766
1767      output_shape = self._compute_all_gather_output_shape(
1768          value_shape, value_rank, axis)
1769
1770      if value.dtype in _DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM:
1771        # optimized all_gather implementation based on cross_replica_sum().
1772        replica_id_mask = array_ops.one_hot(
1773            self.replica_id_in_sync_group, self.num_replicas_in_sync)
1774        replica_id_mask = array_ops.reshape(
1775            replica_id_mask, replica_broadcast_shape)
1776        replica_id_mask = math_ops.cast(replica_id_mask, value.dtype)
1777
1778        gathered_value = array_ops.expand_dims(value, axis) * replica_id_mask
1779        gathered_value = self.all_reduce(
1780            reduce_util.ReduceOp.SUM, gathered_value)
1781        return array_ops.reshape(gathered_value, output_shape)
1782      else:
1783        # value.dtype isn't supported by cross_replica_sum(), so we fall back
1784        # on a less efficient implementation based on all_to_all().
1785
1786        # The underlying AllToAllOp first do a split of the input value and then
1787        # cross-replica communication and concatenation of the result. So we
1788        # concatenate the local tensor here first.
1789        inputs = array_ops.expand_dims(value, axis=axis)
1790        inputs = array_ops.tile(inputs, replica_broadcast_shape)
1791        unordered_output = tpu_ops.all_to_all(
1792            inputs,
1793            concat_dimension=axis,
1794            split_dimension=axis,
1795            split_count=self.num_replicas_in_sync)
1796
1797        # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch.
1798        # Start by computing a permutation -- a 1D Tensor which maps
1799        #     tensor[xla.replica_id] = ReplicaContext.replica_id
1800        concat_replica_id = array_ops.reshape(
1801            self.replica_id_in_sync_group, [1])
1802        concat_replica_id = array_ops.tile(
1803            concat_replica_id, [self.num_replicas_in_sync])
1804        xla_to_replica_context_id = tpu_ops.all_to_all(
1805            concat_replica_id,
1806            concat_dimension=0,
1807            split_dimension=0,
1808            split_count=self.num_replicas_in_sync)
1809
1810        # Now invert the mapping to get
1811        #    tensor[ReplicaContext.replica_id] = xla.replica_id
1812        replica_context_to_xla_id = math_ops.argmax(
1813            array_ops.one_hot(xla_to_replica_context_id,
1814                              self.num_replicas_in_sync),
1815            axis=0)
1816
1817        # Reorder the output elements so that they're sorted based on
1818        # ReplicaContext.replica_id instead of xla.replica_id.
1819        sorted_with_extra_dim = array_ops.gather(
1820            unordered_output, replica_context_to_xla_id, axis=axis)
1821        return array_ops.reshape(sorted_with_extra_dim, output_shape)
1822
1823    ys = [_all_gather_tensor(t, axis=axis) for t in nest.flatten(value)]
1824    return nest.pack_sequence_as(value, ys)
1825
1826
1827def _set_last_step_outputs(ctx, last_step_tensor_outputs):
1828  """Sets the last step outputs on the given context."""
1829  # Convert replicate_outputs to the original dict structure of
1830  # last_step_outputs.
1831  last_step_tensor_outputs_dict = nest.pack_sequence_as(
1832      ctx.last_step_outputs, last_step_tensor_outputs)
1833
1834  for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
1835    output = last_step_tensor_outputs_dict[name]
1836    # For outputs that aren't reduced, return a PerReplica of all values. Else
1837    # take the first value from the list as each value should be the same.
1838    if reduce_op is None:
1839      last_step_tensor_outputs_dict[name] = values.PerReplica(output)
1840    else:
1841      # TODO(priyag): Should this return the element or a list with 1 element
1842      last_step_tensor_outputs_dict[name] = output[0]
1843  ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
1844