1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Distribution Strategy-related dataset transformations.""" 16import numpy as np 17 18from tensorflow.python.data.ops import dataset_ops 19from tensorflow.python.data.ops.options import ExternalStatePolicy 20from tensorflow.python.data.util import nest 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 28from tensorflow.python.util.tf_export import tf_export 29 30SHARD_HINT = -1 31tf_export("data.experimental.SHARD_HINT").export_constant( 32 __name__, "SHARD_HINT") 33 34 35class _AutoShardDataset(dataset_ops.UnaryDataset): 36 """A `Dataset` that shards the `Dataset` automatically. 37 38 This dataset takes in an existing dataset and tries to automatically figure 39 out how to shard the dataset in a multi-worker scenario using graph rewrites. 40 41 If the AutoShardPolicy is set to FILE, it walks up the dataset graph until 42 it finds a reader dataset, then inserts a ShardDataset op before that node 43 so that each worker only sees some files. 44 45 If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the 46 end of the input pipeline, before any terminal PrefetchDataset if there is 47 one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it 48 is written to legacy RebatchDataset for correctness reasons, since 49 RebatchDatasetV2 is incompatible with data sharding. 50 51 If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding. 52 If it cannot find a reader dataset, it falls back to doing data-based 53 sharding. 54 55 If the AutoShardPolicy is set to OFF, it does nothing. 56 57 Attributes: 58 num_workers: Total number of workers to shard this dataset across. 59 index: The current worker index (out of the total number of workers) this 60 dataset is for. 61 num_replicas: The total number of replicas across all workers. This is used 62 only when sharding by data (either DATA or AUTO) in order to rewrite 63 RebatchDatasetV2 to RebatchDataset. 64 65 Raises: 66 NotFoundError: If we cannot find a suitable reader dataset to begin 67 automatically sharding the dataset. 68 """ 69 70 def __init__(self, input_dataset, num_workers, index, num_replicas=None): 71 self._input_dataset = input_dataset 72 73 self._element_spec = input_dataset.element_spec 74 variant_tensor = ged_ops.auto_shard_dataset( 75 self._input_dataset._variant_tensor, # pylint: disable=protected-access 76 num_workers=num_workers, 77 index=index, 78 auto_shard_policy=int( 79 input_dataset.options().experimental_distribute.auto_shard_policy), 80 num_replicas=num_replicas, 81 **self._flat_structure) 82 super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor) 83 84 @property 85 def element_spec(self): 86 return self._element_spec 87 88 89def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name 90 return dataset_ops.DatasetV1Adapter( 91 _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) 92 93 94class _RebatchDataset(dataset_ops.UnaryDataset): 95 """A `Dataset` that rebatches elements from its input into new batch sizes. 96 97 `_RebatchDataset(input_dataset, batch_sizes)` is functionally equivalent to 98 `input_dataset.unbatch().batch(N)`, where the value of N cycles through the 99 `batch_sizes` input list. The elements produced by this dataset have the same 100 rank as the elements of the input dataset. 101 102 For example: 103 104 ```python 105 ds = tf.data.Dataset.range(8) 106 ds = ds.batch(4) 107 ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1]) 108 for elem in ds: 109 print(elem) 110 >> [0, 1], [2], [3], [4, 5], [6], [7] 111 112 ds = tf.data.Dataset.range(16) 113 ds = ds.batch(4) 114 ds = _RebatchDataset(ds, batch_sizes=[6]) 115 for elem in ds: 116 print(elem) 117 >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15] 118 ``` 119 """ 120 121 def __init__(self, input_dataset, batch_sizes, drop_remainder=False): 122 """Creates a _RebatchDataset. 123 124 Args: 125 input_dataset: `Dataset` to rebatch. 126 batch_sizes: A `tf.int64` scalar or vector, representing the size of 127 batches to produce. If this argument is a vector, these values are 128 cycled through in order. 129 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 130 whether the last batch should be dropped in the case it has fewer than 131 `batch_sizes[cycle_index] elements; the default behavior is not to drop 132 the smaller batch. 133 """ 134 self._input_dataset = input_dataset 135 self._batch_sizes = ops.convert_to_tensor( 136 batch_sizes, dtype=dtypes.int64, name="batch_sizes") 137 self._drop_remainder = ops.convert_to_tensor( 138 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 139 new_batch_dim = self._compute_static_batch_dim() 140 141 # pylint: disable=protected-access 142 self._element_spec = nest.map_structure( 143 lambda ts: ts._unbatch()._batch(new_batch_dim), 144 dataset_ops.get_structure(input_dataset)) 145 # pylint: enable=protected-access 146 147 # auto_shard rewrite assumes that there's normalize_to_dense before 148 # rebatch_dataset. 149 # LINT.IfChange 150 input_dataset = dataset_ops.normalize_to_dense(input_dataset) 151 variant_tensor = ged_ops.rebatch_dataset_v2( 152 input_dataset._variant_tensor, # pylint: disable=protected-access 153 batch_sizes=batch_sizes, 154 drop_remainder=drop_remainder, 155 **self._flat_structure) 156 # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc) 157 super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) 158 159 def _compute_static_batch_dim(self): 160 """Computes the static batch dimension of a dataset if it can be determined. 161 162 Given the _RebatchDataset parameters, determines the batch dimension of this 163 dataset statically. Returns None if this cannot be determined or is 164 variable. 165 166 Returns: 167 An integer representing the batch dimension of the dataset. If it cannot 168 be determined statically, returns None. 169 170 Raises: 171 ValueError: The batch_sizes parameter is malformed, input_dataset is 172 not batched, or input_dataset batch sizes are incompatible with each 173 other. 174 """ 175 new_batch_dim = tensor_util.constant_value(self._batch_sizes) 176 if new_batch_dim is None: 177 return None 178 179 if isinstance(new_batch_dim, np.ndarray): 180 if len(new_batch_dim.shape) == 1: 181 if np.all(new_batch_dim == new_batch_dim[0]): 182 new_batch_dim = new_batch_dim[0] 183 else: 184 return None 185 elif len(new_batch_dim.shape) > 1: 186 raise ValueError( 187 f"Invalid `batch_sizes`. Expected `batch_sizes` to be a scalar or " 188 f"a vector. Received `batch_sizes` of rank " 189 f"{len(new_batch_dim.shape)}.") 190 191 if self._may_form_partial_batches(new_batch_dim): 192 return None 193 194 return new_batch_dim 195 196 def _may_form_partial_batches(self, desired_batch_size): 197 """Returns whether this dataset may form partial batches.""" 198 if tensor_util.constant_value(self._drop_remainder): 199 return False 200 201 def get_batch_dim(type_spec): 202 try: 203 shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 204 except NotImplementedError: 205 return None 206 if not isinstance(shape, tensor_shape.TensorShape): 207 return None 208 if shape.rank is None: 209 return None 210 if len(shape) < 1: 211 raise ValueError("Invalid `batch_sizes`. Expected dataset with " 212 "rank of >= 1 but found a dataset with " 213 "scalar elements. Fix the issue by adding the `batch` " 214 "transformation to the dataset.") 215 return shape.dims[0].value 216 217 input_batch_dims = [ 218 get_batch_dim(ts) 219 for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset)) 220 ] 221 known_input_batch_dims = [d for d in input_batch_dims if d is not None] 222 223 if not known_input_batch_dims: 224 return True 225 226 known_input_batch_dims = np.asarray(known_input_batch_dims) 227 if not np.all(known_input_batch_dims == known_input_batch_dims[0]): 228 raise ValueError( 229 f"Invalid `input_dataset.` The batch dimension of component 0 " 230 f"is {known_input_batch_dims[0]}, while the batch dimension " 231 f"of component i is {known_input_batch_dims}.") 232 233 return known_input_batch_dims[0] % desired_batch_size != 0 234 235 @property 236 def element_spec(self): 237 return self._element_spec 238 239 240class _LegacyRebatchDataset(dataset_ops.UnaryDataset): 241 """A `Dataset` that divides its input batches into `num_replicas` sub-batches. 242 243 For each batch in the input dataset, _LegacyRebatchDataset will produce 244 `num_replicas` smaller batches whose sizes add up to the original batch size. 245 246 For example: 247 248 ```python 249 ds = tf.data.Dataset.range(8) 250 ds = ds.batch(4) 251 ds = _LegacyRebatchDataset(ds, num_replicas=3) 252 for elem in ds: 253 print(elem) 254 >> [0, 1], [2, 3], [], [4, 5], [6, 7], [] 255 ``` 256 """ 257 258 def __init__(self, input_dataset, num_replicas): 259 """Creates a _LegacyRebatchDataset. 260 261 Args: 262 input_dataset: `Dataset` to rebatch. 263 num_replicas: A `tf.int64` scalar, representing the number of sub-batches 264 to split each batch from `input_dataset` into. 265 """ 266 267 def recalculate_batch_size(type_spec): 268 """Recalculates the output_shape after dividing it by num_replicas.""" 269 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 270 if not isinstance(output_shape, tensor_shape.TensorShape): 271 return None 272 273 # If the output shape is unknown, we set the batch dimension to unknown. 274 if output_shape.rank is None: 275 return None 276 277 if len(output_shape) < 1: 278 raise ValueError( 279 "Invalid `input_dataset`. Expected a dataset whose elements " 280 "have rank >= 1 but found a dataset whose elements are scalars. " 281 "Fix the issue by adding the `batch` transformation to the " 282 "dataset.") 283 output_dims = [d.value for d in output_shape.dims] 284 285 if output_dims[0] is not None and output_dims[0] % num_replicas == 0: 286 return output_dims[0] // num_replicas 287 288 # Set the batch dimension to unknown. If the global batch size does not 289 # divide num_replicas evenly, the minibatches may have different sizes. 290 return None 291 292 def rebatch(type_spec): 293 # pylint: disable=protected-access 294 batch_size = recalculate_batch_size(type_spec) 295 return type_spec._unbatch()._batch(batch_size) 296 # pylint: enable=protected-access 297 298 self._element_spec = nest.map_structure( 299 rebatch, dataset_ops.get_structure(input_dataset)) 300 301 # auto_shard rewrite assumes that there's normalize_to_dense before 302 # rebatch_dataset. 303 # LINT.IfChange 304 input_dataset = dataset_ops.normalize_to_dense(input_dataset) 305 variant_tensor = ged_ops.rebatch_dataset( 306 input_dataset._variant_tensor, # pylint: disable=protected-access 307 num_replicas=num_replicas, 308 **self._flat_structure) 309 # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc) 310 super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor) 311 312 @property 313 def element_spec(self): 314 return self._element_spec 315 316 317class _RemoteDataset(dataset_ops.DatasetSource): 318 """Creates a dataset on a given `device` given a graph def.""" 319 320 def __init__(self, graph_def, device, element_spec): 321 self._elem_spec = element_spec 322 with ops.device(device): 323 variant_tensor = ged_ops.dataset_from_graph(graph_def) 324 super(_RemoteDataset, self).__init__(variant_tensor) 325 326 @property 327 def element_spec(self): 328 return self._elem_spec 329 330 331def replicate(dataset, devices): 332 """A transformation that replicates `dataset` onto a list of devices. 333 334 Args: 335 dataset: A `tf.data.Dataset` object. 336 devices: A list of devices to replicate the dataset on. 337 338 Returns: 339 A dictionary mapping device name to a dataset on that device. 340 """ 341 if not isinstance(dataset, dataset_ops.DatasetV2): 342 raise TypeError( 343 f"Invalid `dataset`. Expected a `tf.data.Dataset` object but " 344 f"got {type(dataset)}.") 345 346 # pylint: disable=protected-access 347 dataset_device = dataset._variant_tensor.device 348 349 datasets = {} 350 if len(devices) == 1 and devices[0] == dataset_device: 351 datasets[devices[0]] = dataset 352 return datasets 353 354 with ops.colocate_with(dataset._variant_tensor): 355 dataset = dataset._apply_debug_options() 356 graph_def = dataset._as_serialized_graph( 357 strip_device_assignment=True, 358 external_state_policy=ExternalStatePolicy.WARN) 359 for device in devices: 360 ds = _RemoteDataset(graph_def, device, dataset.element_spec) 361 datasets[device] = ds 362 return datasets 363 364 365def batch_sizes_for_worker(global_batch_size, num_workers, 366 num_replicas_per_worker, worker_index): 367 """Determines how to rebatch a dataset for the given worker. 368 369 Given the global batch size, number of workers, number of replicas per worker, 370 and worker index, returns the correct batch sizes for rebatching a dataset 371 on worker `worker_index` of `num_workers`, such that each global step (across 372 all workers and replicas) will consume global_batch_size elements. The 373 returned value should be passed as the `batch_sizes` input parameter to 374 `tf.data.experimental.rebatch()`. The returned batch sizes meet the following 375 constraints: 376 377 Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker 378 (A) for any worker, len(batch_sizes) = W * R 379 (B) for any worker, sum(batch_sizes) == G 380 (C) for any global step (i.e. R iterations on each worker), the sum of batches 381 consumed by replicas across all workers is G. 382 (D) any two batch sizes of any two replicas differs by at most one. 383 384 For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two 385 files which each contain 7 elements: 386 387 ```python 388 # WORKER 0 389 batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size, 390 num_workers=2, 391 num_replicas_per_worker=2, 392 worker_index=0) 393 print(batch_sizes_0) 394 >> [2, 2, 2, 1] 395 396 dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 397 dataset_0 = dataset_0.shard(num_shards, index=0) 398 dataset_0 = dataset_0.batch(7) 399 dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0)) 400 for elem in dataset_0: 401 print(elem) 402 >> [[A0, A1], [A2, A3], [A4, A5], [A6]] 403 404 # WORKER 1 405 batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size, 406 num_workers=2, 407 num_replicas_per_worker=2, 408 worker_index=1) 409 print(batch_sizes_1) 410 >> [2, 1, 2, 2] 411 412 dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 413 dataset_1 = dataset_1.shard(num_shards, index=1) 414 dataset_1 = dataset_1.batch(7) 415 dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1)) 416 for elem in dataset_1: 417 print(elem) 418 >> [[B0, B1], [B2], [B3, B4], [B5, B6]] 419 ``` 420 421 The above example will produce the following elements: 422 423 Step 1: 424 Worker 0 Replica 0: [A0, A1] 425 Worker 0 Replica 1: [A2, A3] 426 Worker 1 Replica 0: [B0, B1] 427 Worker 1 Replica 1: [B2] 428 Total batch size = 7 429 430 Step 2: 431 Worker 0 Replica 0: [A4, A5] 432 Worker 0 Replica 1: [A6] 433 Worker 1 Replica 0: [B3, B4] 434 Worker 1 Replica 1: [B5, B6] 435 Total batch size = 7 436 437 Args: 438 global_batch_size: A `tf.int64` scalar, representing the global batch size. 439 num_workers: An integer representing the number of workers the dataset will 440 be distributed across. 441 num_replicas_per_worker: An integer representing the number of replicas per 442 worker. All workers are assumed to have the same number of replicas. 443 worker_index: An integer index of the worker to be rebatched. 444 445 Returns: 446 A `tf.int64` vector, representing the batch sizes to rebatch the dataset 447 into. 448 """ 449 # Constraint (A) 450 num_subbatches = num_workers * num_replicas_per_worker 451 452 offset = worker_index * num_replicas_per_worker 453 454 const_value = tensor_util.constant_value(global_batch_size) 455 if const_value is not None: 456 # Use the constant global batch size for further calculations 457 global_batch_size = const_value 458 459 # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations 460 # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N 461 # subbatches a batch is split into, B - N * floor(B/N) of them will have size 462 # ceil(B/N), and the rest will have size floor(B/N). 463 floor = global_batch_size // num_subbatches 464 num_ceil = global_batch_size - (num_subbatches * floor) 465 466 # For worker 0, we assign the first num_ceil subbatches to have size 467 # ceil(B/N), and the remainder to have size floor(B/N). The other workers will 468 # each be offset by R * worker_index in order to meet constraint (C). 469 if const_value is not None: 470 # If the global batch size is a known constant value, we return a constant 471 # tensor directly instead of manipulating it with TF ops. This allows for 472 # better downstream shape inference. 473 worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil) 474 return ops.convert_to_tensor( 475 worker_0[offset:] + worker_0[:offset], 476 dtype=dtypes.int64, 477 name="batch_sizes") 478 479 worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64) 480 worker_0 = floor * worker_0 + array_ops.concat([ 481 array_ops.ones(num_ceil, dtype=dtypes.int64), 482 array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64) 483 ], 484 axis=0) 485 486 return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0) 487 488 489def compute_batch_size(dataset): 490 """An operation that returns the batch size of the dataset. 491 492 This op tries to infer the batch size statically by walking up the dataset 493 tree from the final dataset node and returning the batch size of the first 494 batching dataset (such as from .batch() and .padded_batch()) that it 495 encounters. This differs from using the `element_spec` of a dataset in that it 496 does not account for partial batches. 497 498 This operation may fail if it encounters contradictory batch sizes (for 499 example, if the dataset is created by zipping together two datasets with 500 different batch sizes), if there are no explicit batching transformations, or 501 if there are operations downstream from the batching transformation that may 502 modify its batch size. In these cases, it returns a -1. 503 504 Args: 505 dataset: A `tf.data.Dataset` object. 506 507 Returns: 508 A `tf.int64` Tensor representing the batch size of the dataset sans partial 509 batches. If this cannot be inferred statically, the value of this tensor 510 will be -1. 511 """ 512 513 def get_static_batch_dim(type_spec): 514 try: 515 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 516 except NotImplementedError: 517 return None 518 if not isinstance(output_shape, tensor_shape.TensorShape): 519 return None 520 if output_shape.rank is None: 521 return None 522 return output_shape.dims[0].value 523 524 batch_dims = [ 525 get_static_batch_dim(type_spec) 526 for type_spec in nest.flatten(dataset_ops.get_structure(dataset)) 527 ] 528 529 if all(d is not None for d in batch_dims): 530 531 if all(d == batch_dims[0] for d in batch_dims): 532 # If all batch dimensions are known and equal, return that directly. 533 batch_dim = batch_dims[0] 534 else: 535 # If all batch dimensions are known but not all equal, return -1. 536 batch_dim = -1 537 538 return constant_op.constant( 539 batch_dim, dtype=dtypes.int64, name="static_batch_size") 540 541 # If any batch dimensions are unknown, use compute_batch_size op. 542 return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access 543 544 545_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__ 546