xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/batching.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Batching dataset transformations."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.ops import structured_function
18from tensorflow.python.data.util import convert
19from tensorflow.python.data.util import nest
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export("data.experimental.dense_to_ragged_batch")
33def dense_to_ragged_batch(batch_size,
34                          drop_remainder=False,
35                          row_splits_dtype=dtypes.int64):
36  """A transformation that batches ragged elements into `tf.RaggedTensor`s.
37
38  This transformation combines multiple consecutive elements of the input
39  dataset into a single element.
40
41  Like `tf.data.Dataset.batch`, the components of the resulting element will
42  have an additional outer dimension, which will be `batch_size` (or
43  `N % batch_size` for the last element if `batch_size` does not divide the
44  number of input elements `N` evenly and `drop_remainder` is `False`). If
45  your program depends on the batches having the same outer dimension, you
46  should set the `drop_remainder` argument to `True` to prevent the smaller
47  batch from being produced.
48
49  Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
50  different shapes:
51
52  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
53     fully defined, then it is batched as normal.
54  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
55     one or more axes with unknown size (i.e., `shape[i]=None`), then the output
56     will contain a `tf.RaggedTensor` that is ragged up to any of such
57     dimensions.
58  *  If an input element is a `tf.RaggedTensor` or any other type, then it is
59     batched as normal.
60
61  Example:
62
63  >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
64  >>> dataset = dataset.map(lambda x: tf.range(x))
65  >>> dataset.element_spec.shape
66  TensorShape([None])
67  >>> dataset = dataset.apply(
68  ...     tf.data.experimental.dense_to_ragged_batch(batch_size=2))
69  >>> for batch in dataset:
70  ...   print(batch)
71  <tf.RaggedTensor [[], [0]]>
72  <tf.RaggedTensor [[0, 1], [0, 1, 2]]>
73  <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
74
75  Args:
76    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
77      consecutive elements of this dataset to combine in a single batch.
78    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
79      whether the last batch should be dropped in the case it has fewer than
80      `batch_size` elements; the default behavior is not to drop the smaller
81      batch.
82    row_splits_dtype: The dtype that should be used for the `row_splits` of any
83      new ragged tensors.  Existing `tf.RaggedTensor` elements do not have their
84      row_splits dtype changed.
85
86  Returns:
87    Dataset: A `Dataset`.
88  """
89
90  def _apply_fn(dataset):
91    ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype)
92    return dataset_ops.BatchDataset(
93        ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder)
94
95  return _apply_fn
96
97
98@tf_export("data.experimental.dense_to_sparse_batch")
99def dense_to_sparse_batch(batch_size, row_shape):
100  """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s.
101
102  Like `Dataset.padded_batch()`, this transformation combines multiple
103  consecutive elements of the dataset, which might have different
104  shapes, into a single element. The resulting element has three
105  components (`indices`, `values`, and `dense_shape`), which
106  comprise a `tf.sparse.SparseTensor` that represents the same data. The
107  `row_shape` represents the dense shape of each row in the
108  resulting `tf.sparse.SparseTensor`, to which the effective batch size is
109  prepended. For example:
110
111  ```python
112  # NOTE: The following examples use `{ ... }` to represent the
113  # contents of a dataset.
114  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
115
116  a.apply(tf.data.experimental.dense_to_sparse_batch(
117      batch_size=2, row_shape=[6])) ==
118  {
119      ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],  # indices
120       ['a', 'b', 'c', 'a', 'b'],                 # values
121       [2, 6]),                                   # dense_shape
122      ([[0, 0], [0, 1], [0, 2], [0, 3]],
123       ['a', 'b', 'c', 'd'],
124       [1, 6])
125  }
126  ```
127
128  Args:
129    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
130      consecutive elements of this dataset to combine in a single batch.
131    row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object
132      representing the equivalent dense shape of a row in the resulting
133      `tf.sparse.SparseTensor`. Each element of this dataset must have the same
134      rank as `row_shape`, and must have size less than or equal to `row_shape`
135      in each dimension.
136
137  Returns:
138    A `Dataset` transformation function, which can be passed to
139    `tf.data.Dataset.apply`.
140  """
141
142  def _apply_fn(dataset):
143    return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
144
145  return _apply_fn
146
147
148@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()")
149@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"])
150def map_and_batch_with_legacy_function(map_func,
151                                       batch_size,
152                                       num_parallel_batches=None,
153                                       drop_remainder=False,
154                                       num_parallel_calls=None):
155  """Fused implementation of `map` and `batch`.
156
157  NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not
158  work with V2 functions. New uses are strongly discouraged and existing uses
159  should migrate to `map_and_batch` as this method will not be removed in V2.
160
161  Args:
162    map_func: A function mapping a nested structure of tensors to another
163      nested structure of tensors.
164    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
165      consecutive elements of this dataset to combine in a single batch.
166    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
167      representing the number of batches to create in parallel. On one hand,
168      higher values can help mitigate the effect of stragglers. On the other
169      hand, higher values can increase contention if CPU is scarce.
170    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
171      whether the last batch should be dropped in case its size is smaller than
172      desired; the default behavior is not to drop the smaller batch.
173    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
174      representing the number of elements to process in parallel. If not
175      specified, `batch_size * num_parallel_batches` elements will be processed
176      in parallel. If the value `tf.data.AUTOTUNE` is used, then
177      the number of parallel calls is set dynamically based on available CPU.
178
179  Returns:
180    A `Dataset` transformation function, which can be passed to
181    `tf.data.Dataset.apply`.
182
183  Raises:
184    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
185      specified.
186  """
187
188  if num_parallel_batches is None and num_parallel_calls is None:
189    num_parallel_calls = batch_size
190  elif num_parallel_batches is not None and num_parallel_calls is None:
191    num_parallel_calls = batch_size * num_parallel_batches
192  elif num_parallel_batches is not None and num_parallel_calls is not None:
193    raise ValueError(
194        "`map_and_batch_with_legacy_function` allows only one of "
195        "`num_parallel_batches` and "
196        "`num_parallel_calls` to be set, but "
197        f"`num_parallel_batches` was set to {num_parallel_batches} "
198        f"and `num_parallel_calls` as set to {num_parallel_calls}.")
199
200  def _apply_fn(dataset):
201    return _MapAndBatchDataset(dataset, map_func, batch_size,
202                               num_parallel_calls, drop_remainder,
203                               use_legacy_function=True)
204
205  return _apply_fn
206
207
208@deprecation.deprecated(
209    None,
210    "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
211    "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
212    "optimizations will take care of using the fused implementation.")
213@tf_export("data.experimental.map_and_batch")
214def map_and_batch(map_func,
215                  batch_size,
216                  num_parallel_batches=None,
217                  drop_remainder=False,
218                  num_parallel_calls=None):
219  """Fused implementation of `map` and `batch`.
220
221  Maps `map_func` across `batch_size` consecutive elements of this dataset
222  and then combines them into a batch. Functionally, it is equivalent to `map`
223  followed by `batch`. This API is temporary and deprecated since input pipeline
224  optimization now fuses consecutive `map` and `batch` operations automatically.
225
226  Args:
227    map_func: A function mapping a nested structure of tensors to another
228      nested structure of tensors.
229    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
230      consecutive elements of this dataset to combine in a single batch.
231    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
232      representing the number of batches to create in parallel. On one hand,
233      higher values can help mitigate the effect of stragglers. On the other
234      hand, higher values can increase contention if CPU is scarce.
235    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
236      whether the last batch should be dropped in case its size is smaller than
237      desired; the default behavior is not to drop the smaller batch.
238    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
239      representing the number of elements to process in parallel. If not
240      specified, `batch_size * num_parallel_batches` elements will be processed
241      in parallel. If the value `tf.data.AUTOTUNE` is used, then
242      the number of parallel calls is set dynamically based on available CPU.
243
244  Returns:
245    A `Dataset` transformation function, which can be passed to
246    `tf.data.Dataset.apply`.
247
248  Raises:
249    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
250      specified.
251  """
252
253  if num_parallel_batches is None and num_parallel_calls is None:
254    num_parallel_calls = batch_size
255  elif num_parallel_batches is not None and num_parallel_calls is None:
256    num_parallel_calls = batch_size * num_parallel_batches
257  elif num_parallel_batches is not None and num_parallel_calls is not None:
258    raise ValueError(
259        "`map_and_batch` allows only one of `num_parallel_batches` and "
260        "`num_parallel_calls` to be set, but "
261        f"`num_parallel_batches` was set to {num_parallel_batches} "
262        f"and `num_parallel_calls` as set to {num_parallel_calls}.")
263
264  def _apply_fn(dataset):
265    return _MapAndBatchDataset(dataset, map_func, batch_size,
266                               num_parallel_calls, drop_remainder)
267
268  return _apply_fn
269
270
271@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
272@tf_export("data.experimental.unbatch")
273def unbatch():
274  """Splits elements of a dataset into multiple elements on the batch dimension.
275
276  For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
277  where `B` may vary for each input element, then for each element in the
278  dataset, the unbatched dataset will contain `B` consecutive elements
279  of shape `[a0, a1, ...]`.
280
281  ```python
282  # NOTE: The following example uses `{ ... }` to represent the contents
283  # of a dataset.
284  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
285
286  a.unbatch() == {
287      'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
288  ```
289
290  Returns:
291    A `Dataset` transformation function, which can be passed to
292    `tf.data.Dataset.apply`.
293  """
294
295  def _apply_fn(dataset):
296    return dataset.unbatch()
297
298  return _apply_fn
299
300
301class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
302  """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s."""
303
304  def __init__(self, input_dataset, batch_size, row_shape):
305    """See `Dataset.dense_to_sparse_batch()` for more details."""
306    if not isinstance(
307        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
308      raise TypeError("`dense_to_sparse_batch` requires an input dataset whose "
309                      "elements have a single component, but the given dataset "
310                      "has the following component types: "
311                      f"{dataset_ops.get_legacy_output_types(input_dataset)}.")
312    self._input_dataset = input_dataset
313    self._batch_size = batch_size
314    self._row_shape = row_shape
315    self._element_spec = sparse_tensor.SparseTensorSpec(
316        tensor_shape.TensorShape([None]).concatenate(self._row_shape),
317        dataset_ops.get_legacy_output_types(input_dataset))
318
319    variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
320        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
321        self._batch_size,
322        row_shape=convert.partial_shape_to_tensor(self._row_shape),
323        **self._flat_structure)
324    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
325                                                     variant_tensor)
326
327  @property
328  def element_spec(self):
329    return self._element_spec
330
331
332class _MapAndBatchDataset(dataset_ops.UnaryDataset):
333  """A `Dataset` that maps a function over a batch of elements."""
334
335  def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
336               drop_remainder, use_legacy_function=False):
337    self._input_dataset = input_dataset
338
339    self._map_func = structured_function.StructuredFunctionWrapper(
340        map_func,
341        "tf.data.experimental.map_and_batch()",
342        dataset=input_dataset,
343        use_legacy_function=use_legacy_function)
344    self._batch_size_t = ops.convert_to_tensor(
345        batch_size, dtype=dtypes.int64, name="batch_size")
346    self._num_parallel_calls_t = ops.convert_to_tensor(
347        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
348    self._drop_remainder_t = ops.convert_to_tensor(
349        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
350
351    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
352    # pylint: disable=protected-access
353    if constant_drop_remainder:
354      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
355      # or `False` (explicitly retaining the remainder).
356      # pylint: disable=g-long-lambda
357      self._element_spec = nest.map_structure(
358          lambda component_spec: component_spec._batch(
359              tensor_util.constant_value(self._batch_size_t)),
360          self._map_func.output_structure)
361    else:
362      self._element_spec = nest.map_structure(
363          lambda component_spec: component_spec._batch(None),
364          self._map_func.output_structure)
365    # pylint: enable=protected-access
366    variant_tensor = ged_ops.map_and_batch_dataset(
367        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
368        self._map_func.function.captured_inputs,
369        f=self._map_func.function,
370        batch_size=self._batch_size_t,
371        num_parallel_calls=self._num_parallel_calls_t,
372        drop_remainder=self._drop_remainder_t,
373        preserve_cardinality=True,
374        **self._flat_structure)
375    super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
376
377  def _functions(self):
378    return [self._map_func]
379
380  @property
381  def element_spec(self):
382    return self._element_spec
383
384
385class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
386  """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0).
387
388  In particular:
389
390  * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with
391    ragged_rank=0.  This allows tensors with varying shape to be batched
392    together.
393  * Any other elements are left as-is.
394  """
395
396  def __init__(self, input_dataset, row_splits_dtype):
397    """Constructs a new _DenseToRaggedDataset.
398
399    Args:
400      input_dataset: The dataset whose tf.Tensor elements should be made ragged.
401      row_splits_dtype: The dtype that should be used for the `row_splits` of
402        any new ragged tensors.  Existing `tf.RaggedTensor` elements do *not*
403        have their row_splits dtype changed.
404    """
405    # Replace each TensorSpec in the input dataset's structure with a
406    # corresponding RaggedTensorSpec.
407    def to_ragged_spec(spec):
408      """Returns the new spec based on RaggedTensors."""
409      if (not isinstance(spec, tensor_spec.TensorSpec) or
410          spec.shape.rank is None or
411          spec.shape.is_fully_defined()):
412        return spec
413      else:
414        ragged_rank = max([
415            axis for (axis, size) in enumerate(spec.shape.as_list())
416            if size is None
417        ])
418        return ragged_tensor.RaggedTensorSpec(
419            shape=spec.shape,
420            dtype=spec.dtype,
421            ragged_rank=ragged_rank,
422            row_splits_dtype=row_splits_dtype)
423
424    self._structure = nest.map_structure(to_ragged_spec,
425                                         input_dataset.element_spec)
426
427    # Replace each tf.Tensor value in the input dataset with a variant-encoded
428    # RaggedTensor. Since we're updating the corresponding structure to be
429    # a RaggedTensorSpec, this variant-encoded tensor will be decoded with
430    # RaggedTensorSpec._from_tensor_list.
431    def to_ragged_variant(value):
432      """Re-encode Tensors as RaggedTensors."""
433      if (not isinstance(value, ops.Tensor) or
434          value.shape.rank is None or
435          value.shape.is_fully_defined()):
436        return value
437      else:
438        spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
439        if spec._ragged_rank > 0:  # pylint: disable=protected-access
440          value = ragged_tensor.RaggedTensor.from_tensor(
441              value, ragged_rank=spec._ragged_rank)  # pylint: disable=protected-access
442        return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access
443
444    # Tuples are automatically unpacked by `dataset.map` so we repack them.
445    if structured_function._should_unpack(input_dataset.element_spec):  # pylint: disable=protected-access
446      map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
447    else:
448      map_fn = lambda value: nest.map_structure(to_ragged_variant, value)
449
450    self._mapped_dataset = input_dataset.map(map_fn)
451
452    variant = self._mapped_dataset._variant_tensor  # pylint: disable=protected-access
453    super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
454
455  @property
456  def element_spec(self):
457    return self._structure
458