xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/input_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17import functools
18import sys
19import time
20
21import six
22
23from tensorflow.python.data.experimental.ops import batching
24from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
25from tensorflow.python.data.experimental.ops import distribute
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.data.ops import multi_device_iterator_ops
29from tensorflow.python.data.ops import optional_ops
30from tensorflow.python.distribute import device_util
31from tensorflow.python.distribute import distribute_lib
32from tensorflow.python.distribute import distribute_utils
33from tensorflow.python.distribute import distribution_strategy_context
34from tensorflow.python.distribute import input_ops
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import values
37from tensorflow.python.distribute.distribute_lib import InputReplicationMode
38from tensorflow.python.eager import context
39from tensorflow.python.eager import monitoring
40from tensorflow.python.framework import composite_tensor
41from tensorflow.python.framework import device as tf_device
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_util
48from tensorflow.python.framework import type_spec
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import control_flow_ops
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops.ragged import ragged_tensor
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.types import distribute as distribute_types
55from tensorflow.python.util import nest
56from tensorflow.python.util.compat import collections_abc
57from tensorflow.python.util.tf_export import tf_export
58from tensorflow.tools.docs import doc_controls
59
60
61_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler(
62    "/tensorflow/api/distribution_strategy/"
63    "distributed_dataset_initialization_time_milliseconds",
64    monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26),
65    "Track the time (in milliseconds) to initialize distributed datasets.",
66    "strategy", "workers")
67
68_distributed_dataset_from_function_initialization_time_milliseconds = (
69    monitoring.Sampler(
70        "/tensorflow/api/distribution_strategy/"
71        "distributed_dataset_from_function_initialization_time_milliseconds",
72        monitoring.ExponentialBuckets(
73            scale=1, growth_factor=2, bucket_count=26),
74        "Track the time (in milliseconds) to initialize distributed datasets "
75        "from function.",
76        "strategy", "workers"))
77
78
79def get_iterator_spec_from_dataset(strategy, dataset):
80  """Returns an iterator spec from dataset function.
81
82  This function constructs type spec for iterator obtained from
83  iter(dataset).
84
85  Args:
86    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
87        handle last partial batch.
88    dataset: A tf.data.Dataset instance. If using a function that returns a
89      tf.data.Dataset instance, pass dataset_fn.structured_outputs.
90
91  Returns:
92    A type_spec for iterator for dataset instance.
93
94  """
95  # pylint: disable=protected-access
96  output_element_spec = dataset.element_spec
97  if isinstance(dataset._type_spec,
98                (DistributedDatasetSpec,
99                 DistributedDatasetsFromFunctionSpec)):
100    iterator_type_spec = DistributedIteratorSpec(
101        strategy.extended._input_workers_with_options(),
102        output_element_spec,
103        strategy.extended._container_strategy(),
104        options=None,
105        cardinality=dataset.cardinality,
106        enable_get_next_as_optional=True)
107  else:
108    if strategy.extended._num_gpus_per_worker:
109      logging.warning(
110          f"{strategy.extended._num_gpus_per_worker} GPUs "
111          "are allocated per worker. Please use DistributedDataset by "
112          "calling strategy.experimental_distribute_dataset or strategy."
113          "distribute_datasets_from_function to make best use of GPU "
114          "resources"
115      )
116    iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec)
117  return iterator_type_spec
118  # pylint: enable=protected-access
119
120
121@tf_export("distribute.DistributedIterator", v1=[])
122class DistributedIteratorInterface(collections_abc.Iterator,
123                                   distribute_types.Iterator):
124  """An iterator over `tf.distribute.DistributedDataset`.
125
126  `tf.distribute.DistributedIterator` is the primary mechanism for enumerating
127  elements of a `tf.distribute.DistributedDataset`. It supports the Python
128  Iterator protocol, which means it can be iterated over using a for-loop or by
129  fetching individual elements explicitly via `get_next()`.
130
131  You can create a `tf.distribute.DistributedIterator` by calling `iter` on
132  a `tf.distribute.DistributedDataset` or creating a python loop over a
133  `tf.distribute.DistributedDataset`.
134
135  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
136  on distributed input for more examples and caveats.
137  """
138
139  def get_next(self):
140    """Returns the next input from the iterator for all replicas.
141
142    Example use:
143
144    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
145    >>> dataset = tf.data.Dataset.range(100).batch(2)
146    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
147    >>> dist_dataset_iterator = iter(dist_dataset)
148    >>> @tf.function
149    ... def one_step(input):
150    ...   return input
151    >>> step_num = 5
152    >>> for _ in range(step_num):
153    ...   strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
154    >>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
155    (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
156     <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
157
158    Returns:
159      A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
160      the next input for all replicas.
161
162    Raises:
163      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
164    """
165    raise NotImplementedError(
166        "DistributedIterator.get_next() must be implemented in descendants.")
167
168  @property
169  def element_spec(self):
170    # pylint: disable=line-too-long
171    """The type specification of an element of `tf.distribute.DistributedIterator`.
172
173    Example usage:
174
175    >>> global_batch_size = 16
176    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
177    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
178    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
179    >>> distributed_iterator.element_spec
180    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
181                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
182     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
183                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
184
185    Returns:
186      A nested structure of `tf.TypeSpec` objects matching the structure of an
187      element of this `tf.distribute.DistributedIterator`. This returned value
188      is typically a `tf.distribute.DistributedValues` object and specifies the
189      `tf.TensorSpec` of individual components.
190    """
191    raise NotImplementedError(
192        "DistributedIterator.element_spec() must be implemented in descendants")
193
194  def get_next_as_optional(self):
195    # pylint: disable=line-too-long
196    """Returns a `tf.experimental.Optional` that contains the next value for all replicas.
197
198    If the `tf.distribute.DistributedIterator` has reached the end of the
199    sequence, the returned `tf.experimental.Optional` will have no value.
200
201    Example usage:
202
203    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
204    >>> global_batch_size = 2
205    >>> steps_per_loop = 2
206    >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
207    >>> distributed_iterator = iter(
208    ...     strategy.experimental_distribute_dataset(dataset))
209    >>> def step_fn(x):
210    ...   # train the model with inputs
211    ...   return x
212    >>> @tf.function
213    ... def train_fn(distributed_iterator):
214    ...   for _ in tf.range(steps_per_loop):
215    ...     optional_data = distributed_iterator.get_next_as_optional()
216    ...     if not optional_data.has_value():
217    ...       break
218    ...     per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),))
219    ...     tf.print(strategy.experimental_local_results(per_replica_results))
220    >>> train_fn(distributed_iterator)
221    ... # ([0 1], [2 3])
222    ... # ([4], [])
223
224    Returns:
225      An `tf.experimental.Optional` object representing the next value from the
226      `tf.distribute.DistributedIterator` (if it has one) or no value.
227    """
228    # pylint: enable=line-too-long
229    raise NotImplementedError(
230        "get_next_as_optional() not implemented in descendants")
231
232
233@tf_export("distribute.DistributedDataset", v1=[])
234class DistributedDatasetInterface(collections_abc.Iterable,
235                                  distribute_types.Iterable):
236  # pylint: disable=line-too-long
237  """Represents a dataset distributed among devices and machines.
238
239  A `tf.distribute.DistributedDataset` could be thought of as a "distributed"
240  dataset. When you use `tf.distribute` API to scale training to multiple
241  devices or machines, you also need to distribute the input data, which leads
242  to a `tf.distribute.DistributedDataset` instance, instead of a
243  `tf.data.Dataset` instance in the non-distributed case. In TF 2.x,
244  `tf.distribute.DistributedDataset` objects are Python iterables.
245
246  Note: `tf.distribute.DistributedDataset` instances are *not* of type
247  `tf.data.Dataset`. It only supports two usages we will mention below:
248  iteration and `element_spec`. We don't support any other APIs to transform or
249  inspect the dataset.
250
251  There are two APIs to create a `tf.distribute.DistributedDataset` object:
252  `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
253  `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`.
254  *When to use which?* When you have a `tf.data.Dataset` instance, and the
255  regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
256  with a new batch size that is equal to the global batch size divided by the
257  number of replicas in sync) and autosharding (i.e. the
258  `tf.data.experimental.AutoShardPolicy` options) work for you, use the former
259  API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance,
260  or you would like to customize the batch splitting or sharding, you can wrap
261  these logic in a `dataset_fn` and use the latter API. Both API handles
262  prefetch to device for the user. For more details and examples, follow the
263  links to the APIs.
264
265
266  There are two main usages of a `DistributedDataset` object:
267
268  1. Iterate over it to generate the input for a single device or multiple
269  devices, which is a `tf.distribute.DistributedValues` instance. To do this,
270  you can:
271
272    * use a pythonic for-loop construct:
273
274      >>> global_batch_size = 4
275      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
276      >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
277      >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
278      >>> @tf.function
279      ... def train_step(input):
280      ...   features, labels = input
281      ...   return labels - 0.3 * features
282      >>> for x in dist_dataset:
283      ...   # train_step trains the model using the dataset elements
284      ...   loss = strategy.run(train_step, args=(x,))
285      ...   print("Loss is", loss)
286      Loss is PerReplica:{
287        0: tf.Tensor(
288      [[0.7]
289       [0.7]], shape=(2, 1), dtype=float32),
290        1: tf.Tensor(
291      [[0.7]
292       [0.7]], shape=(2, 1), dtype=float32)
293      }
294
295      Placing the loop inside a `tf.function` will give a performance boost.
296      However `break` and `return` are currently not supported if the loop is
297      placed inside a `tf.function`. We also don't support placing the loop
298      inside a `tf.function` when using
299      `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
300      `tf.distribute.experimental.TPUStrategy` with multiple workers.
301
302    * use `__iter__` to create an explicit iterator, which is of type
303      `tf.distribute.DistributedIterator`
304
305      >>> global_batch_size = 4
306      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
307      >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
308      >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
309      >>> @tf.function
310      ... def distributed_train_step(dataset_inputs):
311      ...   def train_step(input):
312      ...     loss = tf.constant(0.1)
313      ...     return loss
314      ...   per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
315      ...   return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
316      >>> EPOCHS = 2
317      >>> STEPS = 3
318      >>> for epoch in range(EPOCHS):
319      ...   total_loss = 0.0
320      ...   num_batches = 0
321      ...   dist_dataset_iterator = iter(train_dist_dataset)
322      ...   for _ in range(STEPS):
323      ...     total_loss += distributed_train_step(next(dist_dataset_iterator))
324      ...     num_batches += 1
325      ...   average_train_loss = total_loss / num_batches
326      ...   template = ("Epoch {}, Loss: {:.4f}")
327      ...   print (template.format(epoch+1, average_train_loss))
328      Epoch 1, Loss: 0.2000
329      Epoch 2, Loss: 0.2000
330
331
332    To achieve a performance improvement, you can also wrap the `strategy.run`
333    call with a `tf.range` inside a `tf.function`. This runs multiple steps in a
334    `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker.
335    However, it is less flexible comparing with running a single step inside
336    `tf.function`. For example, you cannot run things eagerly or arbitrary
337    python code within the steps.
338
339
340  2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`.
341
342    `tf.distribute.DistributedDataset` generates
343    `tf.distribute.DistributedValues` as input to the devices. If you pass the
344    input to a `tf.function` and would like to specify the shape and type of
345    each Tensor argument to the function, you can pass a `tf.TypeSpec` object to
346    the `input_signature` argument of the `tf.function`. To get the
347    `tf.TypeSpec` of the input, you can use the `element_spec` property of the
348    `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator`
349    object.
350
351    For example:
352
353    >>> global_batch_size = 4
354    >>> epochs = 1
355    >>> steps_per_epoch = 1
356    >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
357    >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
358    >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
359    >>> @tf.function(input_signature=[dist_dataset.element_spec])
360    ... def train_step(per_replica_inputs):
361    ...   def step_fn(inputs):
362    ...     return tf.square(inputs)
363    ...   return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
364    >>> for _ in range(epochs):
365    ...   iterator = iter(dist_dataset)
366    ...   for _ in range(steps_per_epoch):
367    ...     output = train_step(next(iterator))
368    ...     print(output)
369    PerReplica:{
370      0: tf.Tensor(
371    [[4.]
372     [4.]], shape=(2, 1), dtype=float32),
373      1: tf.Tensor(
374    [[4.]
375     [4.]], shape=(2, 1), dtype=float32)
376    }
377
378
379  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
380  on distributed input for more examples and caveats.
381  """
382
383  def __iter__(self):
384    """Creates an iterator for the `tf.distribute.DistributedDataset`.
385
386    The returned iterator implements the Python Iterator protocol.
387
388    Example usage:
389
390    >>> global_batch_size = 4
391    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
392    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
393    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
394    >>> print(next(distributed_iterator))
395    PerReplica:{
396      0: tf.Tensor([1 2], shape=(2,), dtype=int32),
397      1: tf.Tensor([3 4], shape=(2,), dtype=int32)
398    }
399
400    Returns:
401      An `tf.distribute.DistributedIterator` instance for the given
402      `tf.distribute.DistributedDataset` object to enumerate over the
403      distributed data.
404    """
405    raise NotImplementedError("Must be implemented in descendants")
406
407  @property
408  def element_spec(self):
409    """The type specification of an element of this `tf.distribute.DistributedDataset`.
410
411    Example usage:
412
413    >>> global_batch_size = 16
414    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
415    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
416    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
417    >>> dist_dataset.element_spec
418    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
419                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
420     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
421                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
422
423    Returns:
424      A nested structure of `tf.TypeSpec` objects matching the structure of an
425      element of this `tf.distribute.DistributedDataset`. This returned value is
426      typically a `tf.distribute.DistributedValues` object and specifies the
427      `tf.TensorSpec` of individual components.
428    """
429    raise NotImplementedError(
430        "DistributedDataset.element_spec must be implemented in descendants.")
431
432  @doc_controls.do_not_generate_docs
433  def reduce(self, initial_state, reduce_func):
434    raise NotImplementedError(
435        "DistributedDataset.reduce must be implemented in descendants.")
436
437
438class InputWorkers(object):
439  """A 1-to-many mapping from input worker devices to compute devices."""
440
441  # TODO(ishark): Remove option canonicalize_devices and make all the callers
442  # pass canonicalized or raw device strings as relevant from strategy.
443  def __init__(self,
444               worker_device_pairs,
445               canonicalize_devices=True):
446    """Initialize an `InputWorkers` object.
447
448    Args:
449      worker_device_pairs: A sequence of pairs: `(input device, a tuple of
450        compute devices fed by that input device)`.
451      canonicalize_devices: Whether to canonicalize devices for workers fully or
452        partially. If False, it will partially canonicalize devices by removing
453        job and task.
454    """
455    self._worker_device_pairs = worker_device_pairs
456    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
457    self._canonicalize_devices = canonicalize_devices
458    if canonicalize_devices:
459      self._fed_devices = tuple(
460          tuple(device_util.canonicalize(d)
461                for d in f)
462          for _, f in self._worker_device_pairs)
463    else:
464      self._fed_devices = tuple(
465          tuple(device_util.canonicalize_without_job_and_task(d)
466                for d in f)
467          for _, f in self._worker_device_pairs)
468
469  @property
470  def num_workers(self):
471    return len(self._input_worker_devices)
472
473  @property
474  def worker_devices(self):
475    return self._input_worker_devices
476
477  def compute_devices_for_worker(self, worker_index):
478    return self._fed_devices[worker_index]
479
480  def __repr__(self):
481    devices = self.worker_devices
482    debug_repr = ",\n".join("  %d %s: %s" %
483                            (i, devices[i], self._fed_devices[i])
484                            for i in range(len(devices)))
485    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
486
487  def serialize(self):
488    return (self._worker_device_pairs, self._canonicalize_devices)
489
490  def deserialize(self, serialized):
491    return InputWorkers(serialized)
492
493
494def _calculate_replicas_with_values(strategy, input_workers, optional_list):
495  """Calcualates the number of replicas that have values.
496
497  Args:
498    strategy: the `tf.distribute.Strategy`.
499    input_workers: the `InputWorkers`.
500    optional_list: a list of lists `tf.experimental.Optional`. The values from
501      each compute device grouped by the input device.
502
503  Returns:
504    A scalar Tensor.
505  """
506  worker_has_values = []
507  for worker, optionals in zip(input_workers.worker_devices, optional_list):
508    with ops.device(worker):
509      device_has_values = [
510          math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
511      ]
512      worker_has_values.append(
513          math_ops.reduce_sum(device_has_values, keepdims=True))
514  client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
515  if strategy.extended._in_multi_worker_mode():  # pylint: disable=protected-access
516    global_has_values = strategy.reduce(
517        reduce_util.ReduceOp.SUM, client_has_values, axis=None)
518    return array_ops.reshape(global_has_values, [])
519  else:
520    return array_ops.reshape(client_has_values, [])
521
522
523def _is_statically_shaped(element_spec):
524  """Test if an iterator output is statically shaped.
525
526  For sparse and ragged tensors this only tests the batch dimension.
527
528  Args:
529    element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
530      dataset of the iterator.
531
532  Returns:
533    True if the shape is static, false otherwise.
534  """
535
536  for spec in nest.flatten(element_spec):
537    if isinstance(
538        spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
539      # For sparse or ragged tensor, we should only check the first
540      # dimension in order to get_next_as_optional. This is because
541      # when these tensors get batched by dataset only the batch dimension
542      # is set.
543      if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
544        return False
545    else:
546      for component in spec._flat_tensor_specs:  # pylint: disable=protected-access
547        if not component.shape.is_fully_defined():
548          return False
549  return True
550
551
552class DistributedIteratorBase(DistributedIteratorInterface):
553  """Common implementation for all input iterators."""
554
555  # pylint: disable=super-init-not-called
556  def __init__(self, input_workers, iterators, strategy, cardinality,
557               enable_get_next_as_optional):
558    assert isinstance(input_workers, InputWorkers)
559    if not input_workers.worker_devices:
560      raise ValueError("Should have at least one worker for input iterator.")
561
562    self._iterators = iterators
563    self._input_workers = input_workers
564    self._strategy = strategy
565    self._cardinality = cardinality
566    self._enable_get_next_as_optional = enable_get_next_as_optional
567
568  def next(self):
569    return self.__next__()
570
571  def __next__(self):
572    try:
573      return self.get_next()
574    except errors.OutOfRangeError:
575      raise StopIteration
576
577  def __iter__(self):
578    return self
579
580  def get_next_as_optional(self):
581    # Ideally get_next_as_optional() should be consistent with get_next(), but
582    # we used to always do partial batch handling in get_next_as_optional(). We
583    # are keeping this behavior for now until we understantd the impact.
584
585    # Skip partial batch handling when the dataset is infinite or empty, as
586    # there won't be any partial batches in those cases. This gives the user
587    # more static shapes as it avoids the tf.cond. Note that for empty datasets,
588    # we can only skip in single client mode, as the dataset can be non-empty on
589    # other workers.
590    if self._cardinality == cardinality_lib.INFINITE:
591      return optional_ops.Optional.from_value(
592          self._get_next_no_partial_batch_handling())
593    if (self._cardinality == 0 and
594        not self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
595      return optional_ops.Optional.empty(self._element_spec)
596
597    optional_list = []
598    for i, worker in enumerate(self._input_workers.worker_devices):
599      with ops.device(worker):
600        optional_list.append(self._iterators[i].get_next_as_optional_list())
601
602    def _create_optional_with_dummy():
603      value_list = _get_value_or_dummy(
604          self._input_workers, optional_list, produce_dummy=True)
605      per_replica = _create_per_replica(value_list, self._strategy)
606      return optional_ops.Optional.from_value(per_replica)
607
608    def _create_empty_optional():
609      return optional_ops.Optional.empty(self._element_spec)
610
611    num_replicas_with_values = _calculate_replicas_with_values(
612        self._strategy, self._input_workers, optional_list)
613
614    return control_flow_ops.cond(
615        num_replicas_with_values > 0,
616        _create_optional_with_dummy,
617        _create_empty_optional,
618        strict=True)
619
620  def get_next(self, name=None):
621    """Returns the next input from the iterator for all replicas."""
622    with distribution_strategy_context.enter_or_assert_strategy(
623        self._strategy):
624      if distribution_strategy_context.get_replica_context() is not None:
625        raise ValueError("next(iterator) should be called from outside of "
626                         "replica_fn. e.g. strategy.run(replica_fn, "
627                         "args=(next(iterator),))")
628
629    if not self._enable_get_next_as_optional:
630      return self._get_next_no_partial_batch_handling(name)
631
632    optional_list = []
633    for i, worker in enumerate(self._input_workers.worker_devices):
634      with ops.device(worker):
635        optional_list.append(self._iterators[i].get_next_as_optional_list())
636    num_replicas_with_values = _calculate_replicas_with_values(
637        self._strategy, self._input_workers, optional_list)
638
639    def _value_or_dummy():
640      value_list = _get_value_or_dummy(
641          self._input_workers, optional_list, produce_dummy=True)
642      return _create_per_replica(value_list, self._strategy)
643
644    def _eof():
645      # Optional.get_value raises InvalidArgumentError when there's no value,
646      # so we need to call GetNext to raise EOFError.
647      return self._get_next_no_partial_batch_handling()
648
649    return control_flow_ops.cond(
650        num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)
651
652  def _get_next_no_partial_batch_handling(self, name=None):
653    replicas = []
654    for i, worker in enumerate(self._input_workers.worker_devices):
655      if name is not None:
656        d = tf_device.DeviceSpec.from_string(worker)
657        new_name = "%s_%s_%d" % (name, d.job, d.task)
658      else:
659        new_name = None
660      with ops.device(worker):
661        # Make `replicas` a flat list of values across all replicas.
662        replicas.extend(self._iterators[i].get_next_as_list(new_name))
663    return _create_per_replica(replicas, self._strategy)
664
665
666class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec):
667  """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction."""
668
669  __slots__ = [
670      "_input_workers", "_element_spec", "_strategy", "_cardinality",
671      "_enable_get_next_as_optional", "_options", "_canonicalize_devices"
672  ]
673
674  def __init__(self,
675               input_workers,
676               element_spec,
677               strategy,
678               options,
679               cardinality=cardinality_lib.UNKNOWN,
680               enable_get_next_as_optional=None):
681    # We don't want to allow deserialization of this class because we don't
682    # serialize the strategy object. Currently the only places where
683    # _deserialize is called is when we save/restore using SavedModels.
684    if isinstance(input_workers, tuple):
685      raise NotImplementedError("DistributedIteratorSpec does not have support "
686                                "for deserialization.")
687    else:
688      self._input_workers = input_workers
689      self._element_spec = element_spec
690      self._strategy = strategy
691      self._cardinality = cardinality
692      self._enable_get_next_as_optional = enable_get_next_as_optional
693      self._options = options
694      if self._strategy:
695        self._canonicalize_devices = getattr(self._strategy,
696                                             "_canonicalize_devices", True)
697      else:
698        self._canonicalize_devices = True
699
700  def _serialize(self):
701    # We cannot serialize the strategy object so we convert it to an id that we
702    # can use for comparison.
703    return (self._input_workers.serialize(), self._element_spec,
704            id(self._strategy), id(self._options))
705
706  def _deserialize(self):
707    raise ValueError(
708        f"Deserialization is currently unsupported for {type(self)}.")
709
710  def sanity_check_type(self, other):
711    """Returns the most specific TypeSpec compatible with `self` and `other`.
712
713    Args:
714      other: A `TypeSpec`.
715
716    Raises:
717      ValueError: If there is no TypeSpec that is compatible with both `self`
718        and `other`.
719    """
720    # pylint: disable=protected-access
721    if type(self) is not type(other):
722      raise ValueError("No TypeSpec is compatible with both %s and %s" %
723                       (self, other))
724    if self._input_workers.serialize() != other._input_workers.serialize():
725      raise ValueError("_input_workers is not compatible with both %s "
726                       "and %s" % (self, other))
727    if self._strategy is not other._strategy:
728      raise ValueError("tf.distribute strategy is not compatible with both %s "
729                       "and %s" % (self, other))
730
731  def is_subtype_of(self, other):
732    """Returns True if `self` is subtype of `other`.
733
734    Args:
735      other: A `TypeSpec`.
736    """
737    try:
738      self.sanity_check_type(other)
739      nest.assert_same_structure(self._element_spec, other._element_spec)  # pylint: disable=protected-access
740    except (TypeError, ValueError):
741      return False
742
743    self_elements = nest.flatten(self._element_spec)
744    other_elements = nest.flatten(other._element_spec)  # pylint: disable=protected-access
745
746    return all(
747        self_element.is_subtype_of(other_element)
748        for (self_element, other_element) in zip(self_elements, other_elements))
749
750  def most_specific_common_supertype(self, others):
751    """Returns the most specific supertype of `self` and `others`.
752
753    Args:
754      others: A Sequence of `TypeSpec`.
755
756    Returns `None` if a supertype does not exist.
757    """
758    try:
759      for other in others:
760        self.sanity_check_type(other)
761        nest.assert_same_structure(self._element_spec, other._element_spec)  # pylint: disable=protected-access
762    except (TypeError, ValueError):
763      return None
764
765    self_elements = nest.flatten(self._element_spec)
766    others_elements = [nest.flatten(other._element_spec) for other in others]  # pylint: disable=protected-access
767    common_elements = [None] * len(self_elements)
768
769    for i, self_element in enumerate(self_elements):
770      common_elements[i] = self_element.most_specific_common_supertype(
771          [other_elements[i] for other_elements in others_elements])
772      if common_elements[i] is None:
773        return None
774    common_element_spec = nest.pack_sequence_as(self._element_spec,
775                                                common_elements)
776    return type(self)(
777        self._input_workers,
778        common_element_spec,
779        self._strategy,
780        self._options,
781        cardinality=self._cardinality,
782        enable_get_next_as_optional=self._enable_get_next_as_optional)
783
784  def _with_tensor_ranks_only(self):
785    element_spec = nest.map_structure(
786        lambda s: s._with_tensor_ranks_only(),  # pylint: disable=protected-access
787        self._element_spec)
788    return type(self)(
789        self._input_workers,
790        element_spec,
791        self._strategy,
792        self._options,
793        cardinality=self._cardinality,
794        enable_get_next_as_optional=self._enable_get_next_as_optional)
795
796  # TODO(b/206014848): Remove once names are not used.
797  def _without_tensor_names(self):
798    element_spec = nest.map_structure(
799        lambda s: s._without_tensor_names(),  # pylint: disable=protected-access
800        self._element_spec)
801    return type(self)(
802        self._input_workers,
803        element_spec,
804        self._strategy,
805        self._options,
806        cardinality=self._cardinality,
807        enable_get_next_as_optional=self._enable_get_next_as_optional)
808
809
810class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec):
811  """Type specification for `DistributedIterator`."""
812
813  @property
814  def value_type(self):
815    return DistributedIterator
816
817  @property
818  def _component_specs(self):
819    specs = []
820    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
821
822    for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
823      element_spec = nest.map_structure(
824          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
825      specs.append(
826          _SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
827                                           element_spec, self._options,
828                                           self._canonicalize_devices))
829    return specs
830
831  def _to_components(self, value):
832    return value._iterators  # pylint: disable=protected-access
833
834  def _from_components(self, components):
835    return DistributedIterator(
836        input_workers=self._input_workers,
837        iterators=None,
838        components=components,
839        element_spec=self._element_spec,
840        strategy=self._strategy,
841        cardinality=self._cardinality,
842        enable_get_next_as_optional=self._enable_get_next_as_optional,
843        options=self._options)
844
845  @staticmethod
846  def from_value(value):
847    # pylint: disable=protected-access
848    return DistributedIteratorSpec(
849        value._input_workers,
850        value._element_spec,
851        value._strategy,
852        value._options,
853        cardinality=value._cardinality,
854        enable_get_next_as_optional=value._enable_get_next_as_optional)
855
856
857class DistributedIterator(DistributedIteratorBase,
858                          composite_tensor.CompositeTensor):
859  """Input Iterator for a distributed dataset."""
860
861  def __init__(self,
862               input_workers=None,
863               iterators=None,
864               strategy=None,
865               components=None,
866               element_spec=None,
867               cardinality=cardinality_lib.UNKNOWN,
868               enable_get_next_as_optional=False,
869               options=None):
870    if input_workers is None:
871      raise ValueError("`input_workers` should be "
872                       "provided.")
873
874    error_message = ("Either `input_workers` or "
875                     "both `components` and `element_spec` need to be "
876                     "provided.")
877    self._options = options
878
879    if iterators is None:
880      if (components is None or element_spec is None):
881        raise ValueError(error_message)
882      self._element_spec = element_spec
883      self._input_workers = input_workers
884      self._iterators = components
885      self._strategy = strategy
886      self._cardinality = cardinality
887      self._enable_get_next_as_optional = enable_get_next_as_optional
888    else:
889      if (components is not None and element_spec is not None):
890        raise ValueError(error_message)
891
892      super(DistributedIterator,
893            self).__init__(input_workers, iterators, strategy, cardinality,
894                           enable_get_next_as_optional)
895
896  @property
897  def element_spec(self):
898    # When partial batch handling is enabled, always set the batch dimension to
899    # None, otherwise we just follow element_spec of the underlying dataset
900    # (whose batch dimension may also be None). This is because with partial
901    # batching handling we could always produce empty batches.
902    if (self._enable_get_next_as_optional and
903        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
904      return nest.map_structure(
905          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
906    return self._element_spec
907
908  @property
909  def _type_spec(self):
910    # Note that we use actual element_spec instead of the rebatched-as-dynamic
911    # one to create DistributedIteratorSpec, to be consistent with the
912    # underlying iterators' specs.
913    return DistributedIteratorSpec(self._input_workers, self._element_spec,
914                                   self._strategy,
915                                   self._options,
916                                   self._cardinality,
917                                   self._enable_get_next_as_optional)
918
919
920class _IterableInput(DistributedDatasetInterface):
921  """Base class for iterable inputs for distribution strategies."""
922
923  # pylint: disable=super-init-not-called
924  def __init__(self, input_workers):
925    assert isinstance(input_workers, InputWorkers)
926    self._input_workers = input_workers
927
928  def __iter__(self):
929    raise NotImplementedError("must be implemented in descendants")
930
931  def reduce(self, initial_state, reduce_fn):
932    """Execute a `reduce_fn` over all the elements of the input."""
933    iterator = iter(self)
934    optional_data = iterator.get_next_as_optional()
935
936    def cond(optional_data, state):
937      del state  # Unused.
938      return optional_data.has_value()
939
940    def loop_body(optional_data, state):
941      """Executes `reduce_fn` in a loop till the dataset is empty."""
942      state = reduce_fn(state, optional_data.get_value())
943      optional_data = iterator.get_next_as_optional()
944      return optional_data, state
945
946    optional_data, final_state = control_flow_ops.while_loop(
947        cond,
948        loop_body, [optional_data, initial_state],
949        parallel_iterations=1,
950        return_same_structure=True)
951    return final_state
952
953
954class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec):
955  """Type specification for `DistributedDataset."""
956
957  @property
958  def value_type(self):
959    return DistributedDataset
960
961  @property
962  def _component_specs(self):
963    specs = []
964    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
965
966    for i, _ in enumerate(worker_device_pairs):
967      element_spec = nest.map_structure(
968          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
969      specs.append(dataset_ops.DatasetSpec(element_spec))
970    return specs
971
972  def _to_components(self, value):
973    return value._cloned_datasets  # pylint: disable=protected-access
974
975  def _from_components(self, components):
976    return DistributedDataset(
977        input_workers=self._input_workers,
978        strategy=self._strategy,
979        components=components,
980        element_spec=self._element_spec,
981        enable_get_next_as_optional=self._enable_get_next_as_optional,
982        options=self._options)
983
984  @staticmethod
985  def from_value(value):
986    # pylint: disable=protected-access
987    return DistributedDatasetSpec(
988        value._input_workers,
989        value._element_spec,
990        value._strategy,
991        value._options,
992        enable_get_next_as_optional=value._enable_get_next_as_optional)
993    # pylint: enable=protected-access
994
995
996class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
997  """Distributed dataset that supports prefetching to multiple devices."""
998
999  def __init__(self,
1000               input_workers,
1001               strategy,
1002               dataset=None,
1003               num_replicas_in_sync=None,
1004               input_context=None,
1005               components=None,
1006               element_spec=None,
1007               enable_get_next_as_optional=None,
1008               build=True,
1009               options=None):
1010    """Distribute the dataset on all workers.
1011
1012    If `num_replicas_in_sync` is not None, we split each batch of the dataset
1013    into `num_replicas_in_sync` smaller batches, to be distributed among that
1014    worker's replicas, so that the batch size for a global step (across all
1015    workers and replicas) is as expected.
1016
1017    Args:
1018      input_workers: an `InputWorkers` object.
1019      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1020        handle last partial batch.
1021      dataset: `tf.data.Dataset` that will be used as the input source. Either
1022        dataset or components field should be passed when constructing
1023        DistributedDataset. Use this when contructing DistributedDataset from a
1024        new `tf.data.Dataset`. Use components when constructing using
1025        DistributedDatasetSpec.
1026      num_replicas_in_sync: Optional integer. If this is not None, the value
1027        is used to decide how to rebatch datasets into smaller batches so that
1028        the total batch size for each step (across all workers and replicas)
1029        adds up to `dataset`'s batch size.
1030      input_context: `InputContext` for sharding. Only pass this in for between
1031        graph multi-worker cases where there is only one `input_worker`. In
1032        these cases, we will shard based on the `input_pipeline_id` and
1033        `num_input_pipelines` in the `InputContext`.
1034      components: datasets when DistributedDataset is constructed from
1035        DistributedDatasetSpec. Either field dataset or components should be
1036        passed.
1037      element_spec: element spec for DistributedDataset when constructing from
1038        DistributedDatasetSpec. This will be used to set the element_spec for
1039        DistributedDataset and verified against element_spec from components.
1040      enable_get_next_as_optional: this is required when components is passed
1041        instead of dataset.
1042      build: whether to build underlying datasets when this object is created.
1043        This is only useful for `ParameterServerStrategy` now.
1044      options: `tf.distribute.InputOptions` used to control options on how this
1045        dataset is distributed.
1046    """
1047    super(DistributedDataset, self).__init__(input_workers=input_workers)
1048    if input_workers is None or strategy is None:
1049      raise ValueError("input_workers and strategy are required arguments")
1050    if dataset is not None and components is not None:
1051      raise ValueError("Only one of dataset or components should be present")
1052    if dataset is None and components is None:
1053      raise ValueError("At least one of dataset or components should be passed")
1054
1055    self._input_workers = input_workers
1056    self._strategy = strategy
1057    self._options = options
1058    self._input_context = input_context
1059    self._num_replicas_in_sync = num_replicas_in_sync
1060
1061    if dataset is not None:
1062      self._original_dataset = dataset
1063      self._built = False
1064      if build:
1065        self.build()
1066    else:
1067      if not build:
1068        raise ValueError(
1069            "When constructing DistributedDataset with components, build "
1070            "should not be False. This is an internal error. Please file a "
1071            "bug.")
1072      if enable_get_next_as_optional is None:
1073        raise ValueError(
1074            "When constructing DistributedDataset with components, " +
1075            "enable_get_next_as_optional should also be passed")
1076      self._cloned_datasets = components
1077      self._cardinality = _cardinality(self._cloned_datasets[0])
1078      self._enable_get_next_as_optional = enable_get_next_as_optional
1079
1080      assert element_spec is not None
1081      if element_spec != _create_distributed_tensor_spec(
1082          self._strategy, self._cloned_datasets[0].element_spec):
1083        raise ValueError("Mismatched element_spec from the passed components")
1084      self._element_spec = element_spec
1085
1086      self._built = True
1087
1088  def build(self, dataset_to_replace=None):
1089    assert not self._built
1090    dataset = dataset_to_replace or self._original_dataset
1091    self._cardinality = _cardinality(dataset)
1092    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1093        self._strategy, dataset, self._cardinality)
1094    distribute_start_time_ns = time.time_ns()
1095    self._create_cloned_datasets_from_dataset(dataset, self._input_context,
1096                                              self._input_workers,
1097                                              self._strategy,
1098                                              self._num_replicas_in_sync)
1099    if context.executing_eagerly():
1100      # Records the time to initialize the distributed dataset.
1101      context.async_wait()
1102      distribute_duration_ms = (time.time_ns() -
1103                                distribute_start_time_ns) // 1_000_000
1104      _distributed_dataset_initialization_time_milliseconds.get_cell(
1105          self._strategy.__class__.__name__,
1106          str(self._input_workers.num_workers)).add(distribute_duration_ms)
1107    self._element_spec = _create_distributed_tensor_spec(
1108        self._strategy, self._cloned_datasets[0].element_spec)
1109    self._built = True
1110
1111  @property
1112  def cardinality(self):
1113    if not self._built:
1114      raise ValueError(
1115          "Cannot get the cardinality of a dataset that is not built")
1116    return self._cardinality
1117
1118  def _create_cloned_datasets_from_dataset(self, dataset, input_context,
1119                                           input_workers, strategy,
1120                                           num_replicas_in_sync):
1121    # We clone and shard the dataset on each worker. The current setup tries to
1122    # shard the dataset by files if possible so that each worker sees a
1123    # different subset of files. If that is not possible, will attempt to shard
1124    # the final input such that each worker will run the entire preprocessing
1125    # pipeline and only receive its own shard of the dataset.
1126
1127    # Additionally, we rebatch the dataset on each worker into
1128    # `num_replicas_in_sync` smaller batches to be distributed among that
1129    # worker's replicas, so that the batch size for a global step (across all
1130    # workers and replicas) adds up to the original dataset's batch size.
1131    if num_replicas_in_sync is not None:
1132      num_workers = input_context.num_input_pipelines if input_context else len(
1133          input_workers.worker_devices)
1134      rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
1135                                         num_replicas_in_sync)
1136    else:
1137      rebatch_fn = None
1138    self._cloned_datasets = []
1139    if input_context:
1140      # Between-graph where we rely on the input_context for sharding
1141      assert input_workers.num_workers == 1
1142      if rebatch_fn is not None:
1143        dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
1144      dataset = input_ops.auto_shard_dataset(dataset,
1145                                             input_context.num_input_pipelines,
1146                                             input_context.input_pipeline_id,
1147                                             num_replicas_in_sync)
1148      self._cloned_datasets.append(dataset)
1149    else:
1150      replicated_ds = distribute.replicate(dataset,
1151                                           input_workers.worker_devices)
1152      for i, worker in enumerate(input_workers.worker_devices):
1153        with ops.device(worker):
1154          cloned_dataset = replicated_ds[worker]
1155          if rebatch_fn is not None:
1156            cloned_dataset = rebatch_fn(cloned_dataset, i)
1157          cloned_dataset = input_ops.auto_shard_dataset(
1158              cloned_dataset, len(input_workers.worker_devices), i,
1159              num_replicas_in_sync)
1160          self._cloned_datasets.append(cloned_dataset)
1161
1162  def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
1163    """Returns a callable that rebatches the input dataset.
1164
1165    Args:
1166      dataset: A `tf.data.Dataset` representing the dataset to be distributed.
1167      num_workers: An integer representing the number of workers to distribute
1168        `dataset` among.
1169      num_replicas_in_sync: An integer representing the number of replicas in
1170        sync across all workers.
1171    """
1172    if num_replicas_in_sync % num_workers:
1173      raise ValueError(
1174          "tf.distribute expects every worker to have the same number of "
1175          "replicas. However, encountered `num_replicas_in_sync` ({}) that "
1176          "cannot be divided by `num_workers` ({})".format(
1177              num_replicas_in_sync, num_workers))
1178
1179    num_replicas_per_worker = num_replicas_in_sync // num_workers
1180    with ops.colocate_with(dataset._variant_tensor):  # pylint: disable=protected-access
1181      batch_size = distribute.compute_batch_size(dataset)
1182
1183    def rebatch_fn(dataset, worker_index):
1184      try:
1185        # pylint: disable=protected-access
1186        def apply_rebatch():
1187          batch_sizes = distribute.batch_sizes_for_worker(
1188              batch_size, num_workers, num_replicas_per_worker, worker_index)
1189          return distribute._RebatchDataset(
1190              dataset, batch_sizes).prefetch(num_replicas_per_worker)
1191
1192        def apply_legacy_rebatch():
1193          return distribute._LegacyRebatchDataset(
1194              dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
1195
1196        with ops.colocate_with(dataset._variant_tensor):
1197          return control_flow_ops.cond(
1198              math_ops.not_equal(batch_size, -1),
1199              true_fn=apply_rebatch,
1200              false_fn=apply_legacy_rebatch)
1201      except errors.InvalidArgumentError as e:
1202        if "without encountering a batch" in str(e):
1203          six.reraise(
1204              ValueError,
1205              ValueError(
1206                  "Call the `batch` method on the input Dataset in order to be "
1207                  "able to split your input across {} replicas.\n Please see "
1208                  "the tf.distribute.Strategy guide. {}".format(
1209                      num_replicas_in_sync, e)),
1210              sys.exc_info()[2])
1211        else:
1212          raise
1213
1214    return rebatch_fn
1215
1216  def __iter__(self):
1217    if not (context.executing_eagerly() or
1218            ops.get_default_graph().building_function):
1219      raise RuntimeError("__iter__() is only supported inside of tf.function "
1220                         "or when eager execution is enabled.")
1221    if not self._built:
1222      raise ValueError("To use this dataset, you need to pass this dataset to "
1223                       "ClusterCoordinator.create_per_worker_dataset.")
1224
1225    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
1226                                   True)
1227
1228    worker_iterators = _create_iterators_per_worker(
1229        self._cloned_datasets,
1230        self._input_workers,
1231        options=self._options,
1232        canonicalize_devices=canonicalize_devices)
1233    iterator = DistributedIterator(
1234        self._input_workers,
1235        worker_iterators,
1236        self._strategy,
1237        cardinality=self._cardinality,
1238        enable_get_next_as_optional=self._enable_get_next_as_optional,
1239        options=self._options)
1240    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1241
1242    # When async eager is enabled, sometimes the iterator may not finish
1243    # initialization before passing to a multi device function, add a sync point
1244    # here to make sure all underlying iterators are initialized.
1245    if context.executing_eagerly():
1246      context.async_wait()
1247
1248    return iterator
1249
1250  @property
1251  def element_spec(self):
1252    """The type specification of an element of this dataset."""
1253    # When partial batch handling is enabled, always set the batch dimension to
1254    # None, otherwise we just follow element_spec of the underlying dataset
1255    # (whose batch dimension may also be None). This is because with partial
1256    # batching handling we could always produce empty batches.
1257    if (self._enable_get_next_as_optional and
1258        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1259      return nest.map_structure(
1260          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1261    return self._element_spec
1262
1263  @property
1264  def _type_spec(self):
1265    return DistributedDatasetSpec(
1266        self._input_workers,
1267        self._element_spec,
1268        self._strategy,
1269        self._options,
1270        enable_get_next_as_optional=self._enable_get_next_as_optional)
1271
1272
1273class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec):
1274  """Type specification for `DistributedDatasetsFromFunction."""
1275
1276  @property
1277  def value_type(self):
1278    return DistributedDatasetsFromFunction
1279
1280  @property
1281  def _component_specs(self):
1282    specs = []
1283    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
1284
1285    for i, _ in enumerate(worker_device_pairs):
1286      element_spec = nest.map_structure(
1287          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
1288      specs.append(dataset_ops.DatasetSpec(element_spec))
1289    return specs
1290
1291  def _to_components(self, value):
1292    return value._datasets  # pylint: disable=protected-access
1293
1294  def _from_components(self, components):
1295    return DistributedDatasetsFromFunction(
1296        input_workers=self._input_workers,
1297        strategy=self._strategy,
1298        components=components,
1299        element_spec=self._element_spec,
1300        options=self._options)
1301
1302  @staticmethod
1303  def from_value(value):
1304    # pylint: disable=protected-access
1305    return DistributedDatasetsFromFunctionSpec(
1306        input_workers=value._input_workers,
1307        element_spec=value._element_spec,
1308        strategy=value._strategy,
1309        options=value._options)
1310
1311
1312# TODO(priyag): Add other replication modes.
1313class DistributedDatasetsFromFunction(_IterableInput,
1314                                      composite_tensor.CompositeTensor):
1315  """Inputs created from dataset function."""
1316
1317  def __init__(self,
1318               input_workers,
1319               strategy,
1320               input_contexts=None,
1321               dataset_fn=None,
1322               options=None,
1323               components=None,
1324               element_spec=None,
1325               build=True):
1326    """Makes an iterable from datasets created by the given function.
1327
1328    Args:
1329      input_workers: an `InputWorkers` object.
1330      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1331        handle last partial batch.
1332      input_contexts: A list of `InputContext` instances to be passed to call(s)
1333        to `dataset_fn`. Length and order should match worker order in
1334        `worker_device_pairs`.
1335      dataset_fn: A function that returns a `Dataset` given an `InputContext`.
1336        Either dataset_fn or components should be passed to construct
1337        DistributedDatasetsFromFunction. Use this when constructing
1338        DistributedDataset using a function. Use components when constructing
1339        using DistributedDatasetsFromFunctionSpec.
1340      options: `tf.distribute.InputOptions` used to control options on how this
1341        dataset is distributed.
1342      components: datasets when DistributedDatasetsFromFunction is constructed
1343        from DistributedDatasetsFromFunctionSpec. Only one of dataset or
1344        components should be passed.
1345      element_spec: element spec for DistributedDataset when constructing from
1346        DistributedDatasetSpec. This will be used to set the element_spec for
1347        DistributedDatasetsFromFunctionSpec and verified against element_spec
1348        from components.
1349      build: whether to build underlying datasets when this object is created.
1350        This is only useful for `ParameterServerStrategy` now.
1351    """
1352    super(DistributedDatasetsFromFunction, self).__init__(
1353        input_workers=input_workers)
1354    self._input_workers = input_workers
1355    self._strategy = strategy
1356    self._options = options
1357    if dataset_fn is not None and components is not None:
1358      raise ValueError("Only one of dataset_fn or components should be set")
1359    if dataset_fn is None and components is None:
1360      raise ValueError("At least one of dataset_fn or components should be set")
1361
1362    if dataset_fn is not None:
1363      if input_workers.num_workers != len(input_contexts):
1364        raise ValueError(
1365            "Number of input workers (%d) is not same as number of "
1366            "input_contexts (%d)" %
1367            (input_workers.num_workers, len(input_contexts)))
1368      self._input_contexts = input_contexts
1369      self._dataset_fn = dataset_fn
1370      self._built = False
1371      if build:
1372        self.build()
1373    else:
1374      if element_spec is None:
1375        raise ValueError(
1376            "element_spec should also be passed when passing components")
1377      if not build:
1378        raise ValueError(
1379            "When constructing DistributedDatasetFromFunction with components, "
1380            "build should not be False. This is an internal error. Please file "
1381            "a bug.")
1382      self._element_spec = element_spec
1383      self._datasets = components
1384      self._built = True
1385      self._cardinality = _cardinality(self._datasets[0])
1386      self._enable_get_next_as_optional = _enable_get_next_as_optional(
1387          self._strategy, self._datasets[0], self._cardinality)
1388
1389  def build(self):
1390    assert not self._built
1391    distribute_start_time_ns = time.time_ns()
1392    self._datasets, element_spec = (
1393        _create_datasets_from_function_with_input_context(
1394            self._input_contexts, self._input_workers, self._dataset_fn))
1395    if context.executing_eagerly():
1396      # Records the time to initialize the distributed dataset.
1397      context.async_wait()
1398      distribute_duration_ms = (time.time_ns() -
1399                                distribute_start_time_ns) // 1_000_000
1400      _distributed_dataset_from_function_initialization_time_milliseconds.get_cell(
1401          self._strategy.__class__.__name__,
1402          str(self._input_workers.num_workers)).add(distribute_duration_ms)
1403
1404    self._element_spec = _create_distributed_tensor_spec(
1405        self._strategy, element_spec)
1406    self._cardinality = _cardinality(self._datasets[0])
1407    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1408        self._strategy, self._datasets[0], self._cardinality)
1409    self._built = True
1410
1411  @property
1412  def cardinality(self):
1413    if not self._built:
1414      raise ValueError(
1415          "Cannot get the cardinality of a dataset that is not built")
1416    return self._cardinality
1417
1418  def __iter__(self):
1419    if not (ops.executing_eagerly_outside_functions() or
1420            ops.get_default_graph().building_function):
1421      raise RuntimeError("__iter__() is only supported inside of tf.function "
1422                         "or when eager execution is enabled.")
1423
1424    if not self._built:
1425      raise ValueError("You need to use this dataset in "
1426                       "ClusterCoordinator.create_per_worker_dataset.")
1427
1428    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
1429                                   True)
1430
1431    iterators = _create_iterators_per_worker(
1432        self._datasets,
1433        self._input_workers,
1434        options=self._options,
1435        canonicalize_devices=canonicalize_devices)
1436    iterator = DistributedIterator(
1437        input_workers=self._input_workers,
1438        iterators=iterators,
1439        strategy=self._strategy,
1440        cardinality=self._cardinality,
1441        enable_get_next_as_optional=self._enable_get_next_as_optional,
1442        options=self._options)
1443    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1444
1445    # When async eager is enabled, sometimes the iterator may not finish
1446    # initialization before passing to a multi device function, add a sync
1447    # point here to make sure all underlying iterators are initialized.
1448    if context.executing_eagerly():
1449      context.async_wait()
1450
1451    return iterator
1452
1453  @property
1454  def element_spec(self):
1455    """The type specification of an element of this dataset."""
1456    # When partial batch handling is enabled, always set the batch dimension to
1457    # None, otherwise we just follow element_spec of the underlying dataset
1458    # (whose batch dimension may also be None). This is because with partial
1459    # batching handling we could always produce empty batches.
1460    if (self._enable_get_next_as_optional and
1461        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1462      return nest.map_structure(
1463          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1464    return self._element_spec
1465
1466  @property
1467  def _type_spec(self):
1468    return DistributedDatasetsFromFunctionSpec(self._input_workers,
1469                                               self._element_spec,
1470                                               self._strategy, self._options)
1471
1472
1473def _dummy_tensor_fn(value_structure):
1474  """A function to create dummy tensors from `value_structure`."""
1475
1476  def create_dummy_tensor(spec):
1477    """Create a dummy tensor with possible batch dimensions set to 0."""
1478    if hasattr(spec, "_create_empty_value"):
1479      # Type spec may overwrite default dummy values behavior by declaring the
1480      # `_create_empty_value(self)` method. This method must return a value
1481      # compatible with the type spec with batch dimensions set to 0 or fail if
1482      # such a value does not exist. This allows a composite tensor to customize
1483      # dummy values creation as, in general, its dummy value is not composed
1484      # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is
1485      # never allowed to be empty). See b/183969859 for more discussions.
1486      # TODO(b/186079336): reconsider CompositeTensor support.
1487      return spec._create_empty_value()  # pylint: disable=protected-access
1488
1489    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1490      # Splice out the ragged dimensions.
1491      # pylint: disable=protected-access
1492      feature_shape = spec._shape[:1].concatenate(
1493          spec._shape[(1 + spec._ragged_rank):])
1494      feature_type = spec._dtype
1495      # pylint: enable=protected-access
1496    else:
1497      feature_shape = spec.shape
1498      feature_type = spec.dtype
1499    # Ideally we should set the batch dimension to 0, however as in
1500    # DistributionStrategy we don't know the batch dimension, we try to
1501    # guess it as much as possible. If the feature has unknown dimensions, we
1502    # will set them to 0. If the feature shape is already static, we guess the
1503    # first dimension as batch dimension and set it to 0.
1504    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
1505            if feature_shape else [])
1506    if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
1507                 feature_shape.is_fully_defined()):
1508      dims[0] = tensor_shape.Dimension(0)
1509
1510    if isinstance(spec, sparse_tensor.SparseTensorSpec):
1511      return sparse_tensor.SparseTensor(
1512          values=array_ops.zeros(0, feature_type),
1513          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
1514          dense_shape=dims)
1515
1516    # Create the dummy tensor.
1517    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
1518    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1519      # Reinsert the ragged dimensions with size 0.
1520      # pylint: disable=protected-access
1521      row_splits = array_ops.zeros(1, spec._row_splits_dtype)
1522      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
1523          dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
1524      # pylint: enable=protected-access
1525    return dummy_tensor
1526
1527  return nest.map_structure(create_dummy_tensor, value_structure)
1528
1529
1530def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
1531  """Returns the value of the optionals or dummy values.
1532
1533  Args:
1534    input_workers: the `InputWorkers`.
1535    optional_list: a list of lists `tf.experimental.Optional`. The values from
1536      each compute device grouped by the input device.
1537    produce_dummy: a bool. Whether to produce dummy tensors when the optional
1538      doesn't have a value.
1539
1540  Returns:
1541    A flatten list of Tensors.
1542
1543  """
1544  value_list = []
1545  for i, worker in enumerate(input_workers.worker_devices):
1546    with ops.device(worker):
1547      devices = input_workers.compute_devices_for_worker(i)
1548      for j, device in enumerate(devices):
1549        with ops.device(device):
1550          if produce_dummy:
1551            # pylint: disable=cell-var-from-loop
1552            value_list.append(
1553                control_flow_ops.cond(
1554                    optional_list[i][j].has_value(),
1555                    lambda: optional_list[i][j].get_value(),  # pylint: disable=unnecessary-lambda
1556                    lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
1557                    strict=True,
1558                ))
1559            # pylint: enable=cell-var-from-loop
1560          else:
1561            value_list.append(optional_list[i][j].get_value())
1562  return value_list
1563
1564
1565class _SingleWorkerDatasetIteratorBase(object):
1566  """Iterator for a single `tf.data.Dataset`."""
1567
1568  def __init__(self, dataset, worker, devices, options=None):
1569    """Create iterator for the `dataset` to fetch data to worker's `devices` .
1570
1571    A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch
1572    input to the devices on the given worker.
1573
1574    Args:
1575      dataset: A `tf.data.Dataset` instance.
1576      worker: Worker on which ops should be created.
1577      devices: Distribute data from `dataset` to these devices.
1578      options: options.
1579    """
1580    self._dataset = dataset
1581    self._worker = worker
1582    self._devices = devices
1583    self._element_spec = dataset.element_spec
1584    self._options = options
1585    self._make_iterator()
1586
1587  def _make_iterator(self):
1588    raise NotImplementedError("must be implemented in descendants")
1589
1590  def _format_data_list_with_options(self, data_list):
1591    """Change the data in to a list type if required.
1592
1593    The OwnedMultiDeviceIterator returns the list data type,
1594    while the PER_REPLICA iterator (when used with prefetch disabled)
1595    returns without the enclosed list. This is to fix the inconsistency.
1596    Args:
1597      data_list: data_list
1598    Returns:
1599      list
1600    """
1601    if (self._options and self._options.experimental_replication_mode ==
1602        InputReplicationMode.PER_REPLICA and
1603        not self._options.experimental_fetch_to_device):
1604      return [data_list]
1605    else:
1606      return data_list
1607
1608  def get_next(self, device, name=None):
1609    """Get next element for the given device."""
1610    del name
1611    with ops.device(self._worker):
1612      if _should_use_multi_device_iterator(self._options):
1613        return self._iterator.get_next(device)
1614      else:
1615        return self._iterator.get_next()
1616
1617  def get_next_as_list(self, name=None):
1618    """Get next element from the underlying iterator.
1619
1620    Runs the iterator get_next() within a device scope. Since this doesn't use
1621    get_next_as_optional(), it is considerably faster than get_next_as_list(),
1622    but it raises EOFError if any of the device doesn't get any data.
1623
1624    Args:
1625      name: not used.
1626
1627    Returns:
1628      A list consisting of the next data from each device.
1629    """
1630    del name
1631    with ops.device(self._worker):
1632      return self._format_data_list_with_options(self._iterator.get_next())
1633
1634  def get_next_as_optional_list(self):
1635    with ops.device(self._worker):
1636      return self._format_data_list_with_options(
1637          self._iterator.get_next_as_optional())
1638
1639
1640class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
1641  """Type specification for `_SingleWorkerOwnedDatasetIterator`."""
1642
1643  __slots__ = [
1644      "_worker", "_devices", "_element_spec", "_options",
1645      "_canonicalize_devices"
1646  ]
1647
1648  def __init__(self, worker, devices, element_spec, options,
1649               canonicalize_devices=True):
1650    self._worker = worker
1651    if canonicalize_devices:
1652      self._devices = tuple(device_util.canonicalize(d) for d in devices)
1653    else:
1654      self._devices = tuple(
1655          device_util.canonicalize_without_job_and_task(d) for d in devices)
1656    self._element_spec = element_spec
1657    # `self._options` intentionally made not `None` for proper serialization.
1658    self._options = (options if options is not None else
1659                     distribute_lib.InputOptions())
1660    self._canonicalize_devices = canonicalize_devices
1661
1662  @property
1663  def value_type(self):
1664    return _SingleWorkerOwnedDatasetIterator
1665
1666  def _serialize(self):
1667    return (self._worker, self._devices, self._element_spec, self._options,
1668            self._canonicalize_devices)
1669
1670  def _get_multi_device_iterator_spec(self, specs):
1671    device_scope = device_util.canonicalize(self._worker, device_util.current())
1672    host_device = device_util.get_host_for_device(device_scope)
1673    # source_device while creating iterator governs the worker device in
1674    # iterator spec.
1675    worker = host_device
1676    specs.append(
1677        multi_device_iterator_ops.MultiDeviceIteratorSpec(
1678            self._devices, worker, element_spec=self._element_spec))
1679
1680  @property
1681  def _component_specs(self):
1682    specs = []
1683    if _should_use_multi_device_iterator(self._options):
1684      self._get_multi_device_iterator_spec(specs)
1685    else:
1686      specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
1687    return specs
1688
1689  def _to_components(self, value):
1690    return [value._iterator]  # pylint: disable=protected-access
1691
1692  def _from_components(self, components):
1693    return _SingleWorkerOwnedDatasetIterator(
1694        dataset=None,
1695        worker=self._worker,
1696        devices=self._devices,
1697        components=components,
1698        element_spec=self._element_spec,
1699        options=self._options,
1700        canonicalize_devices=self._canonicalize_devices)
1701
1702  @staticmethod
1703  def from_value(value):
1704    # pylint: disable=protected-access
1705    return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
1706                                            value._element_spec, value._options,
1707                                            value._canonicalize_devices)
1708
1709
1710class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
1711                                        composite_tensor.CompositeTensor):
1712  """Iterator for a DistributedDataset instance."""
1713
1714  def __init__(self,
1715               dataset=None,
1716               worker=None,
1717               devices=None,
1718               components=None,
1719               element_spec=None,
1720               options=None,
1721               canonicalize_devices=None):
1722    """Create iterator for the `dataset` to fetch data to worker's `devices` .
1723
1724    `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
1725    given worker. The lifetime of this iterator is tied to the encompassing
1726    python object. Once we go out of scope of the python object or return from
1727    a tf.function the underlying iterator resource is deleted.
1728
1729    Args:
1730      dataset: A `tf.data.Dataset` instance.
1731      worker: Worker on which ops should be created.
1732      devices: Distribute data from `dataset` to these devices.
1733      components: Tensor components to construct the
1734        _SingleWorkerOwnedDatasetIterator from.
1735      element_spec: A nested structure of `TypeSpec` objects that represents the
1736      type specification of elements of the iterator.
1737      options: `tf.distribute.InputOptions` used to control options on how this
1738      dataset is distributed.
1739      canonicalize_devices: Whether to canonicalize devices for workers fully or
1740      partially. If False, it will partially canonicalize devices by removing
1741      job and task.
1742    """
1743    if worker is None or devices is None:
1744      raise ValueError("Both `worker` and `devices` should be provided")
1745
1746    error_message = ("Either `dataset` or both `components` and `element_spec` "
1747                     "need to be provided.")
1748
1749    self._options = options
1750    self._canonicalize_devices = canonicalize_devices
1751    if dataset is None:
1752      if (components is None or element_spec is None):
1753        raise ValueError(error_message)
1754      self._element_spec = element_spec
1755      self._worker = worker
1756      self._devices = devices
1757      self._iterator = components[0]
1758    else:
1759      if (components is not None or element_spec is not None):
1760        raise ValueError(error_message)
1761      super(_SingleWorkerOwnedDatasetIterator,
1762            self).__init__(dataset, worker, devices, self._options)
1763
1764  def _create_owned_multi_device_iterator(self):
1765    # If the worker devices are already canonicalized, canonicalizing again
1766    # would have no impact.
1767    # For strategies running on remote workers such as PS Strategy, the device
1768    # scope will be derived from current worker, if used under init_scope().
1769    device_scope = device_util.canonicalize(self._worker,
1770                                            device_util.current())
1771    host_device = device_util.get_host_for_device(device_scope)
1772    with ops.device(device_scope):
1773      if self._options is not None:
1774        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
1775            self._dataset,
1776            self._devices,
1777            source_device=host_device,
1778            max_buffer_size=self._options
1779            .experimental_per_replica_buffer_size,
1780            prefetch_buffer_size=self._options
1781            .experimental_per_replica_buffer_size)
1782      else:
1783        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
1784            self._dataset, self._devices, source_device=host_device)
1785
1786  def _make_iterator(self):
1787    """Make appropriate iterator on the dataset."""
1788    if not self._worker:
1789      raise ValueError("Worker device must be specified when creating an "
1790                       "owned iterator.")
1791    if _should_use_multi_device_iterator(self._options):
1792      self._create_owned_multi_device_iterator()
1793    else:
1794      with ops.device(self._worker):
1795        self._iterator = iter(self._dataset)
1796
1797  @property
1798  def element_spec(self):
1799    return self._element_spec
1800
1801  @property
1802  def _type_spec(self):
1803    return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
1804                                            self._element_spec, self._options,
1805                                            self._canonicalize_devices)
1806
1807  @property
1808  def output_classes(self):
1809    """Returns the class of each component of an element of this iterator.
1810
1811    The expected values are `tf.Tensor` and `tf.SparseTensor`.
1812
1813    Returns:
1814      A nested structure of Python `type` objects corresponding to each
1815      component of an element of this dataset.
1816    """
1817    return nest.map_structure(
1818        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
1819        self._element_spec)
1820
1821  @property
1822  def output_shapes(self):
1823    """Returns the shape of each component of an element of this iterator.
1824
1825    Returns:
1826      A nested structure of `tf.TensorShape` objects corresponding to each
1827      component of an element of this dataset.
1828    """
1829    return nest.map_structure(
1830        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
1831        self._element_spec)
1832
1833  @property
1834  def output_types(self):
1835    """Returns the type of each component of an element of this iterator.
1836
1837    Returns:
1838      A nested structure of `tf.DType` objects corresponding to each component
1839      of an element of this dataset.
1840    """
1841    return nest.map_structure(
1842        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
1843        self._element_spec)
1844
1845
1846def _create_iterators_per_worker(worker_datasets,
1847                                 input_workers,
1848                                 options=None,
1849                                 canonicalize_devices=False):
1850  """Create a multidevice iterator on each of the workers."""
1851  assert isinstance(input_workers, InputWorkers)
1852  assert len(worker_datasets) == len(input_workers.worker_devices)
1853  iterators = []
1854  for i, worker in enumerate(input_workers.worker_devices):
1855    with ops.device(worker):
1856      worker_devices = input_workers.compute_devices_for_worker(i)
1857      iterator = _SingleWorkerOwnedDatasetIterator(
1858          dataset=worker_datasets[i],
1859          worker=worker,
1860          devices=worker_devices,
1861          options=options,
1862          canonicalize_devices=canonicalize_devices)
1863      iterators.append(iterator)
1864  return iterators
1865
1866
1867def _create_datasets_from_function_with_input_context(input_contexts,
1868                                                      input_workers,
1869                                                      dataset_fn):
1870  """Create device datasets per worker given a dataset function."""
1871  datasets = []
1872  for i, ctx in enumerate(input_contexts):
1873    worker = input_workers.worker_devices[i]
1874    with ops.device(worker):
1875      dataset = dataset_fn(ctx)
1876      datasets.append(dataset)
1877  return datasets, dataset.element_spec
1878
1879
1880# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1881def _get_batched_dataset(d):
1882  """Get the batched dataset from `d`."""
1883  # pylint: disable=protected-access
1884  if isinstance(d, dataset_ops.DatasetV1Adapter):
1885    d = d._dataset
1886
1887  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
1888    return d
1889  elif isinstance(d, (dataset_ops.PrefetchDataset,
1890                      dataset_ops._OptionsDataset)):
1891    return _get_batched_dataset(d._input_dataset)
1892
1893  raise ValueError(
1894      "Unable to get batched dataset from the input dataset. `batch` "
1895      "`map_and_batch` need to be the last operations on the dataset. "
1896      "The batch operations can be followed by a prefetch.")
1897
1898
1899def _get_batched_dataset_attributes(d):
1900  """Get `batch_size`, `drop_remainder` of dataset."""
1901  # pylint: disable=protected-access
1902  assert isinstance(d,
1903                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
1904  if isinstance(d, dataset_ops.BatchDataset):
1905    batch_size = d._batch_size
1906    drop_remainder = d._drop_remainder
1907  elif isinstance(d, batching._MapAndBatchDataset):
1908    batch_size = d._batch_size_t
1909    drop_remainder = d._drop_remainder_t
1910  # pylint: enable=protected-access
1911
1912  if tensor_util.is_tf_type(batch_size):
1913    batch_size = tensor_util.constant_value(batch_size)
1914
1915  if tensor_util.is_tf_type(drop_remainder):
1916    drop_remainder = tensor_util.constant_value(drop_remainder)
1917
1918  return batch_size, drop_remainder
1919
1920
1921# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1922def _get_dataset_attributes(dataset):
1923  """Get the underlying attributes from the dataset object."""
1924  # pylint: disable=protected-access
1925
1926  # First, get batch_size and drop_remainder from the dataset. We need
1927  # to walk back the dataset creation process and find the batched version in
1928  # order to get the attributes.
1929  batched_dataset = _get_batched_dataset(dataset)
1930  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
1931
1932  # Second, prefetch buffer should be get from the original dataset.
1933  prefetch_buffer = None
1934  if isinstance(dataset, dataset_ops.PrefetchDataset):
1935    prefetch_buffer = dataset._buffer_size
1936  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
1937        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
1938    prefetch_buffer = dataset._dataset._buffer_size
1939
1940  return batch_size, drop_remainder, prefetch_buffer
1941
1942
1943def _should_use_multi_device_iterator(options):
1944  """Determine whether to use multi_device_iterator_ops."""
1945  if (options is None or
1946      options.experimental_replication_mode == InputReplicationMode.PER_WORKER
1947      or
1948      (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
1949       and options.experimental_fetch_to_device)):
1950    return True
1951  return False
1952
1953
1954class MultiStepContext(object):
1955  """A context object that can be used to capture things when running steps.
1956
1957  This context object is useful when running multiple steps at a time using the
1958  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
1959  function to specify which outputs to emit at what frequency. Currently it
1960  supports capturing output from the last step, as well as capturing non tensor
1961  outputs.  In the future it will be augmented to support other use cases such
1962  as output each N steps.
1963  """
1964
1965  def __init__(self):
1966    """Initialize an output context.
1967
1968    Returns:
1969      A context object.
1970    """
1971    self._last_step_outputs = {}
1972    self._last_step_outputs_reduce_ops = {}
1973    self._non_tensor_outputs = {}
1974
1975  @property
1976  def last_step_outputs(self):
1977    """A dictionary consisting of outputs to be captured on last step.
1978
1979    Keys in the dictionary are names of tensors to be captured, as specified
1980    when `set_last_step_output` is called.
1981    Values in the dictionary are the tensors themselves. If
1982    `set_last_step_output` was called with a `reduce_op` for this output,
1983    then the value is the reduced value.
1984
1985    Returns:
1986      A dictionary with last step outputs.
1987    """
1988    return self._last_step_outputs
1989
1990  def _set_last_step_outputs(self, outputs):
1991    """Replace the entire dictionary of last step outputs."""
1992    if not isinstance(outputs, dict):
1993      raise ValueError("Need a dictionary to set last_step_outputs.")
1994    self._last_step_outputs = outputs
1995
1996  def set_last_step_output(self, name, output, reduce_op=None):
1997    """Set `output` with `name` to be outputted from the last step.
1998
1999    Args:
2000      name: String, name to identify the output. Doesn't need to match tensor
2001        name.
2002      output: The tensors that should be outputted with `name`. See below for
2003        actual types supported.
2004      reduce_op: Reduction method to use to reduce outputs from multiple
2005        replicas. Required if `set_last_step_output` is called in a replica
2006        context. Optional in cross_replica_context.
2007        When present, the outputs from all the replicas are reduced using the
2008        current distribution strategy's `reduce` method. Hence, the type of
2009        `output` must be what's supported by the corresponding `reduce` method.
2010        For e.g. if using MirroredStrategy and reduction is set, output
2011        must be a `PerReplica` value.
2012        The reduce method is also recorded in a dictionary
2013        `_last_step_outputs_reduce_ops` for later interpreting of the
2014        outputs as already reduced or not.
2015    """
2016    if distribution_strategy_context.in_cross_replica_context():
2017      self._last_step_outputs_reduce_ops[name] = reduce_op
2018      if reduce_op is None:
2019        self._last_step_outputs[name] = output
2020      else:
2021        distribution = distribution_strategy_context.get_strategy()
2022        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
2023                                                            axis=None)
2024    else:
2025      assert reduce_op is not None
2026      def merge_fn(distribution, value):
2027        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
2028                                                            axis=None)
2029        # Setting this inside the `merge_fn` because all replicas share the same
2030        # context object, so it's more robust to set it only once (even if all
2031        # the replicas are trying to set the same value).
2032        self._last_step_outputs_reduce_ops[name] = reduce_op
2033
2034      distribution_strategy_context.get_replica_context().merge_call(
2035          merge_fn, args=(output,))
2036
2037  @property
2038  def non_tensor_outputs(self):
2039    """A dictionary consisting of any non tensor outputs to be captured."""
2040    return self._non_tensor_outputs
2041
2042  def set_non_tensor_output(self, name, output):
2043    """Set `output` with `name` to be captured as a non tensor output."""
2044    if distribution_strategy_context.in_cross_replica_context():
2045      self._non_tensor_outputs[name] = output
2046    else:
2047      def merge_fn(distribution, value):
2048        # NOTE(priyag): For non tensor outputs, we simply return all the values
2049        # in a list as reduction doesn't make sense on non tensors.
2050        self._non_tensor_outputs[name] = (
2051            distribution.experimental_local_results(value))
2052      distribution_strategy_context.get_replica_context().merge_call(
2053          merge_fn, args=(output,))
2054
2055
2056def _create_distributed_tensor_spec(strategy, tensor_spec):
2057  """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
2058
2059  Args:
2060    strategy: The given `tf.distribute` strategy.
2061    tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
2062      shape should be None if you have partial batches.
2063
2064  Returns:
2065    A `tf.TypeSpec` that matches the values produced by a given strategy. This
2066    can be a `tf.TensorSpec` or a `PerRelicaSpec`.
2067  """
2068  num_replicas = len(strategy.extended.worker_devices)
2069
2070  # For one device strategy that is not MultiWorkerMirroredStrategy,  return the
2071  # tensor_spec as is, since we don't wrap the output with PerReplica in this
2072  # case.
2073  # TODO(b/166464552): remove after we always wrap for all strategies.
2074  if not _always_wrap(strategy):
2075    return tensor_spec
2076
2077  # For other cases we assume the input to tf.function is a per replica type.
2078  def _get_value_per_replica(tensor_spec_per_input):
2079    value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
2080    return values.PerReplicaSpec(*value_specs)
2081
2082  return nest.map_structure(_get_value_per_replica, tensor_spec)
2083
2084
2085def _replace_per_replica_spec(spec, i):
2086  """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
2087  if isinstance(spec, values.PerReplicaSpec):
2088    return spec._value_specs[i]  # pylint: disable=protected-access
2089  else:
2090    return spec
2091
2092
2093def _cardinality(dataset):
2094  """Returns the cardinality of the dataset."""
2095  if context.executing_eagerly():
2096    with ops.device(dataset._variant_tensor.device):  # pylint: disable=protected-access
2097      return dataset.cardinality().numpy()
2098  return cardinality_lib.UNKNOWN
2099
2100
2101def _enable_get_next_as_optional(strategy, dataset, cardinality):
2102  """Returns whether to enable using partial batch handling."""
2103  # TODO(b/133073708): we currently need a flag to control the usage because
2104  # there is a performance difference between get_next() and
2105  # get_next_as_optional(). And we only enable get_next_as_optional when the
2106  # output shapes are not static.
2107  #
2108  # TODO(rxsang): We want to always enable the get_next_as_optional behavior
2109  # when user passed input_fn instead of dataset.
2110  if not getattr(
2111      strategy.extended, "enable_partial_batch_handling",
2112      getattr(strategy.extended, "experimental_enable_get_next_as_optional",
2113              False)):
2114    return False
2115
2116  # If the dataset is infinite, we don't need to enable last partial batch
2117  # support. Note that we can only evaluate the cardinality of the dataset in
2118  # eager.
2119  if cardinality == cardinality_lib.INFINITE:
2120    return False
2121
2122  return not _is_statically_shaped(
2123      dataset.element_spec) or strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2124
2125
2126def _create_per_replica(value_list, strategy):
2127  """Creates a PerReplica.
2128
2129  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
2130  type spec is set to the element spec of the dataset. This helps avoid
2131  retracing for partial batches. Retracing is problematic for multi client when
2132  different client retraces different time, since retracing changes the
2133  collective keys in the tf.function, and causes mismatches among clients.
2134
2135  For single client strategies, this simply calls distribute_utils.regroup().
2136
2137  Args:
2138    value_list: a list of values, one for each replica.
2139    strategy: the `tf.distribute.Strategy`.
2140
2141  Returns:
2142    a structure of PerReplica.
2143
2144  """
2145  # TODO(b/166464552): always wrap for all one device strategies as well.
2146  always_wrap = _always_wrap(strategy)
2147  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
2148  return per_replicas
2149
2150
2151def _always_wrap(strategy):
2152  """Returns whether to always wrap the values in a DistributedValues."""
2153  return strategy.extended._in_multi_worker_mode() or len(  # pylint: disable=protected-access
2154      strategy.extended.worker_devices) > 1
2155
2156
2157def _rebatch_as_dynamic(per_replica_spec):
2158  """Rebatch the spec to have a dynamic batch dimension."""
2159  assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
2160
2161  # pylint: disable=protected-access
2162  def _rebatch(spec):
2163    # Rebatch if possible.
2164    try:
2165      return spec._unbatch()._batch(None)
2166    except ValueError:
2167      pass
2168    return spec
2169
2170  return values.PerReplicaSpec(
2171      *nest.map_structure(_rebatch, per_replica_spec._value_specs))
2172  # pylint: enable=protected-access
2173