xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/distribute.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Distribution Strategy-related dataset transformations."""
16import numpy as np
17
18from tensorflow.python.data.ops import dataset_ops
19from tensorflow.python.data.ops.options import ExternalStatePolicy
20from tensorflow.python.data.util import nest
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
28from tensorflow.python.util.tf_export import tf_export
29
30SHARD_HINT = -1
31tf_export("data.experimental.SHARD_HINT").export_constant(
32    __name__, "SHARD_HINT")
33
34
35class _AutoShardDataset(dataset_ops.UnaryDataset):
36  """A `Dataset` that shards the `Dataset` automatically.
37
38  This dataset takes in an existing dataset and tries to automatically figure
39  out how to shard the dataset in a multi-worker scenario using graph rewrites.
40
41  If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
42  it finds a reader dataset, then inserts a ShardDataset op before that node
43  so that each worker only sees some files.
44
45  If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
46  end of the input pipeline, before any terminal PrefetchDataset if there is
47  one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
48  is written to legacy RebatchDataset for correctness reasons, since
49  RebatchDatasetV2 is incompatible with data sharding.
50
51  If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
52  If it cannot find a reader dataset, it falls back to doing data-based
53  sharding.
54
55  If the AutoShardPolicy is set to OFF, it does nothing.
56
57  Attributes:
58    num_workers: Total number of workers to shard this dataset across.
59    index: The current worker index (out of the total number of workers) this
60      dataset is for.
61    num_replicas: The total number of replicas across all workers. This is used
62      only when sharding by data (either DATA or AUTO) in order to rewrite
63      RebatchDatasetV2 to RebatchDataset.
64
65  Raises:
66    NotFoundError: If we cannot find a suitable reader dataset to begin
67      automatically sharding the dataset.
68  """
69
70  def __init__(self, input_dataset, num_workers, index, num_replicas=None):
71    self._input_dataset = input_dataset
72
73    self._element_spec = input_dataset.element_spec
74    variant_tensor = ged_ops.auto_shard_dataset(
75        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
76        num_workers=num_workers,
77        index=index,
78        auto_shard_policy=int(
79            input_dataset.options().experimental_distribute.auto_shard_policy),
80        num_replicas=num_replicas,
81        **self._flat_structure)
82    super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
83
84  @property
85  def element_spec(self):
86    return self._element_spec
87
88
89def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None):  # pylint: disable=invalid-name
90  return dataset_ops.DatasetV1Adapter(
91      _AutoShardDataset(input_dataset, num_workers, index, num_replicas))
92
93
94class _RebatchDataset(dataset_ops.UnaryDataset):
95  """A `Dataset` that rebatches elements from its input into new batch sizes.
96
97  `_RebatchDataset(input_dataset, batch_sizes)` is functionally equivalent to
98  `input_dataset.unbatch().batch(N)`, where the value of N cycles through the
99  `batch_sizes` input list. The elements produced by this dataset have the same
100  rank as the elements of the input dataset.
101
102  For example:
103
104  ```python
105  ds = tf.data.Dataset.range(8)
106  ds = ds.batch(4)
107  ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1])
108  for elem in ds:
109    print(elem)
110  >> [0, 1], [2], [3], [4, 5], [6], [7]
111
112  ds = tf.data.Dataset.range(16)
113  ds = ds.batch(4)
114  ds = _RebatchDataset(ds, batch_sizes=[6])
115  for elem in ds:
116    print(elem)
117  >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15]
118  ```
119  """
120
121  def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
122    """Creates a _RebatchDataset.
123
124    Args:
125      input_dataset: `Dataset` to rebatch.
126      batch_sizes: A `tf.int64` scalar or vector, representing the size of
127        batches to produce. If this argument is a vector, these values are
128        cycled through in order.
129      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
130        whether the last batch should be dropped in the case it has fewer than
131        `batch_sizes[cycle_index] elements; the default behavior is not to drop
132        the smaller batch.
133    """
134    self._input_dataset = input_dataset
135    self._batch_sizes = ops.convert_to_tensor(
136        batch_sizes, dtype=dtypes.int64, name="batch_sizes")
137    self._drop_remainder = ops.convert_to_tensor(
138        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
139    new_batch_dim = self._compute_static_batch_dim()
140
141    # pylint: disable=protected-access
142    self._element_spec = nest.map_structure(
143        lambda ts: ts._unbatch()._batch(new_batch_dim),
144        dataset_ops.get_structure(input_dataset))
145    # pylint: enable=protected-access
146
147    # auto_shard rewrite assumes that there's normalize_to_dense before
148    # rebatch_dataset.
149    # LINT.IfChange
150    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
151    variant_tensor = ged_ops.rebatch_dataset_v2(
152        input_dataset._variant_tensor,  # pylint: disable=protected-access
153        batch_sizes=batch_sizes,
154        drop_remainder=drop_remainder,
155        **self._flat_structure)
156    # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
157    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
158
159  def _compute_static_batch_dim(self):
160    """Computes the static batch dimension of a dataset if it can be determined.
161
162    Given the _RebatchDataset parameters, determines the batch dimension of this
163    dataset statically. Returns None if this cannot be determined or is
164    variable.
165
166    Returns:
167      An integer representing the batch dimension of the dataset. If it cannot
168      be determined statically, returns None.
169
170    Raises:
171      ValueError: The batch_sizes parameter is malformed, input_dataset is
172      not batched, or input_dataset batch sizes are incompatible with each
173      other.
174    """
175    new_batch_dim = tensor_util.constant_value(self._batch_sizes)
176    if new_batch_dim is None:
177      return None
178
179    if isinstance(new_batch_dim, np.ndarray):
180      if len(new_batch_dim.shape) == 1:
181        if np.all(new_batch_dim == new_batch_dim[0]):
182          new_batch_dim = new_batch_dim[0]
183        else:
184          return None
185      elif len(new_batch_dim.shape) > 1:
186        raise ValueError(
187            f"Invalid `batch_sizes`. Expected `batch_sizes` to be a scalar or "
188            f"a vector. Received `batch_sizes` of rank "
189            f"{len(new_batch_dim.shape)}.")
190
191    if self._may_form_partial_batches(new_batch_dim):
192      return None
193
194    return new_batch_dim
195
196  def _may_form_partial_batches(self, desired_batch_size):
197    """Returns whether this dataset may form partial batches."""
198    if tensor_util.constant_value(self._drop_remainder):
199      return False
200
201    def get_batch_dim(type_spec):
202      try:
203        shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
204      except NotImplementedError:
205        return None
206      if not isinstance(shape, tensor_shape.TensorShape):
207        return None
208      if shape.rank is None:
209        return None
210      if len(shape) < 1:
211        raise ValueError("Invalid `batch_sizes`. Expected dataset with "
212                         "rank of >= 1 but found a dataset with "
213                         "scalar elements. Fix the issue by adding the `batch` "
214                         "transformation to the dataset.")
215      return shape.dims[0].value
216
217    input_batch_dims = [
218        get_batch_dim(ts)
219        for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset))
220    ]
221    known_input_batch_dims = [d for d in input_batch_dims if d is not None]
222
223    if not known_input_batch_dims:
224      return True
225
226    known_input_batch_dims = np.asarray(known_input_batch_dims)
227    if not np.all(known_input_batch_dims == known_input_batch_dims[0]):
228      raise ValueError(
229          f"Invalid `input_dataset.` The batch dimension of component 0 "
230          f"is {known_input_batch_dims[0]}, while the batch dimension "
231          f"of component i is {known_input_batch_dims}.")
232
233    return known_input_batch_dims[0] % desired_batch_size != 0
234
235  @property
236  def element_spec(self):
237    return self._element_spec
238
239
240class _LegacyRebatchDataset(dataset_ops.UnaryDataset):
241  """A `Dataset` that divides its input batches into `num_replicas` sub-batches.
242
243  For each batch in the input dataset, _LegacyRebatchDataset will produce
244  `num_replicas` smaller batches whose sizes add up to the original batch size.
245
246  For example:
247
248  ```python
249  ds = tf.data.Dataset.range(8)
250  ds = ds.batch(4)
251  ds = _LegacyRebatchDataset(ds, num_replicas=3)
252  for elem in ds:
253    print(elem)
254  >> [0, 1], [2, 3], [], [4, 5], [6, 7], []
255  ```
256  """
257
258  def __init__(self, input_dataset, num_replicas):
259    """Creates a _LegacyRebatchDataset.
260
261    Args:
262      input_dataset: `Dataset` to rebatch.
263      num_replicas: A `tf.int64` scalar, representing the number of sub-batches
264        to split each batch from `input_dataset` into.
265    """
266
267    def recalculate_batch_size(type_spec):
268      """Recalculates the output_shape after dividing it by num_replicas."""
269      output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
270      if not isinstance(output_shape, tensor_shape.TensorShape):
271        return None
272
273      # If the output shape is unknown, we set the batch dimension to unknown.
274      if output_shape.rank is None:
275        return None
276
277      if len(output_shape) < 1:
278        raise ValueError(
279            "Invalid `input_dataset`. Expected a dataset whose elements "
280            "have rank >= 1 but found a dataset whose elements are scalars. "
281            "Fix the issue by adding the `batch` transformation to the "
282            "dataset.")
283      output_dims = [d.value for d in output_shape.dims]
284
285      if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
286        return output_dims[0] // num_replicas
287
288      # Set the batch dimension to unknown. If the global batch size does not
289      # divide num_replicas evenly, the minibatches may have different sizes.
290      return None
291
292    def rebatch(type_spec):
293      # pylint: disable=protected-access
294      batch_size = recalculate_batch_size(type_spec)
295      return type_spec._unbatch()._batch(batch_size)
296      # pylint: enable=protected-access
297
298    self._element_spec = nest.map_structure(
299        rebatch, dataset_ops.get_structure(input_dataset))
300
301    # auto_shard rewrite assumes that there's normalize_to_dense before
302    # rebatch_dataset.
303    # LINT.IfChange
304    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
305    variant_tensor = ged_ops.rebatch_dataset(
306        input_dataset._variant_tensor,  # pylint: disable=protected-access
307        num_replicas=num_replicas,
308        **self._flat_structure)
309    # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
310    super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
311
312  @property
313  def element_spec(self):
314    return self._element_spec
315
316
317class _RemoteDataset(dataset_ops.DatasetSource):
318  """Creates a dataset on a given `device` given a graph def."""
319
320  def __init__(self, graph_def, device, element_spec):
321    self._elem_spec = element_spec
322    with ops.device(device):
323      variant_tensor = ged_ops.dataset_from_graph(graph_def)
324    super(_RemoteDataset, self).__init__(variant_tensor)
325
326  @property
327  def element_spec(self):
328    return self._elem_spec
329
330
331def replicate(dataset, devices):
332  """A transformation that replicates `dataset` onto a list of devices.
333
334  Args:
335    dataset: A `tf.data.Dataset` object.
336    devices: A list of devices to replicate the dataset on.
337
338  Returns:
339    A dictionary mapping device name to a dataset on that device.
340  """
341  if not isinstance(dataset, dataset_ops.DatasetV2):
342    raise TypeError(
343        f"Invalid `dataset`. Expected a `tf.data.Dataset` object but "
344        f"got {type(dataset)}.")
345
346  # pylint: disable=protected-access
347  dataset_device = dataset._variant_tensor.device
348
349  datasets = {}
350  if len(devices) == 1 and devices[0] == dataset_device:
351    datasets[devices[0]] = dataset
352    return datasets
353
354  with ops.colocate_with(dataset._variant_tensor):
355    dataset = dataset._apply_debug_options()
356    graph_def = dataset._as_serialized_graph(
357        strip_device_assignment=True,
358        external_state_policy=ExternalStatePolicy.WARN)
359  for device in devices:
360    ds = _RemoteDataset(graph_def, device, dataset.element_spec)
361    datasets[device] = ds
362  return datasets
363
364
365def batch_sizes_for_worker(global_batch_size, num_workers,
366                           num_replicas_per_worker, worker_index):
367  """Determines how to rebatch a dataset for the given worker.
368
369  Given the global batch size, number of workers, number of replicas per worker,
370  and worker index, returns the correct batch sizes for rebatching a dataset
371  on worker `worker_index` of `num_workers`, such that each global step (across
372  all workers and replicas) will consume global_batch_size elements. The
373  returned value should be passed as the `batch_sizes` input parameter to
374  `tf.data.experimental.rebatch()`. The returned batch sizes meet the following
375  constraints:
376
377  Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker
378  (A) for any worker, len(batch_sizes) = W * R
379  (B) for any worker, sum(batch_sizes) == G
380  (C) for any global step (i.e. R iterations on each worker), the sum of batches
381      consumed by replicas across all workers is G.
382  (D) any two batch sizes of any two replicas differs by at most one.
383
384  For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two
385  files which each contain 7 elements:
386
387  ```python
388  # WORKER 0
389  batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size,
390                                         num_workers=2,
391                                         num_replicas_per_worker=2,
392                                         worker_index=0)
393  print(batch_sizes_0)
394  >> [2, 2, 2, 1]
395
396  dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
397  dataset_0 = dataset_0.shard(num_shards, index=0)
398  dataset_0 = dataset_0.batch(7)
399  dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0))
400  for elem in dataset_0:
401    print(elem)
402  >> [[A0, A1], [A2, A3], [A4, A5], [A6]]
403
404  # WORKER 1
405  batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size,
406                                         num_workers=2,
407                                         num_replicas_per_worker=2,
408                                         worker_index=1)
409  print(batch_sizes_1)
410  >> [2, 1, 2, 2]
411
412  dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
413  dataset_1 = dataset_1.shard(num_shards, index=1)
414  dataset_1 = dataset_1.batch(7)
415  dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1))
416  for elem in dataset_1:
417    print(elem)
418  >> [[B0, B1], [B2], [B3, B4], [B5, B6]]
419  ```
420
421  The above example will produce the following elements:
422
423  Step 1:
424    Worker 0 Replica 0: [A0, A1]
425    Worker 0 Replica 1: [A2, A3]
426    Worker 1 Replica 0: [B0, B1]
427    Worker 1 Replica 1: [B2]
428  Total batch size = 7
429
430  Step 2:
431    Worker 0 Replica 0: [A4, A5]
432    Worker 0 Replica 1: [A6]
433    Worker 1 Replica 0: [B3, B4]
434    Worker 1 Replica 1: [B5, B6]
435  Total batch size = 7
436
437  Args:
438    global_batch_size: A `tf.int64` scalar, representing the global batch size.
439    num_workers: An integer representing the number of workers the dataset will
440      be distributed across.
441    num_replicas_per_worker: An integer representing the number of replicas per
442      worker. All workers are assumed to have the same number of replicas.
443    worker_index: An integer index of the worker to be rebatched.
444
445  Returns:
446    A `tf.int64` vector, representing the batch sizes to rebatch the dataset
447    into.
448  """
449  # Constraint (A)
450  num_subbatches = num_workers * num_replicas_per_worker
451
452  offset = worker_index * num_replicas_per_worker
453
454  const_value = tensor_util.constant_value(global_batch_size)
455  if const_value is not None:
456    # Use the constant global batch size for further calculations
457    global_batch_size = const_value
458
459  # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations
460  # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N
461  # subbatches a batch is split into, B - N * floor(B/N) of them will have size
462  # ceil(B/N), and the rest will have size floor(B/N).
463  floor = global_batch_size // num_subbatches
464  num_ceil = global_batch_size - (num_subbatches * floor)
465
466  # For worker 0, we assign the first num_ceil subbatches to have size
467  # ceil(B/N), and the remainder to have size floor(B/N). The other workers will
468  # each be offset by R * worker_index in order to meet constraint (C).
469  if const_value is not None:
470    # If the global batch size is a known constant value, we return a constant
471    # tensor directly instead of manipulating it with TF ops. This allows for
472    # better downstream shape inference.
473    worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil)
474    return ops.convert_to_tensor(
475        worker_0[offset:] + worker_0[:offset],
476        dtype=dtypes.int64,
477        name="batch_sizes")
478
479  worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64)
480  worker_0 = floor * worker_0 + array_ops.concat([
481      array_ops.ones(num_ceil, dtype=dtypes.int64),
482      array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64)
483  ],
484                                                 axis=0)
485
486  return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0)
487
488
489def compute_batch_size(dataset):
490  """An operation that returns the batch size of the dataset.
491
492  This op tries to infer the batch size statically by walking up the dataset
493  tree from the final dataset node and returning the batch size of the first
494  batching dataset (such as from .batch() and .padded_batch()) that it
495  encounters. This differs from using the `element_spec` of a dataset in that it
496  does not account for partial batches.
497
498  This operation may fail if it encounters contradictory batch sizes (for
499  example, if the dataset is created by zipping together two datasets with
500  different batch sizes), if there are no explicit batching transformations, or
501  if there are operations downstream from the batching transformation that may
502  modify its batch size. In these cases, it returns a -1.
503
504  Args:
505    dataset: A `tf.data.Dataset` object.
506
507  Returns:
508    A `tf.int64` Tensor representing the batch size of the dataset sans partial
509    batches. If this cannot be inferred statically, the value of this tensor
510    will be -1.
511  """
512
513  def get_static_batch_dim(type_spec):
514    try:
515      output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
516    except NotImplementedError:
517      return None
518    if not isinstance(output_shape, tensor_shape.TensorShape):
519      return None
520    if output_shape.rank is None:
521      return None
522    return output_shape.dims[0].value
523
524  batch_dims = [
525      get_static_batch_dim(type_spec)
526      for type_spec in nest.flatten(dataset_ops.get_structure(dataset))
527  ]
528
529  if all(d is not None for d in batch_dims):
530
531    if all(d == batch_dims[0] for d in batch_dims):
532      # If all batch dimensions are known and equal, return that directly.
533      batch_dim = batch_dims[0]
534    else:
535      # If all batch dimensions are known but not all equal, return -1.
536      batch_dim = -1
537
538    return constant_op.constant(
539        batch_dim, dtype=dtypes.int64, name="static_batch_size")
540
541  # If any batch dimensions are unknown, use compute_batch_size op.
542  return ged_ops.compute_batch_size(dataset._variant_tensor)  # pylint: disable=protected-access
543
544
545_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
546