xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/iterator_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16import abc
17import threading
18import warnings
19
20from tensorflow.python.data.ops import optional_ops
21from tensorflow.python.data.ops import options as options_lib
22from tensorflow.python.data.util import nest
23from tensorflow.python.data.util import structure
24from tensorflow.python.eager import context
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import gen_dataset_ops
33from tensorflow.python.trackable import base as trackable
34from tensorflow.python.training.saver import BaseSaverBuilder
35from tensorflow.python.util import _pywrap_utils
36from tensorflow.python.util import deprecation
37from tensorflow.python.util import lazy_loader
38from tensorflow.python.util.compat import collections_abc
39from tensorflow.python.util.tf_export import tf_export
40
41
42# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
43# times, e.g. when you are distributing different elements to multiple
44# devices in a single step. However, a common pitfall arises when
45# users call `Iterator.get_next()` in each iteration of their training
46# loop. `Iterator.get_next()` adds ops to the graph, and executing
47# each op allocates resources (including threads); as a consequence,
48# invoking it in every iteration of a training loop causes slowdown
49# and eventual resource exhaustion. To guard against this outcome, we
50# log a warning when the number of uses crosses a threshold of suspicion.
51GET_NEXT_CALL_WARNING_THRESHOLD = 32
52
53GET_NEXT_CALL_WARNING_MESSAGE = (
54    "An unusually high number of `Iterator.get_next()` calls was detected. "
55    "This often indicates that `Iterator.get_next()` is being called inside "
56    "a training loop, which will cause gradual slowdown and eventual resource "
57    "exhaustion. If this is the case, restructure your code to call "
58    "`next_element = iterator.get_next()` once outside the loop, and use "
59    "`next_element` as the input to some computation that is invoked inside "
60    "the loop.")
61
62# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during
63# tf.function tracing.
64GET_NEXT_CALL_ERROR_THRESHOLD = 32
65
66GET_NEXT_CALL_ERROR_MESSAGE = (
67    "An unusually high number of `tf.data.Iterator.get_next()` calls was "
68    "detected. This suggests that the `for elem in dataset: ...` idiom is used "
69    "within tf.function with AutoGraph disabled. This idiom is only supported "
70    "when AutoGraph is enabled.")
71
72# Collection of all IteratorResources in the `Graph`.
73GLOBAL_ITERATORS = "iterators"
74
75
76autograph_ctx = lazy_loader.LazyLoader(
77    "autograph_ctx", globals(),
78    "tensorflow.python.autograph.core.ag_ctx")
79
80
81# Avoid circular dependency for `type_utils` which transitively depends
82# on Autograph which in turn depends on tf.data.
83type_utils = lazy_loader.LazyLoader(
84    "type_utils", globals(),
85    "tensorflow.python.framework.type_utils")
86
87
88def _device_stack_is_empty():
89  if context.executing_eagerly():
90    return context.context().device_name is None
91  # pylint: disable=protected-access
92  device_stack = ops.get_default_graph()._device_functions_outer_to_inner
93  # pylint: enable=protected-access
94  return not bool(device_stack)
95
96
97@tf_export(v1=["data.Iterator"])
98class Iterator(trackable.Trackable):
99  """Represents the state of iterating through a `Dataset`."""
100
101  def __init__(self, iterator_resource, initializer, output_types,
102               output_shapes, output_classes):
103    """Creates a new iterator from the given iterator resource.
104
105    Note: Most users will not call this initializer directly, and will
106    instead use `Dataset.make_initializable_iterator()` or
107    `Dataset.make_one_shot_iterator()`.
108
109    Args:
110      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
111        iterator.
112      initializer: A `tf.Operation` that should be run to initialize this
113        iterator.
114      output_types: A (nested) structure of `tf.DType` objects corresponding to
115        each component of an element of this iterator.
116      output_shapes: A (nested) structure of `tf.TensorShape` objects
117        corresponding to each component of an element of this iterator.
118      output_classes: A (nested) structure of Python `type` objects
119        corresponding to each component of an element of this iterator.
120
121    Raises:
122      TypeError: If `output_types`, `output_shapes`, or `output_classes` is not
123        specified.
124    """
125    self._iterator_resource = iterator_resource
126    self._initializer = initializer
127
128    if (output_types is None or output_shapes is None
129        or output_classes is None):
130      raise ValueError(
131          "All of `output_types`, `output_shapes`, and `output_classes` "
132          "must be specified to create an iterator. Got "
133          f"`output_types` = {output_types!r}, "
134          f"`output_shapes` = {output_shapes!r}, "
135          f"`output_classes` = {output_classes!r}.")
136    self._element_spec = structure.convert_legacy_structure(
137        output_types, output_shapes, output_classes)
138    self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
139        self._element_spec)
140    self._flat_tensor_types = structure.get_flat_tensor_types(
141        self._element_spec)
142
143    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
144        self._iterator_resource)
145    self._get_next_call_count = 0
146    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
147
148  @staticmethod
149  def from_structure(output_types,
150                     output_shapes=None,
151                     shared_name=None,
152                     output_classes=None):
153    """Creates a new, uninitialized `Iterator` with the given structure.
154
155    This iterator-constructing method can be used to create an iterator that
156    is reusable with many different datasets.
157
158    The returned iterator is not bound to a particular dataset, and it has
159    no `initializer`. To initialize the iterator, run the operation returned by
160    `Iterator.make_initializer(dataset)`.
161
162    The following is an example
163
164    ```python
165    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
166
167    dataset_range = Dataset.range(10)
168    range_initializer = iterator.make_initializer(dataset_range)
169
170    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
171    evens_initializer = iterator.make_initializer(dataset_evens)
172
173    # Define a model based on the iterator; in this example, the model_fn
174    # is expected to take scalar tf.int64 Tensors as input (see
175    # the definition of 'iterator' above).
176    prediction, loss = model_fn(iterator.get_next())
177
178    # Train for `num_epochs`, where for each epoch, we first iterate over
179    # dataset_range, and then iterate over dataset_evens.
180    for _ in range(num_epochs):
181      # Initialize the iterator to `dataset_range`
182      sess.run(range_initializer)
183      while True:
184        try:
185          pred, loss_val = sess.run([prediction, loss])
186        except tf.errors.OutOfRangeError:
187          break
188
189      # Initialize the iterator to `dataset_evens`
190      sess.run(evens_initializer)
191      while True:
192        try:
193          pred, loss_val = sess.run([prediction, loss])
194        except tf.errors.OutOfRangeError:
195          break
196    ```
197
198    Args:
199      output_types: A (nested) structure of `tf.DType` objects corresponding to
200        each component of an element of this dataset.
201      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
202        objects corresponding to each component of an element of this dataset.
203        If omitted, each component will have an unconstrainted shape.
204      shared_name: (Optional.) If non-empty, this iterator will be shared under
205        the given name across multiple sessions that share the same devices
206        (e.g. when using a remote server).
207      output_classes: (Optional.) A (nested) structure of Python `type` objects
208        corresponding to each component of an element of this iterator. If
209        omitted, each component is assumed to be of type `tf.Tensor`.
210
211    Returns:
212      An `Iterator`.
213
214    Raises:
215      TypeError: If the structures of `output_shapes` and `output_types` are
216        not the same.
217    """
218    output_types = nest.map_structure(dtypes.as_dtype, output_types)
219    if output_shapes is None:
220      output_shapes = nest.map_structure(
221          lambda _: tensor_shape.TensorShape(None), output_types)
222    else:
223      output_shapes = nest.map_structure_up_to(output_types,
224                                               tensor_shape.as_shape,
225                                               output_shapes)
226    if output_classes is None:
227      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
228    nest.assert_same_structure(output_types, output_shapes)
229    output_structure = structure.convert_legacy_structure(
230        output_types, output_shapes, output_classes)
231    if shared_name is None:
232      shared_name = ""
233    iterator_resource = gen_dataset_ops.iterator_v2(
234        container="",
235        shared_name=shared_name,
236        output_types=structure.get_flat_tensor_types(output_structure),
237        output_shapes=structure.get_flat_tensor_shapes(
238            output_structure))
239    return Iterator(iterator_resource, None, output_types, output_shapes,
240                    output_classes)
241
242  @staticmethod
243  def from_string_handle(string_handle,
244                         output_types,
245                         output_shapes=None,
246                         output_classes=None):
247    """Creates a new, uninitialized `Iterator` based on the given handle.
248
249    This method allows you to define a "feedable" iterator where you can choose
250    between concrete iterators by feeding a value in a `tf.Session.run` call.
251    In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you
252    would
253    feed it with the value of `tf.data.Iterator.string_handle` in each step.
254
255    For example, if you had two iterators that marked the current position in
256    a training dataset and a test dataset, you could choose which to use in
257    each step as follows:
258
259    ```python
260    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
261    train_iterator_handle = sess.run(train_iterator.string_handle())
262
263    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
264    test_iterator_handle = sess.run(test_iterator.string_handle())
265
266    handle = tf.compat.v1.placeholder(tf.string, shape=[])
267    iterator = tf.data.Iterator.from_string_handle(
268        handle, train_iterator.output_types)
269
270    next_element = iterator.get_next()
271    loss = f(next_element)
272
273    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
274    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
275    ```
276
277    Args:
278      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to
279        a handle produced by the `Iterator.string_handle()` method.
280      output_types: A (nested) structure of `tf.DType` objects corresponding to
281        each component of an element of this dataset.
282      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
283        objects corresponding to each component of an element of this dataset.
284        If omitted, each component will have an unconstrainted shape.
285      output_classes: (Optional.) A (nested) structure of Python `type` objects
286        corresponding to each component of an element of this iterator. If
287        omitted, each component is assumed to be of type `tf.Tensor`.
288
289    Returns:
290      An `Iterator`.
291    """
292    output_types = nest.map_structure(dtypes.as_dtype, output_types)
293    if output_shapes is None:
294      output_shapes = nest.map_structure(
295          lambda _: tensor_shape.TensorShape(None), output_types)
296    else:
297      output_shapes = nest.map_structure_up_to(output_types,
298                                               tensor_shape.as_shape,
299                                               output_shapes)
300    if output_classes is None:
301      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
302    nest.assert_same_structure(output_types, output_shapes)
303    output_structure = structure.convert_legacy_structure(
304        output_types, output_shapes, output_classes)
305    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
306    iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
307        string_handle,
308        output_types=structure.get_flat_tensor_types(output_structure),
309        output_shapes=structure.get_flat_tensor_shapes(output_structure))
310    return Iterator(iterator_resource, None, output_types, output_shapes,
311                    output_classes)
312
313  @property
314  def initializer(self):
315    """A `tf.Operation` that should be run to initialize this iterator.
316
317    Returns:
318      A `tf.Operation` that should be run to initialize this iterator
319
320    Raises:
321      ValueError: If this iterator initializes itself automatically.
322    """
323    if self._initializer is not None:
324      return self._initializer
325    else:
326      # TODO(mrry): Consider whether one-shot iterators should have
327      # initializers that simply reset their state to the beginning.
328      raise ValueError(
329          "The iterator does not have an initializer. This means it was likely "
330          "created using `tf.data.Dataset.make_one_shot_iterator()`. For an "
331          "initializable iterator, use "
332          "`tf.data.Dataset.make_initializable_iterator()` instead.")
333
334  def make_initializer(self, dataset, name=None):
335    """Returns a `tf.Operation` that initializes this iterator on `dataset`.
336
337    Args:
338      dataset: A `Dataset` whose `element_spec` if compatible with this
339        iterator.
340      name: (Optional.) A name for the created operation.
341
342    Returns:
343      A `tf.Operation` that can be run to initialize this iterator on the given
344      `dataset`.
345
346    Raises:
347      TypeError: If `dataset` and this iterator do not have a compatible
348        `element_spec`.
349    """
350    with ops.name_scope(name, "make_initializer") as name:
351      # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
352      # to that creating a circular dependency.
353      # pylint: disable=protected-access
354      dataset_output_types = nest.map_structure(
355          lambda component_spec: component_spec._to_legacy_output_types(),
356          dataset.element_spec)
357      dataset_output_shapes = nest.map_structure(
358          lambda component_spec: component_spec._to_legacy_output_shapes(),
359          dataset.element_spec)
360      dataset_output_classes = nest.map_structure(
361          lambda component_spec: component_spec._to_legacy_output_classes(),
362          dataset.element_spec)
363      # pylint: enable=protected-access
364
365      nest.assert_same_structure(self.output_types, dataset_output_types)
366      nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
367      for iterator_class, dataset_class in zip(
368          nest.flatten(self.output_classes),
369          nest.flatten(dataset_output_classes)):
370        if iterator_class is not dataset_class:
371          raise TypeError(
372              f"Expected output classes {self.output_classes!r} but got "
373              f"dataset with output classes {dataset_output_classes!r}.")
374      for iterator_dtype, dataset_dtype in zip(
375          nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
376        if iterator_dtype != dataset_dtype:
377          raise TypeError(
378              f"Expected output types {self.output_types!r} but got dataset "
379              f"with output types {dataset_output_types!r}.")
380      for iterator_shape, dataset_shape in zip(
381          nest.flatten(self.output_shapes), nest.flatten(
382              dataset_output_shapes)):
383        if not iterator_shape.is_compatible_with(dataset_shape):
384          raise TypeError(
385              f"Expected output shapes compatible with {self.output_shapes!r} "
386              f"but got dataset with output shapes {dataset_output_shapes!r}.")
387
388    # TODO(b/169442955): Investigate the need for this colocation constraint.
389    with ops.colocate_with(self._iterator_resource):
390      # pylint: disable=protected-access
391      return gen_dataset_ops.make_iterator(
392          dataset._variant_tensor, self._iterator_resource, name=name)
393
394  def get_next(self, name=None):
395    """Returns the next element.
396
397    In graph mode, you should typically call this method *once* and use its
398    result as the input to another computation. A typical loop will then call
399    `tf.Session.run` on the result of that computation. The loop will terminate
400    when the `Iterator.get_next()` operation raises
401    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
402    this method when building a training loop:
403
404    ```python
405    dataset = ...  # A `tf.data.Dataset` object.
406    iterator = dataset.make_initializable_iterator()
407    next_element = iterator.get_next()
408
409    # Build a TensorFlow graph that does something with each element.
410    loss = model_function(next_element)
411    optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
412    train_op = optimizer.minimize(loss)
413
414    with tf.compat.v1.Session() as sess:
415      try:
416        while True:
417          sess.run(train_op)
418      except tf.errors.OutOfRangeError:
419        pass
420    ```
421
422    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
423    when you are distributing different elements to multiple devices in a single
424    step. However, a common pitfall arises when users call `Iterator.get_next()`
425    in each iteration of their training loop. `Iterator.get_next()` adds ops to
426    the graph, and executing each op allocates resources (including threads); as
427    a consequence, invoking it in every iteration of a training loop causes
428    slowdown and eventual resource exhaustion. To guard against this outcome, we
429    log a warning when the number of uses crosses a fixed threshold of
430    suspiciousness.
431
432    Args:
433      name: (Optional.) A name for the created operation.
434
435    Returns:
436      A (nested) structure of values matching `tf.data.Iterator.element_spec`.
437    """
438    self._get_next_call_count += 1
439    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
440      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
441
442    # TODO(b/169442955): Investigate the need for this colocation constraint.
443    with ops.colocate_with(self._iterator_resource):
444      # pylint: disable=protected-access
445      flat_ret = gen_dataset_ops.iterator_get_next(
446          self._iterator_resource,
447          output_types=self._flat_tensor_types,
448          output_shapes=self._flat_tensor_shapes,
449          name=name)
450      return structure.from_tensor_list(self._element_spec, flat_ret)
451
452  def get_next_as_optional(self):
453    # TODO(b/169442955): Investigate the need for this colocation constraint.
454    with ops.colocate_with(self._iterator_resource):
455      # pylint: disable=protected-access
456      return optional_ops._OptionalImpl(
457          gen_dataset_ops.iterator_get_next_as_optional(
458              self._iterator_resource,
459              output_types=structure.get_flat_tensor_types(self.element_spec),
460              output_shapes=structure.get_flat_tensor_shapes(
461                  self.element_spec)), self.element_spec)
462
463  def string_handle(self, name=None):
464    """Returns a string-valued `tf.Tensor` that represents this iterator.
465
466    Args:
467      name: (Optional.) A name for the created operation.
468
469    Returns:
470      A scalar `tf.Tensor` of type `tf.string`.
471    """
472    if name is None:
473      return self._string_handle
474    else:
475      return gen_dataset_ops.iterator_to_string_handle(
476          self._iterator_resource, name=name)
477
478  @property
479  @deprecation.deprecated(
480      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
481  def output_classes(self):
482    """Returns the class of each component of an element of this iterator.
483
484    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
485
486    Returns:
487      A (nested) structure of Python `type` objects corresponding to each
488      component of an element of this dataset.
489    """
490    return nest.map_structure(
491        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
492        self._element_spec)
493
494  @property
495  @deprecation.deprecated(
496      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
497  def output_shapes(self):
498    """Returns the shape of each component of an element of this iterator.
499
500    Returns:
501      A (nested) structure of `tf.TensorShape` objects corresponding to each
502      component of an element of this dataset.
503    """
504    return nest.map_structure(
505        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
506        self._element_spec)
507
508  @property
509  @deprecation.deprecated(
510      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
511  def output_types(self):
512    """Returns the type of each component of an element of this iterator.
513
514    Returns:
515      A (nested) structure of `tf.DType` objects corresponding to each component
516      of an element of this dataset.
517    """
518    return nest.map_structure(
519        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
520        self._element_spec)
521
522  @property
523  def element_spec(self):
524    """The type specification of an element of this iterator.
525
526    For more information,
527    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
528
529    Returns:
530      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
531      element of this iterator and specifying the type of individual components.
532    """
533
534    return self._element_spec
535
536  def _gather_saveables_for_checkpoint(self):
537
538    def _saveable_factory(name):
539      return _IteratorSaveable(self._iterator_resource, name)
540
541    return {"ITERATOR": _saveable_factory}
542
543
544_uid_counter = 0
545_uid_lock = threading.Lock()
546
547
548def _generate_shared_name(prefix):
549  with _uid_lock:
550    global _uid_counter
551    uid = _uid_counter
552    _uid_counter += 1
553  return "{}{}".format(prefix, uid)
554
555
556@tf_export("data.Iterator", v1=[])
557class IteratorBase(
558    collections_abc.Iterator,
559    trackable.Trackable,
560    composite_tensor.CompositeTensor,
561    metaclass=abc.ABCMeta):
562  """Represents an iterator of a `tf.data.Dataset`.
563
564  `tf.data.Iterator` is the primary mechanism for enumerating elements of a
565  `tf.data.Dataset`. It supports the Python Iterator protocol, which means
566  it can be iterated over using a for-loop:
567
568  >>> dataset = tf.data.Dataset.range(2)
569  >>> for element in dataset:
570  ...   print(element)
571  tf.Tensor(0, shape=(), dtype=int64)
572  tf.Tensor(1, shape=(), dtype=int64)
573
574  or by fetching individual elements explicitly via `get_next()`:
575
576  >>> dataset = tf.data.Dataset.range(2)
577  >>> iterator = iter(dataset)
578  >>> print(iterator.get_next())
579  tf.Tensor(0, shape=(), dtype=int64)
580  >>> print(iterator.get_next())
581  tf.Tensor(1, shape=(), dtype=int64)
582
583  In addition, non-raising iteration is supported via `get_next_as_optional()`,
584  which returns the next element (if available) wrapped in a
585  `tf.experimental.Optional`.
586
587  >>> dataset = tf.data.Dataset.from_tensors(42)
588  >>> iterator = iter(dataset)
589  >>> optional = iterator.get_next_as_optional()
590  >>> print(optional.has_value())
591  tf.Tensor(True, shape=(), dtype=bool)
592  >>> optional = iterator.get_next_as_optional()
593  >>> print(optional.has_value())
594  tf.Tensor(False, shape=(), dtype=bool)
595  """
596
597  @abc.abstractproperty
598  def element_spec(self):
599    """The type specification of an element of this iterator.
600
601    >>> dataset = tf.data.Dataset.from_tensors(42)
602    >>> iterator = iter(dataset)
603    >>> iterator.element_spec
604    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
605
606    For more information,
607    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
608
609    Returns:
610      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
611      element of this iterator, specifying the type of individual components.
612    """
613    raise NotImplementedError("Iterator.element_spec")
614
615  @abc.abstractmethod
616  def get_next(self):
617    """Returns the next element.
618
619    >>> dataset = tf.data.Dataset.from_tensors(42)
620    >>> iterator = iter(dataset)
621    >>> print(iterator.get_next())
622    tf.Tensor(42, shape=(), dtype=int32)
623
624    Returns:
625      A (nested) structure of values matching `tf.data.Iterator.element_spec`.
626
627    Raises:
628      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
629    """
630    raise NotImplementedError("Iterator.get_next()")
631
632  @abc.abstractmethod
633  def get_next_as_optional(self):
634    """Returns the next element wrapped in `tf.experimental.Optional`.
635
636    If the iterator has reached the end of the sequence, the returned
637    `tf.experimental.Optional` will have no value.
638
639    >>> dataset = tf.data.Dataset.from_tensors(42)
640    >>> iterator = iter(dataset)
641    >>> optional = iterator.get_next_as_optional()
642    >>> print(optional.has_value())
643    tf.Tensor(True, shape=(), dtype=bool)
644    >>> print(optional.get_value())
645    tf.Tensor(42, shape=(), dtype=int32)
646    >>> optional = iterator.get_next_as_optional()
647    >>> print(optional.has_value())
648    tf.Tensor(False, shape=(), dtype=bool)
649
650    Returns:
651      A `tf.experimental.Optional` object representing the next element.
652    """
653    raise NotImplementedError("Iterator.get_next_as_optional()")
654
655
656class OwnedIterator(IteratorBase):
657  """An iterator producing tf.Tensor objects from a tf.data.Dataset.
658
659  The iterator resource  created through `OwnedIterator` is owned by the Python
660  object and the life time of the underlying resource is tied to the life time
661  of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use
662  in eager mode and inside of tf.functions.
663  """
664
665  def __init__(self, dataset=None, components=None, element_spec=None):
666    """Creates a new iterator from the given dataset.
667
668    If `dataset` is not specified, the iterator will be created from the given
669    tensor components and element structure. In particular, the alternative for
670    constructing the iterator is used when the iterator is reconstructed from
671    it `CompositeTensor` representation.
672
673    Args:
674      dataset: A `tf.data.Dataset` object.
675      components: Tensor components to construct the iterator from.
676      element_spec: A (nested) structure of `TypeSpec` objects that
677        represents the type specification of elements of the iterator.
678
679    Raises:
680      ValueError: If `dataset` is not provided and either `components` or
681        `element_spec` is not provided. Or `dataset` is provided and either
682        `components` and `element_spec` is provided.
683    """
684    super(OwnedIterator, self).__init__()
685
686    if dataset is None:
687      if (components is None or element_spec is None):
688        raise ValueError(
689            "When `dataset` is not provided, both `components` and "
690            "`element_spec` must be specified.")
691      # pylint: disable=protected-access
692      self._element_spec = element_spec
693      self._flat_output_types = structure.get_flat_tensor_types(
694          self._element_spec)
695      self._flat_output_shapes = structure.get_flat_tensor_shapes(
696          self._element_spec)
697      self._iterator_resource, = components
698    else:
699      if (components is not None or element_spec is not None):
700        raise ValueError(
701            "When `dataset` is provided, `element_spec` and `components` must "
702            "not be specified.")
703      self._create_iterator(dataset)
704
705    self._get_next_call_count = 0
706
707  def _create_iterator(self, dataset):
708    # pylint: disable=protected-access
709    dataset = dataset._apply_debug_options()
710
711    # Store dataset reference to ensure that dataset is alive when this iterator
712    # is being used. For example, `tf.data.Dataset.from_generator` registers
713    # a few py_funcs that are needed in `self._next_internal`.  If the dataset
714    # is deleted, this iterator crashes on `self.__next__(...)` call.
715    self._dataset = dataset
716
717    ds_variant = dataset._variant_tensor
718    self._element_spec = dataset.element_spec
719    self._flat_output_types = structure.get_flat_tensor_types(
720        self._element_spec)
721    self._flat_output_shapes = structure.get_flat_tensor_shapes(
722        self._element_spec)
723    with ops.colocate_with(ds_variant):
724      self._iterator_resource = (
725          gen_dataset_ops.anonymous_iterator_v3(
726              output_types=self._flat_output_types,
727              output_shapes=self._flat_output_shapes))
728      if not context.executing_eagerly():
729        # Add full type information to the graph so host memory types inside
730        # variants stay on CPU, e.g, ragged string tensors.
731        # TODO(b/224776031) Remove this when AnonymousIterateV3 can use
732        # (reverse) type inference and all other ops that are needed to
733        # provide type information to the AnonymousIterateV3 also support
734        # type inference (esp. cross-function type inference) instead of
735        # setting the full type information manually.
736        fulltype = type_utils.iterator_full_type_from_spec(
737            self._element_spec)
738        # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]]
739        assert len(fulltype.args[0].args[0].args) == len(
740            self._flat_output_types)
741        self._iterator_resource.op.experimental_set_type(fulltype)
742      gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
743
744  def __iter__(self):
745    return self
746
747  def next(self):  # For Python 2 compatibility
748    return self.__next__()
749
750  def _next_internal(self):
751    autograph_status = autograph_ctx.control_status_ctx().status
752    autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED
753    if not context.executing_eagerly() and autograph_disabled:
754      self._get_next_call_count += 1
755      if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD:
756        raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE)
757
758    if not context.executing_eagerly():
759      # TODO(b/169442955): Investigate the need for this colocation constraint.
760      with ops.colocate_with(self._iterator_resource):
761        ret = gen_dataset_ops.iterator_get_next(
762            self._iterator_resource,
763            output_types=self._flat_output_types,
764            output_shapes=self._flat_output_shapes)
765      return structure.from_compatible_tensor_list(self._element_spec, ret)
766
767    # TODO(b/77291417): This runs in sync mode as iterators use an error status
768    # to communicate that there is no more data to iterate over.
769    with context.execution_mode(context.SYNC):
770      ret = gen_dataset_ops.iterator_get_next(
771          self._iterator_resource,
772          output_types=self._flat_output_types,
773          output_shapes=self._flat_output_shapes)
774
775      try:
776        # Fast path for the case `self._structure` is not a nested structure.
777        return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
778      except AttributeError:
779        return structure.from_compatible_tensor_list(self._element_spec, ret)
780
781  @property
782  def _type_spec(self):
783    return IteratorSpec(self.element_spec)
784
785  def __next__(self):
786    try:
787      return self._next_internal()
788    except errors.OutOfRangeError:
789      raise StopIteration
790
791  @property
792  @deprecation.deprecated(
793      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
794  def output_classes(self):
795    """Returns the class of each component of an element of this iterator.
796
797    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
798
799    Returns:
800      A (nested) structure of Python `type` objects corresponding to each
801      component of an element of this dataset.
802    """
803    return nest.map_structure(
804        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
805        self._element_spec)
806
807  @property
808  @deprecation.deprecated(
809      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
810  def output_shapes(self):
811    """Returns the shape of each component of an element of this iterator.
812
813    Returns:
814      A (nested) structure of `tf.TensorShape` objects corresponding to each
815      component of an element of this dataset.
816    """
817    return nest.map_structure(
818        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
819        self._element_spec)
820
821  @property
822  @deprecation.deprecated(
823      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
824  def output_types(self):
825    """Returns the type of each component of an element of this iterator.
826
827    Returns:
828      A (nested) structure of `tf.DType` objects corresponding to each component
829      of an element of this dataset.
830    """
831    return nest.map_structure(
832        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
833        self._element_spec)
834
835  @property
836  def element_spec(self):
837    return self._element_spec
838
839  def get_next(self):
840    return self._next_internal()
841
842  def get_next_as_optional(self):
843    # TODO(b/169442955): Investigate the need for this colocation constraint.
844    with ops.colocate_with(self._iterator_resource):
845      # pylint: disable=protected-access
846      return optional_ops._OptionalImpl(
847          gen_dataset_ops.iterator_get_next_as_optional(
848              self._iterator_resource,
849              output_types=structure.get_flat_tensor_types(self.element_spec),
850              output_shapes=structure.get_flat_tensor_shapes(
851                  self.element_spec)), self.element_spec)
852
853  def _gather_saveables_for_checkpoint(self):
854
855    def _saveable_factory(name):
856      """Returns a SaveableObject for serialization/deserialization."""
857      policy = None
858      if self._dataset:
859        policy = self._dataset.options().experimental_external_state_policy
860      if policy:
861        return _IteratorSaveable(
862            self._iterator_resource,
863            name,
864            external_state_policy=policy)
865      else:
866        return _IteratorSaveable(self._iterator_resource, name)
867
868    return {"ITERATOR": _saveable_factory}
869
870  def __tf_tracing_type__(self, signature_context):
871    return signature_context.make_reference_type(self._type_spec,
872                                                 self._iterator_resource._id)  # pylint:disable=protected-access
873
874
875@tf_export("data.IteratorSpec", v1=[])
876class IteratorSpec(type_spec.TypeSpec):
877  """Type specification for `tf.data.Iterator`.
878
879  For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
880  takes `tf.data.Iterator` as an input argument:
881
882  >>> @tf.function(input_signature=[tf.data.IteratorSpec(
883  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
884  ... def square(iterator):
885  ...   x = iterator.get_next()
886  ...   return x * x
887  >>> dataset = tf.data.Dataset.from_tensors(5)
888  >>> iterator = iter(dataset)
889  >>> print(square(iterator))
890  tf.Tensor(25, shape=(), dtype=int32)
891
892  Attributes:
893    element_spec: A (nested) structure of `tf.TypeSpec` objects that represents
894      the type specification of the iterator elements.
895  """
896
897  __slots__ = ["_element_spec"]
898
899  def __init__(self, element_spec):
900    self._element_spec = element_spec
901
902  @property
903  def value_type(self):
904    return OwnedIterator
905
906  def _serialize(self):
907    return (self._element_spec,)
908
909  @property
910  def _component_specs(self):
911    return (tensor_spec.TensorSpec([], dtypes.resource),)
912
913  def _to_components(self, value):
914    return (value._iterator_resource,)  # pylint: disable=protected-access
915
916  def _from_components(self, components):
917    return OwnedIterator(
918        dataset=None,
919        components=components,
920        element_spec=self._element_spec)
921
922  @staticmethod
923  def from_value(value):
924    return IteratorSpec(value.element_spec)  # pylint: disable=protected-access
925
926  def __tf_tracing_type__(self, signature_context):
927    # TODO(b/202772221): Validate and enforce this assumption of uniqueness per
928    # spec instance.
929    return signature_context.make_reference_type(self, id(self))
930
931
932# TODO(b/71645805): Expose trackable stateful objects from dataset.
933class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
934  """SaveableObject for saving/restoring iterator state."""
935
936  def __init__(
937      self,
938      iterator_resource,
939      name,
940      external_state_policy=options_lib.ExternalStatePolicy.FAIL):
941    serialized_iterator = gen_dataset_ops.serialize_iterator(
942        iterator_resource, external_state_policy=external_state_policy.value)
943    specs = [
944        BaseSaverBuilder.SaveSpec(
945            serialized_iterator,
946            "",
947            name + "_STATE",
948            device=iterator_resource.device)
949    ]
950    super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
951
952  def restore(self, restored_tensors, restored_shapes):
953    with ops.colocate_with(self.op):
954      return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
955
956
957@deprecation.deprecated(
958    None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
959@tf_export("data.experimental.get_next_as_optional")
960def get_next_as_optional(iterator):
961  """Returns a `tf.experimental.Optional` with the next element of the iterator.
962
963  If the iterator has reached the end of the sequence, the returned
964  `tf.experimental.Optional` will have no value.
965
966  Args:
967    iterator: A `tf.data.Iterator`.
968
969  Returns:
970    A `tf.experimental.Optional` object which either contains the next element
971    of the iterator (if it exists) or no value.
972  """
973  return iterator.get_next_as_optional()
974
975
976_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator)
977