1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Batching dataset transformations.""" 16from tensorflow.python.data.ops import dataset_ops 17from tensorflow.python.data.ops import structured_function 18from tensorflow.python.data.util import convert 19from tensorflow.python.data.util import nest 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import sparse_tensor 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.util import deprecation 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export("data.experimental.dense_to_ragged_batch") 33def dense_to_ragged_batch(batch_size, 34 drop_remainder=False, 35 row_splits_dtype=dtypes.int64): 36 """A transformation that batches ragged elements into `tf.RaggedTensor`s. 37 38 This transformation combines multiple consecutive elements of the input 39 dataset into a single element. 40 41 Like `tf.data.Dataset.batch`, the components of the resulting element will 42 have an additional outer dimension, which will be `batch_size` (or 43 `N % batch_size` for the last element if `batch_size` does not divide the 44 number of input elements `N` evenly and `drop_remainder` is `False`). If 45 your program depends on the batches having the same outer dimension, you 46 should set the `drop_remainder` argument to `True` to prevent the smaller 47 batch from being produced. 48 49 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 50 different shapes: 51 52 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is 53 fully defined, then it is batched as normal. 54 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains 55 one or more axes with unknown size (i.e., `shape[i]=None`), then the output 56 will contain a `tf.RaggedTensor` that is ragged up to any of such 57 dimensions. 58 * If an input element is a `tf.RaggedTensor` or any other type, then it is 59 batched as normal. 60 61 Example: 62 63 >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6)) 64 >>> dataset = dataset.map(lambda x: tf.range(x)) 65 >>> dataset.element_spec.shape 66 TensorShape([None]) 67 >>> dataset = dataset.apply( 68 ... tf.data.experimental.dense_to_ragged_batch(batch_size=2)) 69 >>> for batch in dataset: 70 ... print(batch) 71 <tf.RaggedTensor [[], [0]]> 72 <tf.RaggedTensor [[0, 1], [0, 1, 2]]> 73 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]> 74 75 Args: 76 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 77 consecutive elements of this dataset to combine in a single batch. 78 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 79 whether the last batch should be dropped in the case it has fewer than 80 `batch_size` elements; the default behavior is not to drop the smaller 81 batch. 82 row_splits_dtype: The dtype that should be used for the `row_splits` of any 83 new ragged tensors. Existing `tf.RaggedTensor` elements do not have their 84 row_splits dtype changed. 85 86 Returns: 87 Dataset: A `Dataset`. 88 """ 89 90 def _apply_fn(dataset): 91 ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype) 92 return dataset_ops.BatchDataset( 93 ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder) 94 95 return _apply_fn 96 97 98@tf_export("data.experimental.dense_to_sparse_batch") 99def dense_to_sparse_batch(batch_size, row_shape): 100 """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s. 101 102 Like `Dataset.padded_batch()`, this transformation combines multiple 103 consecutive elements of the dataset, which might have different 104 shapes, into a single element. The resulting element has three 105 components (`indices`, `values`, and `dense_shape`), which 106 comprise a `tf.sparse.SparseTensor` that represents the same data. The 107 `row_shape` represents the dense shape of each row in the 108 resulting `tf.sparse.SparseTensor`, to which the effective batch size is 109 prepended. For example: 110 111 ```python 112 # NOTE: The following examples use `{ ... }` to represent the 113 # contents of a dataset. 114 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 115 116 a.apply(tf.data.experimental.dense_to_sparse_batch( 117 batch_size=2, row_shape=[6])) == 118 { 119 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 120 ['a', 'b', 'c', 'a', 'b'], # values 121 [2, 6]), # dense_shape 122 ([[0, 0], [0, 1], [0, 2], [0, 3]], 123 ['a', 'b', 'c', 'd'], 124 [1, 6]) 125 } 126 ``` 127 128 Args: 129 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 130 consecutive elements of this dataset to combine in a single batch. 131 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object 132 representing the equivalent dense shape of a row in the resulting 133 `tf.sparse.SparseTensor`. Each element of this dataset must have the same 134 rank as `row_shape`, and must have size less than or equal to `row_shape` 135 in each dimension. 136 137 Returns: 138 A `Dataset` transformation function, which can be passed to 139 `tf.data.Dataset.apply`. 140 """ 141 142 def _apply_fn(dataset): 143 return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) 144 145 return _apply_fn 146 147 148@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()") 149@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"]) 150def map_and_batch_with_legacy_function(map_func, 151 batch_size, 152 num_parallel_batches=None, 153 drop_remainder=False, 154 num_parallel_calls=None): 155 """Fused implementation of `map` and `batch`. 156 157 NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not 158 work with V2 functions. New uses are strongly discouraged and existing uses 159 should migrate to `map_and_batch` as this method will not be removed in V2. 160 161 Args: 162 map_func: A function mapping a nested structure of tensors to another 163 nested structure of tensors. 164 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 165 consecutive elements of this dataset to combine in a single batch. 166 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 167 representing the number of batches to create in parallel. On one hand, 168 higher values can help mitigate the effect of stragglers. On the other 169 hand, higher values can increase contention if CPU is scarce. 170 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 171 whether the last batch should be dropped in case its size is smaller than 172 desired; the default behavior is not to drop the smaller batch. 173 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 174 representing the number of elements to process in parallel. If not 175 specified, `batch_size * num_parallel_batches` elements will be processed 176 in parallel. If the value `tf.data.AUTOTUNE` is used, then 177 the number of parallel calls is set dynamically based on available CPU. 178 179 Returns: 180 A `Dataset` transformation function, which can be passed to 181 `tf.data.Dataset.apply`. 182 183 Raises: 184 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 185 specified. 186 """ 187 188 if num_parallel_batches is None and num_parallel_calls is None: 189 num_parallel_calls = batch_size 190 elif num_parallel_batches is not None and num_parallel_calls is None: 191 num_parallel_calls = batch_size * num_parallel_batches 192 elif num_parallel_batches is not None and num_parallel_calls is not None: 193 raise ValueError( 194 "`map_and_batch_with_legacy_function` allows only one of " 195 "`num_parallel_batches` and " 196 "`num_parallel_calls` to be set, but " 197 f"`num_parallel_batches` was set to {num_parallel_batches} " 198 f"and `num_parallel_calls` as set to {num_parallel_calls}.") 199 200 def _apply_fn(dataset): 201 return _MapAndBatchDataset(dataset, map_func, batch_size, 202 num_parallel_calls, drop_remainder, 203 use_legacy_function=True) 204 205 return _apply_fn 206 207 208@deprecation.deprecated( 209 None, 210 "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by " 211 "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data " 212 "optimizations will take care of using the fused implementation.") 213@tf_export("data.experimental.map_and_batch") 214def map_and_batch(map_func, 215 batch_size, 216 num_parallel_batches=None, 217 drop_remainder=False, 218 num_parallel_calls=None): 219 """Fused implementation of `map` and `batch`. 220 221 Maps `map_func` across `batch_size` consecutive elements of this dataset 222 and then combines them into a batch. Functionally, it is equivalent to `map` 223 followed by `batch`. This API is temporary and deprecated since input pipeline 224 optimization now fuses consecutive `map` and `batch` operations automatically. 225 226 Args: 227 map_func: A function mapping a nested structure of tensors to another 228 nested structure of tensors. 229 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 230 consecutive elements of this dataset to combine in a single batch. 231 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 232 representing the number of batches to create in parallel. On one hand, 233 higher values can help mitigate the effect of stragglers. On the other 234 hand, higher values can increase contention if CPU is scarce. 235 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 236 whether the last batch should be dropped in case its size is smaller than 237 desired; the default behavior is not to drop the smaller batch. 238 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 239 representing the number of elements to process in parallel. If not 240 specified, `batch_size * num_parallel_batches` elements will be processed 241 in parallel. If the value `tf.data.AUTOTUNE` is used, then 242 the number of parallel calls is set dynamically based on available CPU. 243 244 Returns: 245 A `Dataset` transformation function, which can be passed to 246 `tf.data.Dataset.apply`. 247 248 Raises: 249 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 250 specified. 251 """ 252 253 if num_parallel_batches is None and num_parallel_calls is None: 254 num_parallel_calls = batch_size 255 elif num_parallel_batches is not None and num_parallel_calls is None: 256 num_parallel_calls = batch_size * num_parallel_batches 257 elif num_parallel_batches is not None and num_parallel_calls is not None: 258 raise ValueError( 259 "`map_and_batch` allows only one of `num_parallel_batches` and " 260 "`num_parallel_calls` to be set, but " 261 f"`num_parallel_batches` was set to {num_parallel_batches} " 262 f"and `num_parallel_calls` as set to {num_parallel_calls}.") 263 264 def _apply_fn(dataset): 265 return _MapAndBatchDataset(dataset, map_func, batch_size, 266 num_parallel_calls, drop_remainder) 267 268 return _apply_fn 269 270 271@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.") 272@tf_export("data.experimental.unbatch") 273def unbatch(): 274 """Splits elements of a dataset into multiple elements on the batch dimension. 275 276 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 277 where `B` may vary for each input element, then for each element in the 278 dataset, the unbatched dataset will contain `B` consecutive elements 279 of shape `[a0, a1, ...]`. 280 281 ```python 282 # NOTE: The following example uses `{ ... }` to represent the contents 283 # of a dataset. 284 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 285 286 a.unbatch() == { 287 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} 288 ``` 289 290 Returns: 291 A `Dataset` transformation function, which can be passed to 292 `tf.data.Dataset.apply`. 293 """ 294 295 def _apply_fn(dataset): 296 return dataset.unbatch() 297 298 return _apply_fn 299 300 301class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): 302 """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s.""" 303 304 def __init__(self, input_dataset, batch_size, row_shape): 305 """See `Dataset.dense_to_sparse_batch()` for more details.""" 306 if not isinstance( 307 dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType): 308 raise TypeError("`dense_to_sparse_batch` requires an input dataset whose " 309 "elements have a single component, but the given dataset " 310 "has the following component types: " 311 f"{dataset_ops.get_legacy_output_types(input_dataset)}.") 312 self._input_dataset = input_dataset 313 self._batch_size = batch_size 314 self._row_shape = row_shape 315 self._element_spec = sparse_tensor.SparseTensorSpec( 316 tensor_shape.TensorShape([None]).concatenate(self._row_shape), 317 dataset_ops.get_legacy_output_types(input_dataset)) 318 319 variant_tensor = ged_ops.dense_to_sparse_batch_dataset( 320 self._input_dataset._variant_tensor, # pylint: disable=protected-access 321 self._batch_size, 322 row_shape=convert.partial_shape_to_tensor(self._row_shape), 323 **self._flat_structure) 324 super(_DenseToSparseBatchDataset, self).__init__(input_dataset, 325 variant_tensor) 326 327 @property 328 def element_spec(self): 329 return self._element_spec 330 331 332class _MapAndBatchDataset(dataset_ops.UnaryDataset): 333 """A `Dataset` that maps a function over a batch of elements.""" 334 335 def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, 336 drop_remainder, use_legacy_function=False): 337 self._input_dataset = input_dataset 338 339 self._map_func = structured_function.StructuredFunctionWrapper( 340 map_func, 341 "tf.data.experimental.map_and_batch()", 342 dataset=input_dataset, 343 use_legacy_function=use_legacy_function) 344 self._batch_size_t = ops.convert_to_tensor( 345 batch_size, dtype=dtypes.int64, name="batch_size") 346 self._num_parallel_calls_t = ops.convert_to_tensor( 347 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 348 self._drop_remainder_t = ops.convert_to_tensor( 349 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 350 351 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) 352 # pylint: disable=protected-access 353 if constant_drop_remainder: 354 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 355 # or `False` (explicitly retaining the remainder). 356 # pylint: disable=g-long-lambda 357 self._element_spec = nest.map_structure( 358 lambda component_spec: component_spec._batch( 359 tensor_util.constant_value(self._batch_size_t)), 360 self._map_func.output_structure) 361 else: 362 self._element_spec = nest.map_structure( 363 lambda component_spec: component_spec._batch(None), 364 self._map_func.output_structure) 365 # pylint: enable=protected-access 366 variant_tensor = ged_ops.map_and_batch_dataset( 367 self._input_dataset._variant_tensor, # pylint: disable=protected-access 368 self._map_func.function.captured_inputs, 369 f=self._map_func.function, 370 batch_size=self._batch_size_t, 371 num_parallel_calls=self._num_parallel_calls_t, 372 drop_remainder=self._drop_remainder_t, 373 preserve_cardinality=True, 374 **self._flat_structure) 375 super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) 376 377 def _functions(self): 378 return [self._map_func] 379 380 @property 381 def element_spec(self): 382 return self._element_spec 383 384 385class _DenseToRaggedDataset(dataset_ops.UnaryDataset): 386 """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0). 387 388 In particular: 389 390 * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with 391 ragged_rank=0. This allows tensors with varying shape to be batched 392 together. 393 * Any other elements are left as-is. 394 """ 395 396 def __init__(self, input_dataset, row_splits_dtype): 397 """Constructs a new _DenseToRaggedDataset. 398 399 Args: 400 input_dataset: The dataset whose tf.Tensor elements should be made ragged. 401 row_splits_dtype: The dtype that should be used for the `row_splits` of 402 any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* 403 have their row_splits dtype changed. 404 """ 405 # Replace each TensorSpec in the input dataset's structure with a 406 # corresponding RaggedTensorSpec. 407 def to_ragged_spec(spec): 408 """Returns the new spec based on RaggedTensors.""" 409 if (not isinstance(spec, tensor_spec.TensorSpec) or 410 spec.shape.rank is None or 411 spec.shape.is_fully_defined()): 412 return spec 413 else: 414 ragged_rank = max([ 415 axis for (axis, size) in enumerate(spec.shape.as_list()) 416 if size is None 417 ]) 418 return ragged_tensor.RaggedTensorSpec( 419 shape=spec.shape, 420 dtype=spec.dtype, 421 ragged_rank=ragged_rank, 422 row_splits_dtype=row_splits_dtype) 423 424 self._structure = nest.map_structure(to_ragged_spec, 425 input_dataset.element_spec) 426 427 # Replace each tf.Tensor value in the input dataset with a variant-encoded 428 # RaggedTensor. Since we're updating the corresponding structure to be 429 # a RaggedTensorSpec, this variant-encoded tensor will be decoded with 430 # RaggedTensorSpec._from_tensor_list. 431 def to_ragged_variant(value): 432 """Re-encode Tensors as RaggedTensors.""" 433 if (not isinstance(value, ops.Tensor) or 434 value.shape.rank is None or 435 value.shape.is_fully_defined()): 436 return value 437 else: 438 spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) 439 if spec._ragged_rank > 0: # pylint: disable=protected-access 440 value = ragged_tensor.RaggedTensor.from_tensor( 441 value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access 442 return spec._to_tensor_list(value)[0] # pylint: disable=protected-access 443 444 # Tuples are automatically unpacked by `dataset.map` so we repack them. 445 if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access 446 map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) 447 else: 448 map_fn = lambda value: nest.map_structure(to_ragged_variant, value) 449 450 self._mapped_dataset = input_dataset.map(map_fn) 451 452 variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access 453 super(_DenseToRaggedDataset, self).__init__(input_dataset, variant) 454 455 @property 456 def element_spec(self): 457 return self._structure 458