1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17import functools 18import sys 19import time 20 21import six 22 23from tensorflow.python.data.experimental.ops import batching 24from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib 25from tensorflow.python.data.experimental.ops import distribute 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.data.ops import multi_device_iterator_ops 29from tensorflow.python.data.ops import optional_ops 30from tensorflow.python.distribute import device_util 31from tensorflow.python.distribute import distribute_lib 32from tensorflow.python.distribute import distribute_utils 33from tensorflow.python.distribute import distribution_strategy_context 34from tensorflow.python.distribute import input_ops 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import values 37from tensorflow.python.distribute.distribute_lib import InputReplicationMode 38from tensorflow.python.eager import context 39from tensorflow.python.eager import monitoring 40from tensorflow.python.framework import composite_tensor 41from tensorflow.python.framework import device as tf_device 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import sparse_tensor 46from tensorflow.python.framework import tensor_shape 47from tensorflow.python.framework import tensor_util 48from tensorflow.python.framework import type_spec 49from tensorflow.python.ops import array_ops 50from tensorflow.python.ops import control_flow_ops 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops.ragged import ragged_tensor 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.types import distribute as distribute_types 55from tensorflow.python.util import nest 56from tensorflow.python.util.compat import collections_abc 57from tensorflow.python.util.tf_export import tf_export 58from tensorflow.tools.docs import doc_controls 59 60 61_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler( 62 "/tensorflow/api/distribution_strategy/" 63 "distributed_dataset_initialization_time_milliseconds", 64 monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26), 65 "Track the time (in milliseconds) to initialize distributed datasets.", 66 "strategy", "workers") 67 68_distributed_dataset_from_function_initialization_time_milliseconds = ( 69 monitoring.Sampler( 70 "/tensorflow/api/distribution_strategy/" 71 "distributed_dataset_from_function_initialization_time_milliseconds", 72 monitoring.ExponentialBuckets( 73 scale=1, growth_factor=2, bucket_count=26), 74 "Track the time (in milliseconds) to initialize distributed datasets " 75 "from function.", 76 "strategy", "workers")) 77 78 79def get_iterator_spec_from_dataset(strategy, dataset): 80 """Returns an iterator spec from dataset function. 81 82 This function constructs type spec for iterator obtained from 83 iter(dataset). 84 85 Args: 86 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 87 handle last partial batch. 88 dataset: A tf.data.Dataset instance. If using a function that returns a 89 tf.data.Dataset instance, pass dataset_fn.structured_outputs. 90 91 Returns: 92 A type_spec for iterator for dataset instance. 93 94 """ 95 # pylint: disable=protected-access 96 output_element_spec = dataset.element_spec 97 if isinstance(dataset._type_spec, 98 (DistributedDatasetSpec, 99 DistributedDatasetsFromFunctionSpec)): 100 iterator_type_spec = DistributedIteratorSpec( 101 strategy.extended._input_workers_with_options(), 102 output_element_spec, 103 strategy.extended._container_strategy(), 104 options=None, 105 cardinality=dataset.cardinality, 106 enable_get_next_as_optional=True) 107 else: 108 if strategy.extended._num_gpus_per_worker: 109 logging.warning( 110 f"{strategy.extended._num_gpus_per_worker} GPUs " 111 "are allocated per worker. Please use DistributedDataset by " 112 "calling strategy.experimental_distribute_dataset or strategy." 113 "distribute_datasets_from_function to make best use of GPU " 114 "resources" 115 ) 116 iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec) 117 return iterator_type_spec 118 # pylint: enable=protected-access 119 120 121@tf_export("distribute.DistributedIterator", v1=[]) 122class DistributedIteratorInterface(collections_abc.Iterator, 123 distribute_types.Iterator): 124 """An iterator over `tf.distribute.DistributedDataset`. 125 126 `tf.distribute.DistributedIterator` is the primary mechanism for enumerating 127 elements of a `tf.distribute.DistributedDataset`. It supports the Python 128 Iterator protocol, which means it can be iterated over using a for-loop or by 129 fetching individual elements explicitly via `get_next()`. 130 131 You can create a `tf.distribute.DistributedIterator` by calling `iter` on 132 a `tf.distribute.DistributedDataset` or creating a python loop over a 133 `tf.distribute.DistributedDataset`. 134 135 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 136 on distributed input for more examples and caveats. 137 """ 138 139 def get_next(self): 140 """Returns the next input from the iterator for all replicas. 141 142 Example use: 143 144 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 145 >>> dataset = tf.data.Dataset.range(100).batch(2) 146 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 147 >>> dist_dataset_iterator = iter(dist_dataset) 148 >>> @tf.function 149 ... def one_step(input): 150 ... return input 151 >>> step_num = 5 152 >>> for _ in range(step_num): 153 ... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),)) 154 >>> strategy.experimental_local_results(dist_dataset_iterator.get_next()) 155 (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>, 156 <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>) 157 158 Returns: 159 A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains 160 the next input for all replicas. 161 162 Raises: 163 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 164 """ 165 raise NotImplementedError( 166 "DistributedIterator.get_next() must be implemented in descendants.") 167 168 @property 169 def element_spec(self): 170 # pylint: disable=line-too-long 171 """The type specification of an element of `tf.distribute.DistributedIterator`. 172 173 Example usage: 174 175 >>> global_batch_size = 16 176 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 177 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 178 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 179 >>> distributed_iterator.element_spec 180 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 181 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 182 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 183 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 184 185 Returns: 186 A nested structure of `tf.TypeSpec` objects matching the structure of an 187 element of this `tf.distribute.DistributedIterator`. This returned value 188 is typically a `tf.distribute.DistributedValues` object and specifies the 189 `tf.TensorSpec` of individual components. 190 """ 191 raise NotImplementedError( 192 "DistributedIterator.element_spec() must be implemented in descendants") 193 194 def get_next_as_optional(self): 195 # pylint: disable=line-too-long 196 """Returns a `tf.experimental.Optional` that contains the next value for all replicas. 197 198 If the `tf.distribute.DistributedIterator` has reached the end of the 199 sequence, the returned `tf.experimental.Optional` will have no value. 200 201 Example usage: 202 203 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 204 >>> global_batch_size = 2 205 >>> steps_per_loop = 2 206 >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size) 207 >>> distributed_iterator = iter( 208 ... strategy.experimental_distribute_dataset(dataset)) 209 >>> def step_fn(x): 210 ... # train the model with inputs 211 ... return x 212 >>> @tf.function 213 ... def train_fn(distributed_iterator): 214 ... for _ in tf.range(steps_per_loop): 215 ... optional_data = distributed_iterator.get_next_as_optional() 216 ... if not optional_data.has_value(): 217 ... break 218 ... per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),)) 219 ... tf.print(strategy.experimental_local_results(per_replica_results)) 220 >>> train_fn(distributed_iterator) 221 ... # ([0 1], [2 3]) 222 ... # ([4], []) 223 224 Returns: 225 An `tf.experimental.Optional` object representing the next value from the 226 `tf.distribute.DistributedIterator` (if it has one) or no value. 227 """ 228 # pylint: enable=line-too-long 229 raise NotImplementedError( 230 "get_next_as_optional() not implemented in descendants") 231 232 233@tf_export("distribute.DistributedDataset", v1=[]) 234class DistributedDatasetInterface(collections_abc.Iterable, 235 distribute_types.Iterable): 236 # pylint: disable=line-too-long 237 """Represents a dataset distributed among devices and machines. 238 239 A `tf.distribute.DistributedDataset` could be thought of as a "distributed" 240 dataset. When you use `tf.distribute` API to scale training to multiple 241 devices or machines, you also need to distribute the input data, which leads 242 to a `tf.distribute.DistributedDataset` instance, instead of a 243 `tf.data.Dataset` instance in the non-distributed case. In TF 2.x, 244 `tf.distribute.DistributedDataset` objects are Python iterables. 245 246 Note: `tf.distribute.DistributedDataset` instances are *not* of type 247 `tf.data.Dataset`. It only supports two usages we will mention below: 248 iteration and `element_spec`. We don't support any other APIs to transform or 249 inspect the dataset. 250 251 There are two APIs to create a `tf.distribute.DistributedDataset` object: 252 `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and 253 `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`. 254 *When to use which?* When you have a `tf.data.Dataset` instance, and the 255 regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance 256 with a new batch size that is equal to the global batch size divided by the 257 number of replicas in sync) and autosharding (i.e. the 258 `tf.data.experimental.AutoShardPolicy` options) work for you, use the former 259 API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance, 260 or you would like to customize the batch splitting or sharding, you can wrap 261 these logic in a `dataset_fn` and use the latter API. Both API handles 262 prefetch to device for the user. For more details and examples, follow the 263 links to the APIs. 264 265 266 There are two main usages of a `DistributedDataset` object: 267 268 1. Iterate over it to generate the input for a single device or multiple 269 devices, which is a `tf.distribute.DistributedValues` instance. To do this, 270 you can: 271 272 * use a pythonic for-loop construct: 273 274 >>> global_batch_size = 4 275 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 276 >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size) 277 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 278 >>> @tf.function 279 ... def train_step(input): 280 ... features, labels = input 281 ... return labels - 0.3 * features 282 >>> for x in dist_dataset: 283 ... # train_step trains the model using the dataset elements 284 ... loss = strategy.run(train_step, args=(x,)) 285 ... print("Loss is", loss) 286 Loss is PerReplica:{ 287 0: tf.Tensor( 288 [[0.7] 289 [0.7]], shape=(2, 1), dtype=float32), 290 1: tf.Tensor( 291 [[0.7] 292 [0.7]], shape=(2, 1), dtype=float32) 293 } 294 295 Placing the loop inside a `tf.function` will give a performance boost. 296 However `break` and `return` are currently not supported if the loop is 297 placed inside a `tf.function`. We also don't support placing the loop 298 inside a `tf.function` when using 299 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 300 `tf.distribute.experimental.TPUStrategy` with multiple workers. 301 302 * use `__iter__` to create an explicit iterator, which is of type 303 `tf.distribute.DistributedIterator` 304 305 >>> global_batch_size = 4 306 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 307 >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size) 308 >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) 309 >>> @tf.function 310 ... def distributed_train_step(dataset_inputs): 311 ... def train_step(input): 312 ... loss = tf.constant(0.1) 313 ... return loss 314 ... per_replica_losses = strategy.run(train_step, args=(dataset_inputs,)) 315 ... return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None) 316 >>> EPOCHS = 2 317 >>> STEPS = 3 318 >>> for epoch in range(EPOCHS): 319 ... total_loss = 0.0 320 ... num_batches = 0 321 ... dist_dataset_iterator = iter(train_dist_dataset) 322 ... for _ in range(STEPS): 323 ... total_loss += distributed_train_step(next(dist_dataset_iterator)) 324 ... num_batches += 1 325 ... average_train_loss = total_loss / num_batches 326 ... template = ("Epoch {}, Loss: {:.4f}") 327 ... print (template.format(epoch+1, average_train_loss)) 328 Epoch 1, Loss: 0.2000 329 Epoch 2, Loss: 0.2000 330 331 332 To achieve a performance improvement, you can also wrap the `strategy.run` 333 call with a `tf.range` inside a `tf.function`. This runs multiple steps in a 334 `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker. 335 However, it is less flexible comparing with running a single step inside 336 `tf.function`. For example, you cannot run things eagerly or arbitrary 337 python code within the steps. 338 339 340 2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`. 341 342 `tf.distribute.DistributedDataset` generates 343 `tf.distribute.DistributedValues` as input to the devices. If you pass the 344 input to a `tf.function` and would like to specify the shape and type of 345 each Tensor argument to the function, you can pass a `tf.TypeSpec` object to 346 the `input_signature` argument of the `tf.function`. To get the 347 `tf.TypeSpec` of the input, you can use the `element_spec` property of the 348 `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator` 349 object. 350 351 For example: 352 353 >>> global_batch_size = 4 354 >>> epochs = 1 355 >>> steps_per_epoch = 1 356 >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 357 >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size) 358 >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset) 359 >>> @tf.function(input_signature=[dist_dataset.element_spec]) 360 ... def train_step(per_replica_inputs): 361 ... def step_fn(inputs): 362 ... return tf.square(inputs) 363 ... return mirrored_strategy.run(step_fn, args=(per_replica_inputs,)) 364 >>> for _ in range(epochs): 365 ... iterator = iter(dist_dataset) 366 ... for _ in range(steps_per_epoch): 367 ... output = train_step(next(iterator)) 368 ... print(output) 369 PerReplica:{ 370 0: tf.Tensor( 371 [[4.] 372 [4.]], shape=(2, 1), dtype=float32), 373 1: tf.Tensor( 374 [[4.] 375 [4.]], shape=(2, 1), dtype=float32) 376 } 377 378 379 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 380 on distributed input for more examples and caveats. 381 """ 382 383 def __iter__(self): 384 """Creates an iterator for the `tf.distribute.DistributedDataset`. 385 386 The returned iterator implements the Python Iterator protocol. 387 388 Example usage: 389 390 >>> global_batch_size = 4 391 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 392 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size) 393 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 394 >>> print(next(distributed_iterator)) 395 PerReplica:{ 396 0: tf.Tensor([1 2], shape=(2,), dtype=int32), 397 1: tf.Tensor([3 4], shape=(2,), dtype=int32) 398 } 399 400 Returns: 401 An `tf.distribute.DistributedIterator` instance for the given 402 `tf.distribute.DistributedDataset` object to enumerate over the 403 distributed data. 404 """ 405 raise NotImplementedError("Must be implemented in descendants") 406 407 @property 408 def element_spec(self): 409 """The type specification of an element of this `tf.distribute.DistributedDataset`. 410 411 Example usage: 412 413 >>> global_batch_size = 16 414 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 415 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 416 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 417 >>> dist_dataset.element_spec 418 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 419 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 420 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 421 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 422 423 Returns: 424 A nested structure of `tf.TypeSpec` objects matching the structure of an 425 element of this `tf.distribute.DistributedDataset`. This returned value is 426 typically a `tf.distribute.DistributedValues` object and specifies the 427 `tf.TensorSpec` of individual components. 428 """ 429 raise NotImplementedError( 430 "DistributedDataset.element_spec must be implemented in descendants.") 431 432 @doc_controls.do_not_generate_docs 433 def reduce(self, initial_state, reduce_func): 434 raise NotImplementedError( 435 "DistributedDataset.reduce must be implemented in descendants.") 436 437 438class InputWorkers(object): 439 """A 1-to-many mapping from input worker devices to compute devices.""" 440 441 # TODO(ishark): Remove option canonicalize_devices and make all the callers 442 # pass canonicalized or raw device strings as relevant from strategy. 443 def __init__(self, 444 worker_device_pairs, 445 canonicalize_devices=True): 446 """Initialize an `InputWorkers` object. 447 448 Args: 449 worker_device_pairs: A sequence of pairs: `(input device, a tuple of 450 compute devices fed by that input device)`. 451 canonicalize_devices: Whether to canonicalize devices for workers fully or 452 partially. If False, it will partially canonicalize devices by removing 453 job and task. 454 """ 455 self._worker_device_pairs = worker_device_pairs 456 self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) 457 self._canonicalize_devices = canonicalize_devices 458 if canonicalize_devices: 459 self._fed_devices = tuple( 460 tuple(device_util.canonicalize(d) 461 for d in f) 462 for _, f in self._worker_device_pairs) 463 else: 464 self._fed_devices = tuple( 465 tuple(device_util.canonicalize_without_job_and_task(d) 466 for d in f) 467 for _, f in self._worker_device_pairs) 468 469 @property 470 def num_workers(self): 471 return len(self._input_worker_devices) 472 473 @property 474 def worker_devices(self): 475 return self._input_worker_devices 476 477 def compute_devices_for_worker(self, worker_index): 478 return self._fed_devices[worker_index] 479 480 def __repr__(self): 481 devices = self.worker_devices 482 debug_repr = ",\n".join(" %d %s: %s" % 483 (i, devices[i], self._fed_devices[i]) 484 for i in range(len(devices))) 485 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) 486 487 def serialize(self): 488 return (self._worker_device_pairs, self._canonicalize_devices) 489 490 def deserialize(self, serialized): 491 return InputWorkers(serialized) 492 493 494def _calculate_replicas_with_values(strategy, input_workers, optional_list): 495 """Calcualates the number of replicas that have values. 496 497 Args: 498 strategy: the `tf.distribute.Strategy`. 499 input_workers: the `InputWorkers`. 500 optional_list: a list of lists `tf.experimental.Optional`. The values from 501 each compute device grouped by the input device. 502 503 Returns: 504 A scalar Tensor. 505 """ 506 worker_has_values = [] 507 for worker, optionals in zip(input_workers.worker_devices, optional_list): 508 with ops.device(worker): 509 device_has_values = [ 510 math_ops.cast(v.has_value(), dtypes.int64) for v in optionals 511 ] 512 worker_has_values.append( 513 math_ops.reduce_sum(device_has_values, keepdims=True)) 514 client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True) 515 if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access 516 global_has_values = strategy.reduce( 517 reduce_util.ReduceOp.SUM, client_has_values, axis=None) 518 return array_ops.reshape(global_has_values, []) 519 else: 520 return array_ops.reshape(client_has_values, []) 521 522 523def _is_statically_shaped(element_spec): 524 """Test if an iterator output is statically shaped. 525 526 For sparse and ragged tensors this only tests the batch dimension. 527 528 Args: 529 element_spec: a nest structure of `tf.TypeSpec`. The element spec of the 530 dataset of the iterator. 531 532 Returns: 533 True if the shape is static, false otherwise. 534 """ 535 536 for spec in nest.flatten(element_spec): 537 if isinstance( 538 spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): 539 # For sparse or ragged tensor, we should only check the first 540 # dimension in order to get_next_as_optional. This is because 541 # when these tensors get batched by dataset only the batch dimension 542 # is set. 543 if spec.shape.rank > 0 and spec.shape.as_list()[0] is None: 544 return False 545 else: 546 for component in spec._flat_tensor_specs: # pylint: disable=protected-access 547 if not component.shape.is_fully_defined(): 548 return False 549 return True 550 551 552class DistributedIteratorBase(DistributedIteratorInterface): 553 """Common implementation for all input iterators.""" 554 555 # pylint: disable=super-init-not-called 556 def __init__(self, input_workers, iterators, strategy, cardinality, 557 enable_get_next_as_optional): 558 assert isinstance(input_workers, InputWorkers) 559 if not input_workers.worker_devices: 560 raise ValueError("Should have at least one worker for input iterator.") 561 562 self._iterators = iterators 563 self._input_workers = input_workers 564 self._strategy = strategy 565 self._cardinality = cardinality 566 self._enable_get_next_as_optional = enable_get_next_as_optional 567 568 def next(self): 569 return self.__next__() 570 571 def __next__(self): 572 try: 573 return self.get_next() 574 except errors.OutOfRangeError: 575 raise StopIteration 576 577 def __iter__(self): 578 return self 579 580 def get_next_as_optional(self): 581 # Ideally get_next_as_optional() should be consistent with get_next(), but 582 # we used to always do partial batch handling in get_next_as_optional(). We 583 # are keeping this behavior for now until we understantd the impact. 584 585 # Skip partial batch handling when the dataset is infinite or empty, as 586 # there won't be any partial batches in those cases. This gives the user 587 # more static shapes as it avoids the tf.cond. Note that for empty datasets, 588 # we can only skip in single client mode, as the dataset can be non-empty on 589 # other workers. 590 if self._cardinality == cardinality_lib.INFINITE: 591 return optional_ops.Optional.from_value( 592 self._get_next_no_partial_batch_handling()) 593 if (self._cardinality == 0 and 594 not self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 595 return optional_ops.Optional.empty(self._element_spec) 596 597 optional_list = [] 598 for i, worker in enumerate(self._input_workers.worker_devices): 599 with ops.device(worker): 600 optional_list.append(self._iterators[i].get_next_as_optional_list()) 601 602 def _create_optional_with_dummy(): 603 value_list = _get_value_or_dummy( 604 self._input_workers, optional_list, produce_dummy=True) 605 per_replica = _create_per_replica(value_list, self._strategy) 606 return optional_ops.Optional.from_value(per_replica) 607 608 def _create_empty_optional(): 609 return optional_ops.Optional.empty(self._element_spec) 610 611 num_replicas_with_values = _calculate_replicas_with_values( 612 self._strategy, self._input_workers, optional_list) 613 614 return control_flow_ops.cond( 615 num_replicas_with_values > 0, 616 _create_optional_with_dummy, 617 _create_empty_optional, 618 strict=True) 619 620 def get_next(self, name=None): 621 """Returns the next input from the iterator for all replicas.""" 622 with distribution_strategy_context.enter_or_assert_strategy( 623 self._strategy): 624 if distribution_strategy_context.get_replica_context() is not None: 625 raise ValueError("next(iterator) should be called from outside of " 626 "replica_fn. e.g. strategy.run(replica_fn, " 627 "args=(next(iterator),))") 628 629 if not self._enable_get_next_as_optional: 630 return self._get_next_no_partial_batch_handling(name) 631 632 optional_list = [] 633 for i, worker in enumerate(self._input_workers.worker_devices): 634 with ops.device(worker): 635 optional_list.append(self._iterators[i].get_next_as_optional_list()) 636 num_replicas_with_values = _calculate_replicas_with_values( 637 self._strategy, self._input_workers, optional_list) 638 639 def _value_or_dummy(): 640 value_list = _get_value_or_dummy( 641 self._input_workers, optional_list, produce_dummy=True) 642 return _create_per_replica(value_list, self._strategy) 643 644 def _eof(): 645 # Optional.get_value raises InvalidArgumentError when there's no value, 646 # so we need to call GetNext to raise EOFError. 647 return self._get_next_no_partial_batch_handling() 648 649 return control_flow_ops.cond( 650 num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True) 651 652 def _get_next_no_partial_batch_handling(self, name=None): 653 replicas = [] 654 for i, worker in enumerate(self._input_workers.worker_devices): 655 if name is not None: 656 d = tf_device.DeviceSpec.from_string(worker) 657 new_name = "%s_%s_%d" % (name, d.job, d.task) 658 else: 659 new_name = None 660 with ops.device(worker): 661 # Make `replicas` a flat list of values across all replicas. 662 replicas.extend(self._iterators[i].get_next_as_list(new_name)) 663 return _create_per_replica(replicas, self._strategy) 664 665 666class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec): 667 """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction.""" 668 669 __slots__ = [ 670 "_input_workers", "_element_spec", "_strategy", "_cardinality", 671 "_enable_get_next_as_optional", "_options", "_canonicalize_devices" 672 ] 673 674 def __init__(self, 675 input_workers, 676 element_spec, 677 strategy, 678 options, 679 cardinality=cardinality_lib.UNKNOWN, 680 enable_get_next_as_optional=None): 681 # We don't want to allow deserialization of this class because we don't 682 # serialize the strategy object. Currently the only places where 683 # _deserialize is called is when we save/restore using SavedModels. 684 if isinstance(input_workers, tuple): 685 raise NotImplementedError("DistributedIteratorSpec does not have support " 686 "for deserialization.") 687 else: 688 self._input_workers = input_workers 689 self._element_spec = element_spec 690 self._strategy = strategy 691 self._cardinality = cardinality 692 self._enable_get_next_as_optional = enable_get_next_as_optional 693 self._options = options 694 if self._strategy: 695 self._canonicalize_devices = getattr(self._strategy, 696 "_canonicalize_devices", True) 697 else: 698 self._canonicalize_devices = True 699 700 def _serialize(self): 701 # We cannot serialize the strategy object so we convert it to an id that we 702 # can use for comparison. 703 return (self._input_workers.serialize(), self._element_spec, 704 id(self._strategy), id(self._options)) 705 706 def _deserialize(self): 707 raise ValueError( 708 f"Deserialization is currently unsupported for {type(self)}.") 709 710 def sanity_check_type(self, other): 711 """Returns the most specific TypeSpec compatible with `self` and `other`. 712 713 Args: 714 other: A `TypeSpec`. 715 716 Raises: 717 ValueError: If there is no TypeSpec that is compatible with both `self` 718 and `other`. 719 """ 720 # pylint: disable=protected-access 721 if type(self) is not type(other): 722 raise ValueError("No TypeSpec is compatible with both %s and %s" % 723 (self, other)) 724 if self._input_workers.serialize() != other._input_workers.serialize(): 725 raise ValueError("_input_workers is not compatible with both %s " 726 "and %s" % (self, other)) 727 if self._strategy is not other._strategy: 728 raise ValueError("tf.distribute strategy is not compatible with both %s " 729 "and %s" % (self, other)) 730 731 def is_subtype_of(self, other): 732 """Returns True if `self` is subtype of `other`. 733 734 Args: 735 other: A `TypeSpec`. 736 """ 737 try: 738 self.sanity_check_type(other) 739 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access 740 except (TypeError, ValueError): 741 return False 742 743 self_elements = nest.flatten(self._element_spec) 744 other_elements = nest.flatten(other._element_spec) # pylint: disable=protected-access 745 746 return all( 747 self_element.is_subtype_of(other_element) 748 for (self_element, other_element) in zip(self_elements, other_elements)) 749 750 def most_specific_common_supertype(self, others): 751 """Returns the most specific supertype of `self` and `others`. 752 753 Args: 754 others: A Sequence of `TypeSpec`. 755 756 Returns `None` if a supertype does not exist. 757 """ 758 try: 759 for other in others: 760 self.sanity_check_type(other) 761 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access 762 except (TypeError, ValueError): 763 return None 764 765 self_elements = nest.flatten(self._element_spec) 766 others_elements = [nest.flatten(other._element_spec) for other in others] # pylint: disable=protected-access 767 common_elements = [None] * len(self_elements) 768 769 for i, self_element in enumerate(self_elements): 770 common_elements[i] = self_element.most_specific_common_supertype( 771 [other_elements[i] for other_elements in others_elements]) 772 if common_elements[i] is None: 773 return None 774 common_element_spec = nest.pack_sequence_as(self._element_spec, 775 common_elements) 776 return type(self)( 777 self._input_workers, 778 common_element_spec, 779 self._strategy, 780 self._options, 781 cardinality=self._cardinality, 782 enable_get_next_as_optional=self._enable_get_next_as_optional) 783 784 def _with_tensor_ranks_only(self): 785 element_spec = nest.map_structure( 786 lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access 787 self._element_spec) 788 return type(self)( 789 self._input_workers, 790 element_spec, 791 self._strategy, 792 self._options, 793 cardinality=self._cardinality, 794 enable_get_next_as_optional=self._enable_get_next_as_optional) 795 796 # TODO(b/206014848): Remove once names are not used. 797 def _without_tensor_names(self): 798 element_spec = nest.map_structure( 799 lambda s: s._without_tensor_names(), # pylint: disable=protected-access 800 self._element_spec) 801 return type(self)( 802 self._input_workers, 803 element_spec, 804 self._strategy, 805 self._options, 806 cardinality=self._cardinality, 807 enable_get_next_as_optional=self._enable_get_next_as_optional) 808 809 810class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec): 811 """Type specification for `DistributedIterator`.""" 812 813 @property 814 def value_type(self): 815 return DistributedIterator 816 817 @property 818 def _component_specs(self): 819 specs = [] 820 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 821 822 for i, (input_device, compute_devices) in enumerate(worker_device_pairs): 823 element_spec = nest.map_structure( 824 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 825 specs.append( 826 _SingleWorkerDatasetIteratorSpec(input_device, compute_devices, 827 element_spec, self._options, 828 self._canonicalize_devices)) 829 return specs 830 831 def _to_components(self, value): 832 return value._iterators # pylint: disable=protected-access 833 834 def _from_components(self, components): 835 return DistributedIterator( 836 input_workers=self._input_workers, 837 iterators=None, 838 components=components, 839 element_spec=self._element_spec, 840 strategy=self._strategy, 841 cardinality=self._cardinality, 842 enable_get_next_as_optional=self._enable_get_next_as_optional, 843 options=self._options) 844 845 @staticmethod 846 def from_value(value): 847 # pylint: disable=protected-access 848 return DistributedIteratorSpec( 849 value._input_workers, 850 value._element_spec, 851 value._strategy, 852 value._options, 853 cardinality=value._cardinality, 854 enable_get_next_as_optional=value._enable_get_next_as_optional) 855 856 857class DistributedIterator(DistributedIteratorBase, 858 composite_tensor.CompositeTensor): 859 """Input Iterator for a distributed dataset.""" 860 861 def __init__(self, 862 input_workers=None, 863 iterators=None, 864 strategy=None, 865 components=None, 866 element_spec=None, 867 cardinality=cardinality_lib.UNKNOWN, 868 enable_get_next_as_optional=False, 869 options=None): 870 if input_workers is None: 871 raise ValueError("`input_workers` should be " 872 "provided.") 873 874 error_message = ("Either `input_workers` or " 875 "both `components` and `element_spec` need to be " 876 "provided.") 877 self._options = options 878 879 if iterators is None: 880 if (components is None or element_spec is None): 881 raise ValueError(error_message) 882 self._element_spec = element_spec 883 self._input_workers = input_workers 884 self._iterators = components 885 self._strategy = strategy 886 self._cardinality = cardinality 887 self._enable_get_next_as_optional = enable_get_next_as_optional 888 else: 889 if (components is not None and element_spec is not None): 890 raise ValueError(error_message) 891 892 super(DistributedIterator, 893 self).__init__(input_workers, iterators, strategy, cardinality, 894 enable_get_next_as_optional) 895 896 @property 897 def element_spec(self): 898 # When partial batch handling is enabled, always set the batch dimension to 899 # None, otherwise we just follow element_spec of the underlying dataset 900 # (whose batch dimension may also be None). This is because with partial 901 # batching handling we could always produce empty batches. 902 if (self._enable_get_next_as_optional and 903 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 904 return nest.map_structure( 905 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 906 return self._element_spec 907 908 @property 909 def _type_spec(self): 910 # Note that we use actual element_spec instead of the rebatched-as-dynamic 911 # one to create DistributedIteratorSpec, to be consistent with the 912 # underlying iterators' specs. 913 return DistributedIteratorSpec(self._input_workers, self._element_spec, 914 self._strategy, 915 self._options, 916 self._cardinality, 917 self._enable_get_next_as_optional) 918 919 920class _IterableInput(DistributedDatasetInterface): 921 """Base class for iterable inputs for distribution strategies.""" 922 923 # pylint: disable=super-init-not-called 924 def __init__(self, input_workers): 925 assert isinstance(input_workers, InputWorkers) 926 self._input_workers = input_workers 927 928 def __iter__(self): 929 raise NotImplementedError("must be implemented in descendants") 930 931 def reduce(self, initial_state, reduce_fn): 932 """Execute a `reduce_fn` over all the elements of the input.""" 933 iterator = iter(self) 934 optional_data = iterator.get_next_as_optional() 935 936 def cond(optional_data, state): 937 del state # Unused. 938 return optional_data.has_value() 939 940 def loop_body(optional_data, state): 941 """Executes `reduce_fn` in a loop till the dataset is empty.""" 942 state = reduce_fn(state, optional_data.get_value()) 943 optional_data = iterator.get_next_as_optional() 944 return optional_data, state 945 946 optional_data, final_state = control_flow_ops.while_loop( 947 cond, 948 loop_body, [optional_data, initial_state], 949 parallel_iterations=1, 950 return_same_structure=True) 951 return final_state 952 953 954class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec): 955 """Type specification for `DistributedDataset.""" 956 957 @property 958 def value_type(self): 959 return DistributedDataset 960 961 @property 962 def _component_specs(self): 963 specs = [] 964 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 965 966 for i, _ in enumerate(worker_device_pairs): 967 element_spec = nest.map_structure( 968 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 969 specs.append(dataset_ops.DatasetSpec(element_spec)) 970 return specs 971 972 def _to_components(self, value): 973 return value._cloned_datasets # pylint: disable=protected-access 974 975 def _from_components(self, components): 976 return DistributedDataset( 977 input_workers=self._input_workers, 978 strategy=self._strategy, 979 components=components, 980 element_spec=self._element_spec, 981 enable_get_next_as_optional=self._enable_get_next_as_optional, 982 options=self._options) 983 984 @staticmethod 985 def from_value(value): 986 # pylint: disable=protected-access 987 return DistributedDatasetSpec( 988 value._input_workers, 989 value._element_spec, 990 value._strategy, 991 value._options, 992 enable_get_next_as_optional=value._enable_get_next_as_optional) 993 # pylint: enable=protected-access 994 995 996class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor): 997 """Distributed dataset that supports prefetching to multiple devices.""" 998 999 def __init__(self, 1000 input_workers, 1001 strategy, 1002 dataset=None, 1003 num_replicas_in_sync=None, 1004 input_context=None, 1005 components=None, 1006 element_spec=None, 1007 enable_get_next_as_optional=None, 1008 build=True, 1009 options=None): 1010 """Distribute the dataset on all workers. 1011 1012 If `num_replicas_in_sync` is not None, we split each batch of the dataset 1013 into `num_replicas_in_sync` smaller batches, to be distributed among that 1014 worker's replicas, so that the batch size for a global step (across all 1015 workers and replicas) is as expected. 1016 1017 Args: 1018 input_workers: an `InputWorkers` object. 1019 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1020 handle last partial batch. 1021 dataset: `tf.data.Dataset` that will be used as the input source. Either 1022 dataset or components field should be passed when constructing 1023 DistributedDataset. Use this when contructing DistributedDataset from a 1024 new `tf.data.Dataset`. Use components when constructing using 1025 DistributedDatasetSpec. 1026 num_replicas_in_sync: Optional integer. If this is not None, the value 1027 is used to decide how to rebatch datasets into smaller batches so that 1028 the total batch size for each step (across all workers and replicas) 1029 adds up to `dataset`'s batch size. 1030 input_context: `InputContext` for sharding. Only pass this in for between 1031 graph multi-worker cases where there is only one `input_worker`. In 1032 these cases, we will shard based on the `input_pipeline_id` and 1033 `num_input_pipelines` in the `InputContext`. 1034 components: datasets when DistributedDataset is constructed from 1035 DistributedDatasetSpec. Either field dataset or components should be 1036 passed. 1037 element_spec: element spec for DistributedDataset when constructing from 1038 DistributedDatasetSpec. This will be used to set the element_spec for 1039 DistributedDataset and verified against element_spec from components. 1040 enable_get_next_as_optional: this is required when components is passed 1041 instead of dataset. 1042 build: whether to build underlying datasets when this object is created. 1043 This is only useful for `ParameterServerStrategy` now. 1044 options: `tf.distribute.InputOptions` used to control options on how this 1045 dataset is distributed. 1046 """ 1047 super(DistributedDataset, self).__init__(input_workers=input_workers) 1048 if input_workers is None or strategy is None: 1049 raise ValueError("input_workers and strategy are required arguments") 1050 if dataset is not None and components is not None: 1051 raise ValueError("Only one of dataset or components should be present") 1052 if dataset is None and components is None: 1053 raise ValueError("At least one of dataset or components should be passed") 1054 1055 self._input_workers = input_workers 1056 self._strategy = strategy 1057 self._options = options 1058 self._input_context = input_context 1059 self._num_replicas_in_sync = num_replicas_in_sync 1060 1061 if dataset is not None: 1062 self._original_dataset = dataset 1063 self._built = False 1064 if build: 1065 self.build() 1066 else: 1067 if not build: 1068 raise ValueError( 1069 "When constructing DistributedDataset with components, build " 1070 "should not be False. This is an internal error. Please file a " 1071 "bug.") 1072 if enable_get_next_as_optional is None: 1073 raise ValueError( 1074 "When constructing DistributedDataset with components, " + 1075 "enable_get_next_as_optional should also be passed") 1076 self._cloned_datasets = components 1077 self._cardinality = _cardinality(self._cloned_datasets[0]) 1078 self._enable_get_next_as_optional = enable_get_next_as_optional 1079 1080 assert element_spec is not None 1081 if element_spec != _create_distributed_tensor_spec( 1082 self._strategy, self._cloned_datasets[0].element_spec): 1083 raise ValueError("Mismatched element_spec from the passed components") 1084 self._element_spec = element_spec 1085 1086 self._built = True 1087 1088 def build(self, dataset_to_replace=None): 1089 assert not self._built 1090 dataset = dataset_to_replace or self._original_dataset 1091 self._cardinality = _cardinality(dataset) 1092 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1093 self._strategy, dataset, self._cardinality) 1094 distribute_start_time_ns = time.time_ns() 1095 self._create_cloned_datasets_from_dataset(dataset, self._input_context, 1096 self._input_workers, 1097 self._strategy, 1098 self._num_replicas_in_sync) 1099 if context.executing_eagerly(): 1100 # Records the time to initialize the distributed dataset. 1101 context.async_wait() 1102 distribute_duration_ms = (time.time_ns() - 1103 distribute_start_time_ns) // 1_000_000 1104 _distributed_dataset_initialization_time_milliseconds.get_cell( 1105 self._strategy.__class__.__name__, 1106 str(self._input_workers.num_workers)).add(distribute_duration_ms) 1107 self._element_spec = _create_distributed_tensor_spec( 1108 self._strategy, self._cloned_datasets[0].element_spec) 1109 self._built = True 1110 1111 @property 1112 def cardinality(self): 1113 if not self._built: 1114 raise ValueError( 1115 "Cannot get the cardinality of a dataset that is not built") 1116 return self._cardinality 1117 1118 def _create_cloned_datasets_from_dataset(self, dataset, input_context, 1119 input_workers, strategy, 1120 num_replicas_in_sync): 1121 # We clone and shard the dataset on each worker. The current setup tries to 1122 # shard the dataset by files if possible so that each worker sees a 1123 # different subset of files. If that is not possible, will attempt to shard 1124 # the final input such that each worker will run the entire preprocessing 1125 # pipeline and only receive its own shard of the dataset. 1126 1127 # Additionally, we rebatch the dataset on each worker into 1128 # `num_replicas_in_sync` smaller batches to be distributed among that 1129 # worker's replicas, so that the batch size for a global step (across all 1130 # workers and replicas) adds up to the original dataset's batch size. 1131 if num_replicas_in_sync is not None: 1132 num_workers = input_context.num_input_pipelines if input_context else len( 1133 input_workers.worker_devices) 1134 rebatch_fn = self._make_rebatch_fn(dataset, num_workers, 1135 num_replicas_in_sync) 1136 else: 1137 rebatch_fn = None 1138 self._cloned_datasets = [] 1139 if input_context: 1140 # Between-graph where we rely on the input_context for sharding 1141 assert input_workers.num_workers == 1 1142 if rebatch_fn is not None: 1143 dataset = rebatch_fn(dataset, input_context.input_pipeline_id) 1144 dataset = input_ops.auto_shard_dataset(dataset, 1145 input_context.num_input_pipelines, 1146 input_context.input_pipeline_id, 1147 num_replicas_in_sync) 1148 self._cloned_datasets.append(dataset) 1149 else: 1150 replicated_ds = distribute.replicate(dataset, 1151 input_workers.worker_devices) 1152 for i, worker in enumerate(input_workers.worker_devices): 1153 with ops.device(worker): 1154 cloned_dataset = replicated_ds[worker] 1155 if rebatch_fn is not None: 1156 cloned_dataset = rebatch_fn(cloned_dataset, i) 1157 cloned_dataset = input_ops.auto_shard_dataset( 1158 cloned_dataset, len(input_workers.worker_devices), i, 1159 num_replicas_in_sync) 1160 self._cloned_datasets.append(cloned_dataset) 1161 1162 def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): 1163 """Returns a callable that rebatches the input dataset. 1164 1165 Args: 1166 dataset: A `tf.data.Dataset` representing the dataset to be distributed. 1167 num_workers: An integer representing the number of workers to distribute 1168 `dataset` among. 1169 num_replicas_in_sync: An integer representing the number of replicas in 1170 sync across all workers. 1171 """ 1172 if num_replicas_in_sync % num_workers: 1173 raise ValueError( 1174 "tf.distribute expects every worker to have the same number of " 1175 "replicas. However, encountered `num_replicas_in_sync` ({}) that " 1176 "cannot be divided by `num_workers` ({})".format( 1177 num_replicas_in_sync, num_workers)) 1178 1179 num_replicas_per_worker = num_replicas_in_sync // num_workers 1180 with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access 1181 batch_size = distribute.compute_batch_size(dataset) 1182 1183 def rebatch_fn(dataset, worker_index): 1184 try: 1185 # pylint: disable=protected-access 1186 def apply_rebatch(): 1187 batch_sizes = distribute.batch_sizes_for_worker( 1188 batch_size, num_workers, num_replicas_per_worker, worker_index) 1189 return distribute._RebatchDataset( 1190 dataset, batch_sizes).prefetch(num_replicas_per_worker) 1191 1192 def apply_legacy_rebatch(): 1193 return distribute._LegacyRebatchDataset( 1194 dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) 1195 1196 with ops.colocate_with(dataset._variant_tensor): 1197 return control_flow_ops.cond( 1198 math_ops.not_equal(batch_size, -1), 1199 true_fn=apply_rebatch, 1200 false_fn=apply_legacy_rebatch) 1201 except errors.InvalidArgumentError as e: 1202 if "without encountering a batch" in str(e): 1203 six.reraise( 1204 ValueError, 1205 ValueError( 1206 "Call the `batch` method on the input Dataset in order to be " 1207 "able to split your input across {} replicas.\n Please see " 1208 "the tf.distribute.Strategy guide. {}".format( 1209 num_replicas_in_sync, e)), 1210 sys.exc_info()[2]) 1211 else: 1212 raise 1213 1214 return rebatch_fn 1215 1216 def __iter__(self): 1217 if not (context.executing_eagerly() or 1218 ops.get_default_graph().building_function): 1219 raise RuntimeError("__iter__() is only supported inside of tf.function " 1220 "or when eager execution is enabled.") 1221 if not self._built: 1222 raise ValueError("To use this dataset, you need to pass this dataset to " 1223 "ClusterCoordinator.create_per_worker_dataset.") 1224 1225 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 1226 True) 1227 1228 worker_iterators = _create_iterators_per_worker( 1229 self._cloned_datasets, 1230 self._input_workers, 1231 options=self._options, 1232 canonicalize_devices=canonicalize_devices) 1233 iterator = DistributedIterator( 1234 self._input_workers, 1235 worker_iterators, 1236 self._strategy, 1237 cardinality=self._cardinality, 1238 enable_get_next_as_optional=self._enable_get_next_as_optional, 1239 options=self._options) 1240 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1241 1242 # When async eager is enabled, sometimes the iterator may not finish 1243 # initialization before passing to a multi device function, add a sync point 1244 # here to make sure all underlying iterators are initialized. 1245 if context.executing_eagerly(): 1246 context.async_wait() 1247 1248 return iterator 1249 1250 @property 1251 def element_spec(self): 1252 """The type specification of an element of this dataset.""" 1253 # When partial batch handling is enabled, always set the batch dimension to 1254 # None, otherwise we just follow element_spec of the underlying dataset 1255 # (whose batch dimension may also be None). This is because with partial 1256 # batching handling we could always produce empty batches. 1257 if (self._enable_get_next_as_optional and 1258 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1259 return nest.map_structure( 1260 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1261 return self._element_spec 1262 1263 @property 1264 def _type_spec(self): 1265 return DistributedDatasetSpec( 1266 self._input_workers, 1267 self._element_spec, 1268 self._strategy, 1269 self._options, 1270 enable_get_next_as_optional=self._enable_get_next_as_optional) 1271 1272 1273class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec): 1274 """Type specification for `DistributedDatasetsFromFunction.""" 1275 1276 @property 1277 def value_type(self): 1278 return DistributedDatasetsFromFunction 1279 1280 @property 1281 def _component_specs(self): 1282 specs = [] 1283 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 1284 1285 for i, _ in enumerate(worker_device_pairs): 1286 element_spec = nest.map_structure( 1287 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 1288 specs.append(dataset_ops.DatasetSpec(element_spec)) 1289 return specs 1290 1291 def _to_components(self, value): 1292 return value._datasets # pylint: disable=protected-access 1293 1294 def _from_components(self, components): 1295 return DistributedDatasetsFromFunction( 1296 input_workers=self._input_workers, 1297 strategy=self._strategy, 1298 components=components, 1299 element_spec=self._element_spec, 1300 options=self._options) 1301 1302 @staticmethod 1303 def from_value(value): 1304 # pylint: disable=protected-access 1305 return DistributedDatasetsFromFunctionSpec( 1306 input_workers=value._input_workers, 1307 element_spec=value._element_spec, 1308 strategy=value._strategy, 1309 options=value._options) 1310 1311 1312# TODO(priyag): Add other replication modes. 1313class DistributedDatasetsFromFunction(_IterableInput, 1314 composite_tensor.CompositeTensor): 1315 """Inputs created from dataset function.""" 1316 1317 def __init__(self, 1318 input_workers, 1319 strategy, 1320 input_contexts=None, 1321 dataset_fn=None, 1322 options=None, 1323 components=None, 1324 element_spec=None, 1325 build=True): 1326 """Makes an iterable from datasets created by the given function. 1327 1328 Args: 1329 input_workers: an `InputWorkers` object. 1330 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1331 handle last partial batch. 1332 input_contexts: A list of `InputContext` instances to be passed to call(s) 1333 to `dataset_fn`. Length and order should match worker order in 1334 `worker_device_pairs`. 1335 dataset_fn: A function that returns a `Dataset` given an `InputContext`. 1336 Either dataset_fn or components should be passed to construct 1337 DistributedDatasetsFromFunction. Use this when constructing 1338 DistributedDataset using a function. Use components when constructing 1339 using DistributedDatasetsFromFunctionSpec. 1340 options: `tf.distribute.InputOptions` used to control options on how this 1341 dataset is distributed. 1342 components: datasets when DistributedDatasetsFromFunction is constructed 1343 from DistributedDatasetsFromFunctionSpec. Only one of dataset or 1344 components should be passed. 1345 element_spec: element spec for DistributedDataset when constructing from 1346 DistributedDatasetSpec. This will be used to set the element_spec for 1347 DistributedDatasetsFromFunctionSpec and verified against element_spec 1348 from components. 1349 build: whether to build underlying datasets when this object is created. 1350 This is only useful for `ParameterServerStrategy` now. 1351 """ 1352 super(DistributedDatasetsFromFunction, self).__init__( 1353 input_workers=input_workers) 1354 self._input_workers = input_workers 1355 self._strategy = strategy 1356 self._options = options 1357 if dataset_fn is not None and components is not None: 1358 raise ValueError("Only one of dataset_fn or components should be set") 1359 if dataset_fn is None and components is None: 1360 raise ValueError("At least one of dataset_fn or components should be set") 1361 1362 if dataset_fn is not None: 1363 if input_workers.num_workers != len(input_contexts): 1364 raise ValueError( 1365 "Number of input workers (%d) is not same as number of " 1366 "input_contexts (%d)" % 1367 (input_workers.num_workers, len(input_contexts))) 1368 self._input_contexts = input_contexts 1369 self._dataset_fn = dataset_fn 1370 self._built = False 1371 if build: 1372 self.build() 1373 else: 1374 if element_spec is None: 1375 raise ValueError( 1376 "element_spec should also be passed when passing components") 1377 if not build: 1378 raise ValueError( 1379 "When constructing DistributedDatasetFromFunction with components, " 1380 "build should not be False. This is an internal error. Please file " 1381 "a bug.") 1382 self._element_spec = element_spec 1383 self._datasets = components 1384 self._built = True 1385 self._cardinality = _cardinality(self._datasets[0]) 1386 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1387 self._strategy, self._datasets[0], self._cardinality) 1388 1389 def build(self): 1390 assert not self._built 1391 distribute_start_time_ns = time.time_ns() 1392 self._datasets, element_spec = ( 1393 _create_datasets_from_function_with_input_context( 1394 self._input_contexts, self._input_workers, self._dataset_fn)) 1395 if context.executing_eagerly(): 1396 # Records the time to initialize the distributed dataset. 1397 context.async_wait() 1398 distribute_duration_ms = (time.time_ns() - 1399 distribute_start_time_ns) // 1_000_000 1400 _distributed_dataset_from_function_initialization_time_milliseconds.get_cell( 1401 self._strategy.__class__.__name__, 1402 str(self._input_workers.num_workers)).add(distribute_duration_ms) 1403 1404 self._element_spec = _create_distributed_tensor_spec( 1405 self._strategy, element_spec) 1406 self._cardinality = _cardinality(self._datasets[0]) 1407 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1408 self._strategy, self._datasets[0], self._cardinality) 1409 self._built = True 1410 1411 @property 1412 def cardinality(self): 1413 if not self._built: 1414 raise ValueError( 1415 "Cannot get the cardinality of a dataset that is not built") 1416 return self._cardinality 1417 1418 def __iter__(self): 1419 if not (ops.executing_eagerly_outside_functions() or 1420 ops.get_default_graph().building_function): 1421 raise RuntimeError("__iter__() is only supported inside of tf.function " 1422 "or when eager execution is enabled.") 1423 1424 if not self._built: 1425 raise ValueError("You need to use this dataset in " 1426 "ClusterCoordinator.create_per_worker_dataset.") 1427 1428 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 1429 True) 1430 1431 iterators = _create_iterators_per_worker( 1432 self._datasets, 1433 self._input_workers, 1434 options=self._options, 1435 canonicalize_devices=canonicalize_devices) 1436 iterator = DistributedIterator( 1437 input_workers=self._input_workers, 1438 iterators=iterators, 1439 strategy=self._strategy, 1440 cardinality=self._cardinality, 1441 enable_get_next_as_optional=self._enable_get_next_as_optional, 1442 options=self._options) 1443 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1444 1445 # When async eager is enabled, sometimes the iterator may not finish 1446 # initialization before passing to a multi device function, add a sync 1447 # point here to make sure all underlying iterators are initialized. 1448 if context.executing_eagerly(): 1449 context.async_wait() 1450 1451 return iterator 1452 1453 @property 1454 def element_spec(self): 1455 """The type specification of an element of this dataset.""" 1456 # When partial batch handling is enabled, always set the batch dimension to 1457 # None, otherwise we just follow element_spec of the underlying dataset 1458 # (whose batch dimension may also be None). This is because with partial 1459 # batching handling we could always produce empty batches. 1460 if (self._enable_get_next_as_optional and 1461 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1462 return nest.map_structure( 1463 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1464 return self._element_spec 1465 1466 @property 1467 def _type_spec(self): 1468 return DistributedDatasetsFromFunctionSpec(self._input_workers, 1469 self._element_spec, 1470 self._strategy, self._options) 1471 1472 1473def _dummy_tensor_fn(value_structure): 1474 """A function to create dummy tensors from `value_structure`.""" 1475 1476 def create_dummy_tensor(spec): 1477 """Create a dummy tensor with possible batch dimensions set to 0.""" 1478 if hasattr(spec, "_create_empty_value"): 1479 # Type spec may overwrite default dummy values behavior by declaring the 1480 # `_create_empty_value(self)` method. This method must return a value 1481 # compatible with the type spec with batch dimensions set to 0 or fail if 1482 # such a value does not exist. This allows a composite tensor to customize 1483 # dummy values creation as, in general, its dummy value is not composed 1484 # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is 1485 # never allowed to be empty). See b/183969859 for more discussions. 1486 # TODO(b/186079336): reconsider CompositeTensor support. 1487 return spec._create_empty_value() # pylint: disable=protected-access 1488 1489 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1490 # Splice out the ragged dimensions. 1491 # pylint: disable=protected-access 1492 feature_shape = spec._shape[:1].concatenate( 1493 spec._shape[(1 + spec._ragged_rank):]) 1494 feature_type = spec._dtype 1495 # pylint: enable=protected-access 1496 else: 1497 feature_shape = spec.shape 1498 feature_type = spec.dtype 1499 # Ideally we should set the batch dimension to 0, however as in 1500 # DistributionStrategy we don't know the batch dimension, we try to 1501 # guess it as much as possible. If the feature has unknown dimensions, we 1502 # will set them to 0. If the feature shape is already static, we guess the 1503 # first dimension as batch dimension and set it to 0. 1504 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] 1505 if feature_shape else []) 1506 if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or 1507 feature_shape.is_fully_defined()): 1508 dims[0] = tensor_shape.Dimension(0) 1509 1510 if isinstance(spec, sparse_tensor.SparseTensorSpec): 1511 return sparse_tensor.SparseTensor( 1512 values=array_ops.zeros(0, feature_type), 1513 indices=array_ops.zeros((0, len(dims)), dtypes.int64), 1514 dense_shape=dims) 1515 1516 # Create the dummy tensor. 1517 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 1518 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1519 # Reinsert the ragged dimensions with size 0. 1520 # pylint: disable=protected-access 1521 row_splits = array_ops.zeros(1, spec._row_splits_dtype) 1522 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( 1523 dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) 1524 # pylint: enable=protected-access 1525 return dummy_tensor 1526 1527 return nest.map_structure(create_dummy_tensor, value_structure) 1528 1529 1530def _get_value_or_dummy(input_workers, optional_list, produce_dummy): 1531 """Returns the value of the optionals or dummy values. 1532 1533 Args: 1534 input_workers: the `InputWorkers`. 1535 optional_list: a list of lists `tf.experimental.Optional`. The values from 1536 each compute device grouped by the input device. 1537 produce_dummy: a bool. Whether to produce dummy tensors when the optional 1538 doesn't have a value. 1539 1540 Returns: 1541 A flatten list of Tensors. 1542 1543 """ 1544 value_list = [] 1545 for i, worker in enumerate(input_workers.worker_devices): 1546 with ops.device(worker): 1547 devices = input_workers.compute_devices_for_worker(i) 1548 for j, device in enumerate(devices): 1549 with ops.device(device): 1550 if produce_dummy: 1551 # pylint: disable=cell-var-from-loop 1552 value_list.append( 1553 control_flow_ops.cond( 1554 optional_list[i][j].has_value(), 1555 lambda: optional_list[i][j].get_value(), # pylint: disable=unnecessary-lambda 1556 lambda: _dummy_tensor_fn(optional_list[i][j].element_spec), 1557 strict=True, 1558 )) 1559 # pylint: enable=cell-var-from-loop 1560 else: 1561 value_list.append(optional_list[i][j].get_value()) 1562 return value_list 1563 1564 1565class _SingleWorkerDatasetIteratorBase(object): 1566 """Iterator for a single `tf.data.Dataset`.""" 1567 1568 def __init__(self, dataset, worker, devices, options=None): 1569 """Create iterator for the `dataset` to fetch data to worker's `devices` . 1570 1571 A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch 1572 input to the devices on the given worker. 1573 1574 Args: 1575 dataset: A `tf.data.Dataset` instance. 1576 worker: Worker on which ops should be created. 1577 devices: Distribute data from `dataset` to these devices. 1578 options: options. 1579 """ 1580 self._dataset = dataset 1581 self._worker = worker 1582 self._devices = devices 1583 self._element_spec = dataset.element_spec 1584 self._options = options 1585 self._make_iterator() 1586 1587 def _make_iterator(self): 1588 raise NotImplementedError("must be implemented in descendants") 1589 1590 def _format_data_list_with_options(self, data_list): 1591 """Change the data in to a list type if required. 1592 1593 The OwnedMultiDeviceIterator returns the list data type, 1594 while the PER_REPLICA iterator (when used with prefetch disabled) 1595 returns without the enclosed list. This is to fix the inconsistency. 1596 Args: 1597 data_list: data_list 1598 Returns: 1599 list 1600 """ 1601 if (self._options and self._options.experimental_replication_mode == 1602 InputReplicationMode.PER_REPLICA and 1603 not self._options.experimental_fetch_to_device): 1604 return [data_list] 1605 else: 1606 return data_list 1607 1608 def get_next(self, device, name=None): 1609 """Get next element for the given device.""" 1610 del name 1611 with ops.device(self._worker): 1612 if _should_use_multi_device_iterator(self._options): 1613 return self._iterator.get_next(device) 1614 else: 1615 return self._iterator.get_next() 1616 1617 def get_next_as_list(self, name=None): 1618 """Get next element from the underlying iterator. 1619 1620 Runs the iterator get_next() within a device scope. Since this doesn't use 1621 get_next_as_optional(), it is considerably faster than get_next_as_list(), 1622 but it raises EOFError if any of the device doesn't get any data. 1623 1624 Args: 1625 name: not used. 1626 1627 Returns: 1628 A list consisting of the next data from each device. 1629 """ 1630 del name 1631 with ops.device(self._worker): 1632 return self._format_data_list_with_options(self._iterator.get_next()) 1633 1634 def get_next_as_optional_list(self): 1635 with ops.device(self._worker): 1636 return self._format_data_list_with_options( 1637 self._iterator.get_next_as_optional()) 1638 1639 1640class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): 1641 """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" 1642 1643 __slots__ = [ 1644 "_worker", "_devices", "_element_spec", "_options", 1645 "_canonicalize_devices" 1646 ] 1647 1648 def __init__(self, worker, devices, element_spec, options, 1649 canonicalize_devices=True): 1650 self._worker = worker 1651 if canonicalize_devices: 1652 self._devices = tuple(device_util.canonicalize(d) for d in devices) 1653 else: 1654 self._devices = tuple( 1655 device_util.canonicalize_without_job_and_task(d) for d in devices) 1656 self._element_spec = element_spec 1657 # `self._options` intentionally made not `None` for proper serialization. 1658 self._options = (options if options is not None else 1659 distribute_lib.InputOptions()) 1660 self._canonicalize_devices = canonicalize_devices 1661 1662 @property 1663 def value_type(self): 1664 return _SingleWorkerOwnedDatasetIterator 1665 1666 def _serialize(self): 1667 return (self._worker, self._devices, self._element_spec, self._options, 1668 self._canonicalize_devices) 1669 1670 def _get_multi_device_iterator_spec(self, specs): 1671 device_scope = device_util.canonicalize(self._worker, device_util.current()) 1672 host_device = device_util.get_host_for_device(device_scope) 1673 # source_device while creating iterator governs the worker device in 1674 # iterator spec. 1675 worker = host_device 1676 specs.append( 1677 multi_device_iterator_ops.MultiDeviceIteratorSpec( 1678 self._devices, worker, element_spec=self._element_spec)) 1679 1680 @property 1681 def _component_specs(self): 1682 specs = [] 1683 if _should_use_multi_device_iterator(self._options): 1684 self._get_multi_device_iterator_spec(specs) 1685 else: 1686 specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec)) 1687 return specs 1688 1689 def _to_components(self, value): 1690 return [value._iterator] # pylint: disable=protected-access 1691 1692 def _from_components(self, components): 1693 return _SingleWorkerOwnedDatasetIterator( 1694 dataset=None, 1695 worker=self._worker, 1696 devices=self._devices, 1697 components=components, 1698 element_spec=self._element_spec, 1699 options=self._options, 1700 canonicalize_devices=self._canonicalize_devices) 1701 1702 @staticmethod 1703 def from_value(value): 1704 # pylint: disable=protected-access 1705 return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, 1706 value._element_spec, value._options, 1707 value._canonicalize_devices) 1708 1709 1710class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, 1711 composite_tensor.CompositeTensor): 1712 """Iterator for a DistributedDataset instance.""" 1713 1714 def __init__(self, 1715 dataset=None, 1716 worker=None, 1717 devices=None, 1718 components=None, 1719 element_spec=None, 1720 options=None, 1721 canonicalize_devices=None): 1722 """Create iterator for the `dataset` to fetch data to worker's `devices` . 1723 1724 `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the 1725 given worker. The lifetime of this iterator is tied to the encompassing 1726 python object. Once we go out of scope of the python object or return from 1727 a tf.function the underlying iterator resource is deleted. 1728 1729 Args: 1730 dataset: A `tf.data.Dataset` instance. 1731 worker: Worker on which ops should be created. 1732 devices: Distribute data from `dataset` to these devices. 1733 components: Tensor components to construct the 1734 _SingleWorkerOwnedDatasetIterator from. 1735 element_spec: A nested structure of `TypeSpec` objects that represents the 1736 type specification of elements of the iterator. 1737 options: `tf.distribute.InputOptions` used to control options on how this 1738 dataset is distributed. 1739 canonicalize_devices: Whether to canonicalize devices for workers fully or 1740 partially. If False, it will partially canonicalize devices by removing 1741 job and task. 1742 """ 1743 if worker is None or devices is None: 1744 raise ValueError("Both `worker` and `devices` should be provided") 1745 1746 error_message = ("Either `dataset` or both `components` and `element_spec` " 1747 "need to be provided.") 1748 1749 self._options = options 1750 self._canonicalize_devices = canonicalize_devices 1751 if dataset is None: 1752 if (components is None or element_spec is None): 1753 raise ValueError(error_message) 1754 self._element_spec = element_spec 1755 self._worker = worker 1756 self._devices = devices 1757 self._iterator = components[0] 1758 else: 1759 if (components is not None or element_spec is not None): 1760 raise ValueError(error_message) 1761 super(_SingleWorkerOwnedDatasetIterator, 1762 self).__init__(dataset, worker, devices, self._options) 1763 1764 def _create_owned_multi_device_iterator(self): 1765 # If the worker devices are already canonicalized, canonicalizing again 1766 # would have no impact. 1767 # For strategies running on remote workers such as PS Strategy, the device 1768 # scope will be derived from current worker, if used under init_scope(). 1769 device_scope = device_util.canonicalize(self._worker, 1770 device_util.current()) 1771 host_device = device_util.get_host_for_device(device_scope) 1772 with ops.device(device_scope): 1773 if self._options is not None: 1774 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 1775 self._dataset, 1776 self._devices, 1777 source_device=host_device, 1778 max_buffer_size=self._options 1779 .experimental_per_replica_buffer_size, 1780 prefetch_buffer_size=self._options 1781 .experimental_per_replica_buffer_size) 1782 else: 1783 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 1784 self._dataset, self._devices, source_device=host_device) 1785 1786 def _make_iterator(self): 1787 """Make appropriate iterator on the dataset.""" 1788 if not self._worker: 1789 raise ValueError("Worker device must be specified when creating an " 1790 "owned iterator.") 1791 if _should_use_multi_device_iterator(self._options): 1792 self._create_owned_multi_device_iterator() 1793 else: 1794 with ops.device(self._worker): 1795 self._iterator = iter(self._dataset) 1796 1797 @property 1798 def element_spec(self): 1799 return self._element_spec 1800 1801 @property 1802 def _type_spec(self): 1803 return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, 1804 self._element_spec, self._options, 1805 self._canonicalize_devices) 1806 1807 @property 1808 def output_classes(self): 1809 """Returns the class of each component of an element of this iterator. 1810 1811 The expected values are `tf.Tensor` and `tf.SparseTensor`. 1812 1813 Returns: 1814 A nested structure of Python `type` objects corresponding to each 1815 component of an element of this dataset. 1816 """ 1817 return nest.map_structure( 1818 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 1819 self._element_spec) 1820 1821 @property 1822 def output_shapes(self): 1823 """Returns the shape of each component of an element of this iterator. 1824 1825 Returns: 1826 A nested structure of `tf.TensorShape` objects corresponding to each 1827 component of an element of this dataset. 1828 """ 1829 return nest.map_structure( 1830 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 1831 self._element_spec) 1832 1833 @property 1834 def output_types(self): 1835 """Returns the type of each component of an element of this iterator. 1836 1837 Returns: 1838 A nested structure of `tf.DType` objects corresponding to each component 1839 of an element of this dataset. 1840 """ 1841 return nest.map_structure( 1842 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 1843 self._element_spec) 1844 1845 1846def _create_iterators_per_worker(worker_datasets, 1847 input_workers, 1848 options=None, 1849 canonicalize_devices=False): 1850 """Create a multidevice iterator on each of the workers.""" 1851 assert isinstance(input_workers, InputWorkers) 1852 assert len(worker_datasets) == len(input_workers.worker_devices) 1853 iterators = [] 1854 for i, worker in enumerate(input_workers.worker_devices): 1855 with ops.device(worker): 1856 worker_devices = input_workers.compute_devices_for_worker(i) 1857 iterator = _SingleWorkerOwnedDatasetIterator( 1858 dataset=worker_datasets[i], 1859 worker=worker, 1860 devices=worker_devices, 1861 options=options, 1862 canonicalize_devices=canonicalize_devices) 1863 iterators.append(iterator) 1864 return iterators 1865 1866 1867def _create_datasets_from_function_with_input_context(input_contexts, 1868 input_workers, 1869 dataset_fn): 1870 """Create device datasets per worker given a dataset function.""" 1871 datasets = [] 1872 for i, ctx in enumerate(input_contexts): 1873 worker = input_workers.worker_devices[i] 1874 with ops.device(worker): 1875 dataset = dataset_fn(ctx) 1876 datasets.append(dataset) 1877 return datasets, dataset.element_spec 1878 1879 1880# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1881def _get_batched_dataset(d): 1882 """Get the batched dataset from `d`.""" 1883 # pylint: disable=protected-access 1884 if isinstance(d, dataset_ops.DatasetV1Adapter): 1885 d = d._dataset 1886 1887 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 1888 return d 1889 elif isinstance(d, (dataset_ops.PrefetchDataset, 1890 dataset_ops._OptionsDataset)): 1891 return _get_batched_dataset(d._input_dataset) 1892 1893 raise ValueError( 1894 "Unable to get batched dataset from the input dataset. `batch` " 1895 "`map_and_batch` need to be the last operations on the dataset. " 1896 "The batch operations can be followed by a prefetch.") 1897 1898 1899def _get_batched_dataset_attributes(d): 1900 """Get `batch_size`, `drop_remainder` of dataset.""" 1901 # pylint: disable=protected-access 1902 assert isinstance(d, 1903 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 1904 if isinstance(d, dataset_ops.BatchDataset): 1905 batch_size = d._batch_size 1906 drop_remainder = d._drop_remainder 1907 elif isinstance(d, batching._MapAndBatchDataset): 1908 batch_size = d._batch_size_t 1909 drop_remainder = d._drop_remainder_t 1910 # pylint: enable=protected-access 1911 1912 if tensor_util.is_tf_type(batch_size): 1913 batch_size = tensor_util.constant_value(batch_size) 1914 1915 if tensor_util.is_tf_type(drop_remainder): 1916 drop_remainder = tensor_util.constant_value(drop_remainder) 1917 1918 return batch_size, drop_remainder 1919 1920 1921# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1922def _get_dataset_attributes(dataset): 1923 """Get the underlying attributes from the dataset object.""" 1924 # pylint: disable=protected-access 1925 1926 # First, get batch_size and drop_remainder from the dataset. We need 1927 # to walk back the dataset creation process and find the batched version in 1928 # order to get the attributes. 1929 batched_dataset = _get_batched_dataset(dataset) 1930 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 1931 1932 # Second, prefetch buffer should be get from the original dataset. 1933 prefetch_buffer = None 1934 if isinstance(dataset, dataset_ops.PrefetchDataset): 1935 prefetch_buffer = dataset._buffer_size 1936 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 1937 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 1938 prefetch_buffer = dataset._dataset._buffer_size 1939 1940 return batch_size, drop_remainder, prefetch_buffer 1941 1942 1943def _should_use_multi_device_iterator(options): 1944 """Determine whether to use multi_device_iterator_ops.""" 1945 if (options is None or 1946 options.experimental_replication_mode == InputReplicationMode.PER_WORKER 1947 or 1948 (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 1949 and options.experimental_fetch_to_device)): 1950 return True 1951 return False 1952 1953 1954class MultiStepContext(object): 1955 """A context object that can be used to capture things when running steps. 1956 1957 This context object is useful when running multiple steps at a time using the 1958 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 1959 function to specify which outputs to emit at what frequency. Currently it 1960 supports capturing output from the last step, as well as capturing non tensor 1961 outputs. In the future it will be augmented to support other use cases such 1962 as output each N steps. 1963 """ 1964 1965 def __init__(self): 1966 """Initialize an output context. 1967 1968 Returns: 1969 A context object. 1970 """ 1971 self._last_step_outputs = {} 1972 self._last_step_outputs_reduce_ops = {} 1973 self._non_tensor_outputs = {} 1974 1975 @property 1976 def last_step_outputs(self): 1977 """A dictionary consisting of outputs to be captured on last step. 1978 1979 Keys in the dictionary are names of tensors to be captured, as specified 1980 when `set_last_step_output` is called. 1981 Values in the dictionary are the tensors themselves. If 1982 `set_last_step_output` was called with a `reduce_op` for this output, 1983 then the value is the reduced value. 1984 1985 Returns: 1986 A dictionary with last step outputs. 1987 """ 1988 return self._last_step_outputs 1989 1990 def _set_last_step_outputs(self, outputs): 1991 """Replace the entire dictionary of last step outputs.""" 1992 if not isinstance(outputs, dict): 1993 raise ValueError("Need a dictionary to set last_step_outputs.") 1994 self._last_step_outputs = outputs 1995 1996 def set_last_step_output(self, name, output, reduce_op=None): 1997 """Set `output` with `name` to be outputted from the last step. 1998 1999 Args: 2000 name: String, name to identify the output. Doesn't need to match tensor 2001 name. 2002 output: The tensors that should be outputted with `name`. See below for 2003 actual types supported. 2004 reduce_op: Reduction method to use to reduce outputs from multiple 2005 replicas. Required if `set_last_step_output` is called in a replica 2006 context. Optional in cross_replica_context. 2007 When present, the outputs from all the replicas are reduced using the 2008 current distribution strategy's `reduce` method. Hence, the type of 2009 `output` must be what's supported by the corresponding `reduce` method. 2010 For e.g. if using MirroredStrategy and reduction is set, output 2011 must be a `PerReplica` value. 2012 The reduce method is also recorded in a dictionary 2013 `_last_step_outputs_reduce_ops` for later interpreting of the 2014 outputs as already reduced or not. 2015 """ 2016 if distribution_strategy_context.in_cross_replica_context(): 2017 self._last_step_outputs_reduce_ops[name] = reduce_op 2018 if reduce_op is None: 2019 self._last_step_outputs[name] = output 2020 else: 2021 distribution = distribution_strategy_context.get_strategy() 2022 self._last_step_outputs[name] = distribution.reduce(reduce_op, output, 2023 axis=None) 2024 else: 2025 assert reduce_op is not None 2026 def merge_fn(distribution, value): 2027 self._last_step_outputs[name] = distribution.reduce(reduce_op, value, 2028 axis=None) 2029 # Setting this inside the `merge_fn` because all replicas share the same 2030 # context object, so it's more robust to set it only once (even if all 2031 # the replicas are trying to set the same value). 2032 self._last_step_outputs_reduce_ops[name] = reduce_op 2033 2034 distribution_strategy_context.get_replica_context().merge_call( 2035 merge_fn, args=(output,)) 2036 2037 @property 2038 def non_tensor_outputs(self): 2039 """A dictionary consisting of any non tensor outputs to be captured.""" 2040 return self._non_tensor_outputs 2041 2042 def set_non_tensor_output(self, name, output): 2043 """Set `output` with `name` to be captured as a non tensor output.""" 2044 if distribution_strategy_context.in_cross_replica_context(): 2045 self._non_tensor_outputs[name] = output 2046 else: 2047 def merge_fn(distribution, value): 2048 # NOTE(priyag): For non tensor outputs, we simply return all the values 2049 # in a list as reduction doesn't make sense on non tensors. 2050 self._non_tensor_outputs[name] = ( 2051 distribution.experimental_local_results(value)) 2052 distribution_strategy_context.get_replica_context().merge_call( 2053 merge_fn, args=(output,)) 2054 2055 2056def _create_distributed_tensor_spec(strategy, tensor_spec): 2057 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`. 2058 2059 Args: 2060 strategy: The given `tf.distribute` strategy. 2061 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the 2062 shape should be None if you have partial batches. 2063 2064 Returns: 2065 A `tf.TypeSpec` that matches the values produced by a given strategy. This 2066 can be a `tf.TensorSpec` or a `PerRelicaSpec`. 2067 """ 2068 num_replicas = len(strategy.extended.worker_devices) 2069 2070 # For one device strategy that is not MultiWorkerMirroredStrategy, return the 2071 # tensor_spec as is, since we don't wrap the output with PerReplica in this 2072 # case. 2073 # TODO(b/166464552): remove after we always wrap for all strategies. 2074 if not _always_wrap(strategy): 2075 return tensor_spec 2076 2077 # For other cases we assume the input to tf.function is a per replica type. 2078 def _get_value_per_replica(tensor_spec_per_input): 2079 value_specs = [tensor_spec_per_input for _ in range(num_replicas)] 2080 return values.PerReplicaSpec(*value_specs) 2081 2082 return nest.map_structure(_get_value_per_replica, tensor_spec) 2083 2084 2085def _replace_per_replica_spec(spec, i): 2086 """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec.""" 2087 if isinstance(spec, values.PerReplicaSpec): 2088 return spec._value_specs[i] # pylint: disable=protected-access 2089 else: 2090 return spec 2091 2092 2093def _cardinality(dataset): 2094 """Returns the cardinality of the dataset.""" 2095 if context.executing_eagerly(): 2096 with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access 2097 return dataset.cardinality().numpy() 2098 return cardinality_lib.UNKNOWN 2099 2100 2101def _enable_get_next_as_optional(strategy, dataset, cardinality): 2102 """Returns whether to enable using partial batch handling.""" 2103 # TODO(b/133073708): we currently need a flag to control the usage because 2104 # there is a performance difference between get_next() and 2105 # get_next_as_optional(). And we only enable get_next_as_optional when the 2106 # output shapes are not static. 2107 # 2108 # TODO(rxsang): We want to always enable the get_next_as_optional behavior 2109 # when user passed input_fn instead of dataset. 2110 if not getattr( 2111 strategy.extended, "enable_partial_batch_handling", 2112 getattr(strategy.extended, "experimental_enable_get_next_as_optional", 2113 False)): 2114 return False 2115 2116 # If the dataset is infinite, we don't need to enable last partial batch 2117 # support. Note that we can only evaluate the cardinality of the dataset in 2118 # eager. 2119 if cardinality == cardinality_lib.INFINITE: 2120 return False 2121 2122 return not _is_statically_shaped( 2123 dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2124 2125 2126def _create_per_replica(value_list, strategy): 2127 """Creates a PerReplica. 2128 2129 For strategies other than OneDeviceStrategy, it creates a PerReplica whose 2130 type spec is set to the element spec of the dataset. This helps avoid 2131 retracing for partial batches. Retracing is problematic for multi client when 2132 different client retraces different time, since retracing changes the 2133 collective keys in the tf.function, and causes mismatches among clients. 2134 2135 For single client strategies, this simply calls distribute_utils.regroup(). 2136 2137 Args: 2138 value_list: a list of values, one for each replica. 2139 strategy: the `tf.distribute.Strategy`. 2140 2141 Returns: 2142 a structure of PerReplica. 2143 2144 """ 2145 # TODO(b/166464552): always wrap for all one device strategies as well. 2146 always_wrap = _always_wrap(strategy) 2147 per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) 2148 return per_replicas 2149 2150 2151def _always_wrap(strategy): 2152 """Returns whether to always wrap the values in a DistributedValues.""" 2153 return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access 2154 strategy.extended.worker_devices) > 1 2155 2156 2157def _rebatch_as_dynamic(per_replica_spec): 2158 """Rebatch the spec to have a dynamic batch dimension.""" 2159 assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec 2160 2161 # pylint: disable=protected-access 2162 def _rebatch(spec): 2163 # Rebatch if possible. 2164 try: 2165 return spec._unbatch()._batch(None) 2166 except ValueError: 2167 pass 2168 return spec 2169 2170 return values.PerReplicaSpec( 2171 *nest.map_structure(_rebatch, per_replica_spec._value_specs)) 2172 # pylint: enable=protected-access 2173