1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Iterators.""" 16import abc 17import threading 18import warnings 19 20from tensorflow.python.data.ops import optional_ops 21from tensorflow.python.data.ops import options as options_lib 22from tensorflow.python.data.util import nest 23from tensorflow.python.data.util import structure 24from tensorflow.python.eager import context 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import type_spec 32from tensorflow.python.ops import gen_dataset_ops 33from tensorflow.python.trackable import base as trackable 34from tensorflow.python.training.saver import BaseSaverBuilder 35from tensorflow.python.util import _pywrap_utils 36from tensorflow.python.util import deprecation 37from tensorflow.python.util import lazy_loader 38from tensorflow.python.util.compat import collections_abc 39from tensorflow.python.util.tf_export import tf_export 40 41 42# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple 43# times, e.g. when you are distributing different elements to multiple 44# devices in a single step. However, a common pitfall arises when 45# users call `Iterator.get_next()` in each iteration of their training 46# loop. `Iterator.get_next()` adds ops to the graph, and executing 47# each op allocates resources (including threads); as a consequence, 48# invoking it in every iteration of a training loop causes slowdown 49# and eventual resource exhaustion. To guard against this outcome, we 50# log a warning when the number of uses crosses a threshold of suspicion. 51GET_NEXT_CALL_WARNING_THRESHOLD = 32 52 53GET_NEXT_CALL_WARNING_MESSAGE = ( 54 "An unusually high number of `Iterator.get_next()` calls was detected. " 55 "This often indicates that `Iterator.get_next()` is being called inside " 56 "a training loop, which will cause gradual slowdown and eventual resource " 57 "exhaustion. If this is the case, restructure your code to call " 58 "`next_element = iterator.get_next()` once outside the loop, and use " 59 "`next_element` as the input to some computation that is invoked inside " 60 "the loop.") 61 62# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during 63# tf.function tracing. 64GET_NEXT_CALL_ERROR_THRESHOLD = 32 65 66GET_NEXT_CALL_ERROR_MESSAGE = ( 67 "An unusually high number of `tf.data.Iterator.get_next()` calls was " 68 "detected. This suggests that the `for elem in dataset: ...` idiom is used " 69 "within tf.function with AutoGraph disabled. This idiom is only supported " 70 "when AutoGraph is enabled.") 71 72# Collection of all IteratorResources in the `Graph`. 73GLOBAL_ITERATORS = "iterators" 74 75 76autograph_ctx = lazy_loader.LazyLoader( 77 "autograph_ctx", globals(), 78 "tensorflow.python.autograph.core.ag_ctx") 79 80 81# Avoid circular dependency for `type_utils` which transitively depends 82# on Autograph which in turn depends on tf.data. 83type_utils = lazy_loader.LazyLoader( 84 "type_utils", globals(), 85 "tensorflow.python.framework.type_utils") 86 87 88def _device_stack_is_empty(): 89 if context.executing_eagerly(): 90 return context.context().device_name is None 91 # pylint: disable=protected-access 92 device_stack = ops.get_default_graph()._device_functions_outer_to_inner 93 # pylint: enable=protected-access 94 return not bool(device_stack) 95 96 97@tf_export(v1=["data.Iterator"]) 98class Iterator(trackable.Trackable): 99 """Represents the state of iterating through a `Dataset`.""" 100 101 def __init__(self, iterator_resource, initializer, output_types, 102 output_shapes, output_classes): 103 """Creates a new iterator from the given iterator resource. 104 105 Note: Most users will not call this initializer directly, and will 106 instead use `Dataset.make_initializable_iterator()` or 107 `Dataset.make_one_shot_iterator()`. 108 109 Args: 110 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the 111 iterator. 112 initializer: A `tf.Operation` that should be run to initialize this 113 iterator. 114 output_types: A (nested) structure of `tf.DType` objects corresponding to 115 each component of an element of this iterator. 116 output_shapes: A (nested) structure of `tf.TensorShape` objects 117 corresponding to each component of an element of this iterator. 118 output_classes: A (nested) structure of Python `type` objects 119 corresponding to each component of an element of this iterator. 120 121 Raises: 122 TypeError: If `output_types`, `output_shapes`, or `output_classes` is not 123 specified. 124 """ 125 self._iterator_resource = iterator_resource 126 self._initializer = initializer 127 128 if (output_types is None or output_shapes is None 129 or output_classes is None): 130 raise ValueError( 131 "All of `output_types`, `output_shapes`, and `output_classes` " 132 "must be specified to create an iterator. Got " 133 f"`output_types` = {output_types!r}, " 134 f"`output_shapes` = {output_shapes!r}, " 135 f"`output_classes` = {output_classes!r}.") 136 self._element_spec = structure.convert_legacy_structure( 137 output_types, output_shapes, output_classes) 138 self._flat_tensor_shapes = structure.get_flat_tensor_shapes( 139 self._element_spec) 140 self._flat_tensor_types = structure.get_flat_tensor_types( 141 self._element_spec) 142 143 self._string_handle = gen_dataset_ops.iterator_to_string_handle( 144 self._iterator_resource) 145 self._get_next_call_count = 0 146 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) 147 148 @staticmethod 149 def from_structure(output_types, 150 output_shapes=None, 151 shared_name=None, 152 output_classes=None): 153 """Creates a new, uninitialized `Iterator` with the given structure. 154 155 This iterator-constructing method can be used to create an iterator that 156 is reusable with many different datasets. 157 158 The returned iterator is not bound to a particular dataset, and it has 159 no `initializer`. To initialize the iterator, run the operation returned by 160 `Iterator.make_initializer(dataset)`. 161 162 The following is an example 163 164 ```python 165 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) 166 167 dataset_range = Dataset.range(10) 168 range_initializer = iterator.make_initializer(dataset_range) 169 170 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) 171 evens_initializer = iterator.make_initializer(dataset_evens) 172 173 # Define a model based on the iterator; in this example, the model_fn 174 # is expected to take scalar tf.int64 Tensors as input (see 175 # the definition of 'iterator' above). 176 prediction, loss = model_fn(iterator.get_next()) 177 178 # Train for `num_epochs`, where for each epoch, we first iterate over 179 # dataset_range, and then iterate over dataset_evens. 180 for _ in range(num_epochs): 181 # Initialize the iterator to `dataset_range` 182 sess.run(range_initializer) 183 while True: 184 try: 185 pred, loss_val = sess.run([prediction, loss]) 186 except tf.errors.OutOfRangeError: 187 break 188 189 # Initialize the iterator to `dataset_evens` 190 sess.run(evens_initializer) 191 while True: 192 try: 193 pred, loss_val = sess.run([prediction, loss]) 194 except tf.errors.OutOfRangeError: 195 break 196 ``` 197 198 Args: 199 output_types: A (nested) structure of `tf.DType` objects corresponding to 200 each component of an element of this dataset. 201 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 202 objects corresponding to each component of an element of this dataset. 203 If omitted, each component will have an unconstrainted shape. 204 shared_name: (Optional.) If non-empty, this iterator will be shared under 205 the given name across multiple sessions that share the same devices 206 (e.g. when using a remote server). 207 output_classes: (Optional.) A (nested) structure of Python `type` objects 208 corresponding to each component of an element of this iterator. If 209 omitted, each component is assumed to be of type `tf.Tensor`. 210 211 Returns: 212 An `Iterator`. 213 214 Raises: 215 TypeError: If the structures of `output_shapes` and `output_types` are 216 not the same. 217 """ 218 output_types = nest.map_structure(dtypes.as_dtype, output_types) 219 if output_shapes is None: 220 output_shapes = nest.map_structure( 221 lambda _: tensor_shape.TensorShape(None), output_types) 222 else: 223 output_shapes = nest.map_structure_up_to(output_types, 224 tensor_shape.as_shape, 225 output_shapes) 226 if output_classes is None: 227 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 228 nest.assert_same_structure(output_types, output_shapes) 229 output_structure = structure.convert_legacy_structure( 230 output_types, output_shapes, output_classes) 231 if shared_name is None: 232 shared_name = "" 233 iterator_resource = gen_dataset_ops.iterator_v2( 234 container="", 235 shared_name=shared_name, 236 output_types=structure.get_flat_tensor_types(output_structure), 237 output_shapes=structure.get_flat_tensor_shapes( 238 output_structure)) 239 return Iterator(iterator_resource, None, output_types, output_shapes, 240 output_classes) 241 242 @staticmethod 243 def from_string_handle(string_handle, 244 output_types, 245 output_shapes=None, 246 output_classes=None): 247 """Creates a new, uninitialized `Iterator` based on the given handle. 248 249 This method allows you to define a "feedable" iterator where you can choose 250 between concrete iterators by feeding a value in a `tf.Session.run` call. 251 In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you 252 would 253 feed it with the value of `tf.data.Iterator.string_handle` in each step. 254 255 For example, if you had two iterators that marked the current position in 256 a training dataset and a test dataset, you could choose which to use in 257 each step as follows: 258 259 ```python 260 train_iterator = tf.data.Dataset(...).make_one_shot_iterator() 261 train_iterator_handle = sess.run(train_iterator.string_handle()) 262 263 test_iterator = tf.data.Dataset(...).make_one_shot_iterator() 264 test_iterator_handle = sess.run(test_iterator.string_handle()) 265 266 handle = tf.compat.v1.placeholder(tf.string, shape=[]) 267 iterator = tf.data.Iterator.from_string_handle( 268 handle, train_iterator.output_types) 269 270 next_element = iterator.get_next() 271 loss = f(next_element) 272 273 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) 274 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) 275 ``` 276 277 Args: 278 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to 279 a handle produced by the `Iterator.string_handle()` method. 280 output_types: A (nested) structure of `tf.DType` objects corresponding to 281 each component of an element of this dataset. 282 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 283 objects corresponding to each component of an element of this dataset. 284 If omitted, each component will have an unconstrainted shape. 285 output_classes: (Optional.) A (nested) structure of Python `type` objects 286 corresponding to each component of an element of this iterator. If 287 omitted, each component is assumed to be of type `tf.Tensor`. 288 289 Returns: 290 An `Iterator`. 291 """ 292 output_types = nest.map_structure(dtypes.as_dtype, output_types) 293 if output_shapes is None: 294 output_shapes = nest.map_structure( 295 lambda _: tensor_shape.TensorShape(None), output_types) 296 else: 297 output_shapes = nest.map_structure_up_to(output_types, 298 tensor_shape.as_shape, 299 output_shapes) 300 if output_classes is None: 301 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 302 nest.assert_same_structure(output_types, output_shapes) 303 output_structure = structure.convert_legacy_structure( 304 output_types, output_shapes, output_classes) 305 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) 306 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 307 string_handle, 308 output_types=structure.get_flat_tensor_types(output_structure), 309 output_shapes=structure.get_flat_tensor_shapes(output_structure)) 310 return Iterator(iterator_resource, None, output_types, output_shapes, 311 output_classes) 312 313 @property 314 def initializer(self): 315 """A `tf.Operation` that should be run to initialize this iterator. 316 317 Returns: 318 A `tf.Operation` that should be run to initialize this iterator 319 320 Raises: 321 ValueError: If this iterator initializes itself automatically. 322 """ 323 if self._initializer is not None: 324 return self._initializer 325 else: 326 # TODO(mrry): Consider whether one-shot iterators should have 327 # initializers that simply reset their state to the beginning. 328 raise ValueError( 329 "The iterator does not have an initializer. This means it was likely " 330 "created using `tf.data.Dataset.make_one_shot_iterator()`. For an " 331 "initializable iterator, use " 332 "`tf.data.Dataset.make_initializable_iterator()` instead.") 333 334 def make_initializer(self, dataset, name=None): 335 """Returns a `tf.Operation` that initializes this iterator on `dataset`. 336 337 Args: 338 dataset: A `Dataset` whose `element_spec` if compatible with this 339 iterator. 340 name: (Optional.) A name for the created operation. 341 342 Returns: 343 A `tf.Operation` that can be run to initialize this iterator on the given 344 `dataset`. 345 346 Raises: 347 TypeError: If `dataset` and this iterator do not have a compatible 348 `element_spec`. 349 """ 350 with ops.name_scope(name, "make_initializer") as name: 351 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due 352 # to that creating a circular dependency. 353 # pylint: disable=protected-access 354 dataset_output_types = nest.map_structure( 355 lambda component_spec: component_spec._to_legacy_output_types(), 356 dataset.element_spec) 357 dataset_output_shapes = nest.map_structure( 358 lambda component_spec: component_spec._to_legacy_output_shapes(), 359 dataset.element_spec) 360 dataset_output_classes = nest.map_structure( 361 lambda component_spec: component_spec._to_legacy_output_classes(), 362 dataset.element_spec) 363 # pylint: enable=protected-access 364 365 nest.assert_same_structure(self.output_types, dataset_output_types) 366 nest.assert_same_structure(self.output_shapes, dataset_output_shapes) 367 for iterator_class, dataset_class in zip( 368 nest.flatten(self.output_classes), 369 nest.flatten(dataset_output_classes)): 370 if iterator_class is not dataset_class: 371 raise TypeError( 372 f"Expected output classes {self.output_classes!r} but got " 373 f"dataset with output classes {dataset_output_classes!r}.") 374 for iterator_dtype, dataset_dtype in zip( 375 nest.flatten(self.output_types), nest.flatten(dataset_output_types)): 376 if iterator_dtype != dataset_dtype: 377 raise TypeError( 378 f"Expected output types {self.output_types!r} but got dataset " 379 f"with output types {dataset_output_types!r}.") 380 for iterator_shape, dataset_shape in zip( 381 nest.flatten(self.output_shapes), nest.flatten( 382 dataset_output_shapes)): 383 if not iterator_shape.is_compatible_with(dataset_shape): 384 raise TypeError( 385 f"Expected output shapes compatible with {self.output_shapes!r} " 386 f"but got dataset with output shapes {dataset_output_shapes!r}.") 387 388 # TODO(b/169442955): Investigate the need for this colocation constraint. 389 with ops.colocate_with(self._iterator_resource): 390 # pylint: disable=protected-access 391 return gen_dataset_ops.make_iterator( 392 dataset._variant_tensor, self._iterator_resource, name=name) 393 394 def get_next(self, name=None): 395 """Returns the next element. 396 397 In graph mode, you should typically call this method *once* and use its 398 result as the input to another computation. A typical loop will then call 399 `tf.Session.run` on the result of that computation. The loop will terminate 400 when the `Iterator.get_next()` operation raises 401 `tf.errors.OutOfRangeError`. The following skeleton shows how to use 402 this method when building a training loop: 403 404 ```python 405 dataset = ... # A `tf.data.Dataset` object. 406 iterator = dataset.make_initializable_iterator() 407 next_element = iterator.get_next() 408 409 # Build a TensorFlow graph that does something with each element. 410 loss = model_function(next_element) 411 optimizer = ... # A `tf.compat.v1.train.Optimizer` object. 412 train_op = optimizer.minimize(loss) 413 414 with tf.compat.v1.Session() as sess: 415 try: 416 while True: 417 sess.run(train_op) 418 except tf.errors.OutOfRangeError: 419 pass 420 ``` 421 422 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. 423 when you are distributing different elements to multiple devices in a single 424 step. However, a common pitfall arises when users call `Iterator.get_next()` 425 in each iteration of their training loop. `Iterator.get_next()` adds ops to 426 the graph, and executing each op allocates resources (including threads); as 427 a consequence, invoking it in every iteration of a training loop causes 428 slowdown and eventual resource exhaustion. To guard against this outcome, we 429 log a warning when the number of uses crosses a fixed threshold of 430 suspiciousness. 431 432 Args: 433 name: (Optional.) A name for the created operation. 434 435 Returns: 436 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 437 """ 438 self._get_next_call_count += 1 439 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: 440 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) 441 442 # TODO(b/169442955): Investigate the need for this colocation constraint. 443 with ops.colocate_with(self._iterator_resource): 444 # pylint: disable=protected-access 445 flat_ret = gen_dataset_ops.iterator_get_next( 446 self._iterator_resource, 447 output_types=self._flat_tensor_types, 448 output_shapes=self._flat_tensor_shapes, 449 name=name) 450 return structure.from_tensor_list(self._element_spec, flat_ret) 451 452 def get_next_as_optional(self): 453 # TODO(b/169442955): Investigate the need for this colocation constraint. 454 with ops.colocate_with(self._iterator_resource): 455 # pylint: disable=protected-access 456 return optional_ops._OptionalImpl( 457 gen_dataset_ops.iterator_get_next_as_optional( 458 self._iterator_resource, 459 output_types=structure.get_flat_tensor_types(self.element_spec), 460 output_shapes=structure.get_flat_tensor_shapes( 461 self.element_spec)), self.element_spec) 462 463 def string_handle(self, name=None): 464 """Returns a string-valued `tf.Tensor` that represents this iterator. 465 466 Args: 467 name: (Optional.) A name for the created operation. 468 469 Returns: 470 A scalar `tf.Tensor` of type `tf.string`. 471 """ 472 if name is None: 473 return self._string_handle 474 else: 475 return gen_dataset_ops.iterator_to_string_handle( 476 self._iterator_resource, name=name) 477 478 @property 479 @deprecation.deprecated( 480 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 481 def output_classes(self): 482 """Returns the class of each component of an element of this iterator. 483 484 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 485 486 Returns: 487 A (nested) structure of Python `type` objects corresponding to each 488 component of an element of this dataset. 489 """ 490 return nest.map_structure( 491 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 492 self._element_spec) 493 494 @property 495 @deprecation.deprecated( 496 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 497 def output_shapes(self): 498 """Returns the shape of each component of an element of this iterator. 499 500 Returns: 501 A (nested) structure of `tf.TensorShape` objects corresponding to each 502 component of an element of this dataset. 503 """ 504 return nest.map_structure( 505 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 506 self._element_spec) 507 508 @property 509 @deprecation.deprecated( 510 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 511 def output_types(self): 512 """Returns the type of each component of an element of this iterator. 513 514 Returns: 515 A (nested) structure of `tf.DType` objects corresponding to each component 516 of an element of this dataset. 517 """ 518 return nest.map_structure( 519 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 520 self._element_spec) 521 522 @property 523 def element_spec(self): 524 """The type specification of an element of this iterator. 525 526 For more information, 527 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 528 529 Returns: 530 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 531 element of this iterator and specifying the type of individual components. 532 """ 533 534 return self._element_spec 535 536 def _gather_saveables_for_checkpoint(self): 537 538 def _saveable_factory(name): 539 return _IteratorSaveable(self._iterator_resource, name) 540 541 return {"ITERATOR": _saveable_factory} 542 543 544_uid_counter = 0 545_uid_lock = threading.Lock() 546 547 548def _generate_shared_name(prefix): 549 with _uid_lock: 550 global _uid_counter 551 uid = _uid_counter 552 _uid_counter += 1 553 return "{}{}".format(prefix, uid) 554 555 556@tf_export("data.Iterator", v1=[]) 557class IteratorBase( 558 collections_abc.Iterator, 559 trackable.Trackable, 560 composite_tensor.CompositeTensor, 561 metaclass=abc.ABCMeta): 562 """Represents an iterator of a `tf.data.Dataset`. 563 564 `tf.data.Iterator` is the primary mechanism for enumerating elements of a 565 `tf.data.Dataset`. It supports the Python Iterator protocol, which means 566 it can be iterated over using a for-loop: 567 568 >>> dataset = tf.data.Dataset.range(2) 569 >>> for element in dataset: 570 ... print(element) 571 tf.Tensor(0, shape=(), dtype=int64) 572 tf.Tensor(1, shape=(), dtype=int64) 573 574 or by fetching individual elements explicitly via `get_next()`: 575 576 >>> dataset = tf.data.Dataset.range(2) 577 >>> iterator = iter(dataset) 578 >>> print(iterator.get_next()) 579 tf.Tensor(0, shape=(), dtype=int64) 580 >>> print(iterator.get_next()) 581 tf.Tensor(1, shape=(), dtype=int64) 582 583 In addition, non-raising iteration is supported via `get_next_as_optional()`, 584 which returns the next element (if available) wrapped in a 585 `tf.experimental.Optional`. 586 587 >>> dataset = tf.data.Dataset.from_tensors(42) 588 >>> iterator = iter(dataset) 589 >>> optional = iterator.get_next_as_optional() 590 >>> print(optional.has_value()) 591 tf.Tensor(True, shape=(), dtype=bool) 592 >>> optional = iterator.get_next_as_optional() 593 >>> print(optional.has_value()) 594 tf.Tensor(False, shape=(), dtype=bool) 595 """ 596 597 @abc.abstractproperty 598 def element_spec(self): 599 """The type specification of an element of this iterator. 600 601 >>> dataset = tf.data.Dataset.from_tensors(42) 602 >>> iterator = iter(dataset) 603 >>> iterator.element_spec 604 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 605 606 For more information, 607 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 608 609 Returns: 610 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 611 element of this iterator, specifying the type of individual components. 612 """ 613 raise NotImplementedError("Iterator.element_spec") 614 615 @abc.abstractmethod 616 def get_next(self): 617 """Returns the next element. 618 619 >>> dataset = tf.data.Dataset.from_tensors(42) 620 >>> iterator = iter(dataset) 621 >>> print(iterator.get_next()) 622 tf.Tensor(42, shape=(), dtype=int32) 623 624 Returns: 625 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 626 627 Raises: 628 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 629 """ 630 raise NotImplementedError("Iterator.get_next()") 631 632 @abc.abstractmethod 633 def get_next_as_optional(self): 634 """Returns the next element wrapped in `tf.experimental.Optional`. 635 636 If the iterator has reached the end of the sequence, the returned 637 `tf.experimental.Optional` will have no value. 638 639 >>> dataset = tf.data.Dataset.from_tensors(42) 640 >>> iterator = iter(dataset) 641 >>> optional = iterator.get_next_as_optional() 642 >>> print(optional.has_value()) 643 tf.Tensor(True, shape=(), dtype=bool) 644 >>> print(optional.get_value()) 645 tf.Tensor(42, shape=(), dtype=int32) 646 >>> optional = iterator.get_next_as_optional() 647 >>> print(optional.has_value()) 648 tf.Tensor(False, shape=(), dtype=bool) 649 650 Returns: 651 A `tf.experimental.Optional` object representing the next element. 652 """ 653 raise NotImplementedError("Iterator.get_next_as_optional()") 654 655 656class OwnedIterator(IteratorBase): 657 """An iterator producing tf.Tensor objects from a tf.data.Dataset. 658 659 The iterator resource created through `OwnedIterator` is owned by the Python 660 object and the life time of the underlying resource is tied to the life time 661 of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use 662 in eager mode and inside of tf.functions. 663 """ 664 665 def __init__(self, dataset=None, components=None, element_spec=None): 666 """Creates a new iterator from the given dataset. 667 668 If `dataset` is not specified, the iterator will be created from the given 669 tensor components and element structure. In particular, the alternative for 670 constructing the iterator is used when the iterator is reconstructed from 671 it `CompositeTensor` representation. 672 673 Args: 674 dataset: A `tf.data.Dataset` object. 675 components: Tensor components to construct the iterator from. 676 element_spec: A (nested) structure of `TypeSpec` objects that 677 represents the type specification of elements of the iterator. 678 679 Raises: 680 ValueError: If `dataset` is not provided and either `components` or 681 `element_spec` is not provided. Or `dataset` is provided and either 682 `components` and `element_spec` is provided. 683 """ 684 super(OwnedIterator, self).__init__() 685 686 if dataset is None: 687 if (components is None or element_spec is None): 688 raise ValueError( 689 "When `dataset` is not provided, both `components` and " 690 "`element_spec` must be specified.") 691 # pylint: disable=protected-access 692 self._element_spec = element_spec 693 self._flat_output_types = structure.get_flat_tensor_types( 694 self._element_spec) 695 self._flat_output_shapes = structure.get_flat_tensor_shapes( 696 self._element_spec) 697 self._iterator_resource, = components 698 else: 699 if (components is not None or element_spec is not None): 700 raise ValueError( 701 "When `dataset` is provided, `element_spec` and `components` must " 702 "not be specified.") 703 self._create_iterator(dataset) 704 705 self._get_next_call_count = 0 706 707 def _create_iterator(self, dataset): 708 # pylint: disable=protected-access 709 dataset = dataset._apply_debug_options() 710 711 # Store dataset reference to ensure that dataset is alive when this iterator 712 # is being used. For example, `tf.data.Dataset.from_generator` registers 713 # a few py_funcs that are needed in `self._next_internal`. If the dataset 714 # is deleted, this iterator crashes on `self.__next__(...)` call. 715 self._dataset = dataset 716 717 ds_variant = dataset._variant_tensor 718 self._element_spec = dataset.element_spec 719 self._flat_output_types = structure.get_flat_tensor_types( 720 self._element_spec) 721 self._flat_output_shapes = structure.get_flat_tensor_shapes( 722 self._element_spec) 723 with ops.colocate_with(ds_variant): 724 self._iterator_resource = ( 725 gen_dataset_ops.anonymous_iterator_v3( 726 output_types=self._flat_output_types, 727 output_shapes=self._flat_output_shapes)) 728 if not context.executing_eagerly(): 729 # Add full type information to the graph so host memory types inside 730 # variants stay on CPU, e.g, ragged string tensors. 731 # TODO(b/224776031) Remove this when AnonymousIterateV3 can use 732 # (reverse) type inference and all other ops that are needed to 733 # provide type information to the AnonymousIterateV3 also support 734 # type inference (esp. cross-function type inference) instead of 735 # setting the full type information manually. 736 fulltype = type_utils.iterator_full_type_from_spec( 737 self._element_spec) 738 # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]] 739 assert len(fulltype.args[0].args[0].args) == len( 740 self._flat_output_types) 741 self._iterator_resource.op.experimental_set_type(fulltype) 742 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) 743 744 def __iter__(self): 745 return self 746 747 def next(self): # For Python 2 compatibility 748 return self.__next__() 749 750 def _next_internal(self): 751 autograph_status = autograph_ctx.control_status_ctx().status 752 autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED 753 if not context.executing_eagerly() and autograph_disabled: 754 self._get_next_call_count += 1 755 if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD: 756 raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE) 757 758 if not context.executing_eagerly(): 759 # TODO(b/169442955): Investigate the need for this colocation constraint. 760 with ops.colocate_with(self._iterator_resource): 761 ret = gen_dataset_ops.iterator_get_next( 762 self._iterator_resource, 763 output_types=self._flat_output_types, 764 output_shapes=self._flat_output_shapes) 765 return structure.from_compatible_tensor_list(self._element_spec, ret) 766 767 # TODO(b/77291417): This runs in sync mode as iterators use an error status 768 # to communicate that there is no more data to iterate over. 769 with context.execution_mode(context.SYNC): 770 ret = gen_dataset_ops.iterator_get_next( 771 self._iterator_resource, 772 output_types=self._flat_output_types, 773 output_shapes=self._flat_output_shapes) 774 775 try: 776 # Fast path for the case `self._structure` is not a nested structure. 777 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access 778 except AttributeError: 779 return structure.from_compatible_tensor_list(self._element_spec, ret) 780 781 @property 782 def _type_spec(self): 783 return IteratorSpec(self.element_spec) 784 785 def __next__(self): 786 try: 787 return self._next_internal() 788 except errors.OutOfRangeError: 789 raise StopIteration 790 791 @property 792 @deprecation.deprecated( 793 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 794 def output_classes(self): 795 """Returns the class of each component of an element of this iterator. 796 797 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 798 799 Returns: 800 A (nested) structure of Python `type` objects corresponding to each 801 component of an element of this dataset. 802 """ 803 return nest.map_structure( 804 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 805 self._element_spec) 806 807 @property 808 @deprecation.deprecated( 809 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 810 def output_shapes(self): 811 """Returns the shape of each component of an element of this iterator. 812 813 Returns: 814 A (nested) structure of `tf.TensorShape` objects corresponding to each 815 component of an element of this dataset. 816 """ 817 return nest.map_structure( 818 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 819 self._element_spec) 820 821 @property 822 @deprecation.deprecated( 823 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 824 def output_types(self): 825 """Returns the type of each component of an element of this iterator. 826 827 Returns: 828 A (nested) structure of `tf.DType` objects corresponding to each component 829 of an element of this dataset. 830 """ 831 return nest.map_structure( 832 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 833 self._element_spec) 834 835 @property 836 def element_spec(self): 837 return self._element_spec 838 839 def get_next(self): 840 return self._next_internal() 841 842 def get_next_as_optional(self): 843 # TODO(b/169442955): Investigate the need for this colocation constraint. 844 with ops.colocate_with(self._iterator_resource): 845 # pylint: disable=protected-access 846 return optional_ops._OptionalImpl( 847 gen_dataset_ops.iterator_get_next_as_optional( 848 self._iterator_resource, 849 output_types=structure.get_flat_tensor_types(self.element_spec), 850 output_shapes=structure.get_flat_tensor_shapes( 851 self.element_spec)), self.element_spec) 852 853 def _gather_saveables_for_checkpoint(self): 854 855 def _saveable_factory(name): 856 """Returns a SaveableObject for serialization/deserialization.""" 857 policy = None 858 if self._dataset: 859 policy = self._dataset.options().experimental_external_state_policy 860 if policy: 861 return _IteratorSaveable( 862 self._iterator_resource, 863 name, 864 external_state_policy=policy) 865 else: 866 return _IteratorSaveable(self._iterator_resource, name) 867 868 return {"ITERATOR": _saveable_factory} 869 870 def __tf_tracing_type__(self, signature_context): 871 return signature_context.make_reference_type(self._type_spec, 872 self._iterator_resource._id) # pylint:disable=protected-access 873 874 875@tf_export("data.IteratorSpec", v1=[]) 876class IteratorSpec(type_spec.TypeSpec): 877 """Type specification for `tf.data.Iterator`. 878 879 For instance, `tf.data.IteratorSpec` can be used to define a tf.function that 880 takes `tf.data.Iterator` as an input argument: 881 882 >>> @tf.function(input_signature=[tf.data.IteratorSpec( 883 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 884 ... def square(iterator): 885 ... x = iterator.get_next() 886 ... return x * x 887 >>> dataset = tf.data.Dataset.from_tensors(5) 888 >>> iterator = iter(dataset) 889 >>> print(square(iterator)) 890 tf.Tensor(25, shape=(), dtype=int32) 891 892 Attributes: 893 element_spec: A (nested) structure of `tf.TypeSpec` objects that represents 894 the type specification of the iterator elements. 895 """ 896 897 __slots__ = ["_element_spec"] 898 899 def __init__(self, element_spec): 900 self._element_spec = element_spec 901 902 @property 903 def value_type(self): 904 return OwnedIterator 905 906 def _serialize(self): 907 return (self._element_spec,) 908 909 @property 910 def _component_specs(self): 911 return (tensor_spec.TensorSpec([], dtypes.resource),) 912 913 def _to_components(self, value): 914 return (value._iterator_resource,) # pylint: disable=protected-access 915 916 def _from_components(self, components): 917 return OwnedIterator( 918 dataset=None, 919 components=components, 920 element_spec=self._element_spec) 921 922 @staticmethod 923 def from_value(value): 924 return IteratorSpec(value.element_spec) # pylint: disable=protected-access 925 926 def __tf_tracing_type__(self, signature_context): 927 # TODO(b/202772221): Validate and enforce this assumption of uniqueness per 928 # spec instance. 929 return signature_context.make_reference_type(self, id(self)) 930 931 932# TODO(b/71645805): Expose trackable stateful objects from dataset. 933class _IteratorSaveable(BaseSaverBuilder.SaveableObject): 934 """SaveableObject for saving/restoring iterator state.""" 935 936 def __init__( 937 self, 938 iterator_resource, 939 name, 940 external_state_policy=options_lib.ExternalStatePolicy.FAIL): 941 serialized_iterator = gen_dataset_ops.serialize_iterator( 942 iterator_resource, external_state_policy=external_state_policy.value) 943 specs = [ 944 BaseSaverBuilder.SaveSpec( 945 serialized_iterator, 946 "", 947 name + "_STATE", 948 device=iterator_resource.device) 949 ] 950 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) 951 952 def restore(self, restored_tensors, restored_shapes): 953 with ops.colocate_with(self.op): 954 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) 955 956 957@deprecation.deprecated( 958 None, "Use `tf.data.Iterator.get_next_as_optional()` instead.") 959@tf_export("data.experimental.get_next_as_optional") 960def get_next_as_optional(iterator): 961 """Returns a `tf.experimental.Optional` with the next element of the iterator. 962 963 If the iterator has reached the end of the sequence, the returned 964 `tf.experimental.Optional` will have no value. 965 966 Args: 967 iterator: A `tf.data.Iterator`. 968 969 Returns: 970 A `tf.experimental.Optional` object which either contains the next element 971 of the iterator (if it exists) or no value. 972 """ 973 return iterator.get_next_as_optional() 974 975 976_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator) 977