xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/dataset_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16import abc
17import functools
18import multiprocessing
19import queue
20import threading
21import warnings
22
23import numpy as np
24
25from tensorflow.core.framework import dataset_metadata_pb2
26from tensorflow.core.framework import dataset_options_pb2
27from tensorflow.core.framework import graph_pb2
28from tensorflow.python import tf2
29from tensorflow.python.data.ops import iterator_ops
30from tensorflow.python.data.ops import options as options_lib
31from tensorflow.python.data.ops import structured_function
32from tensorflow.python.data.util import nest
33from tensorflow.python.data.util import random_seed
34from tensorflow.python.data.util import structure
35from tensorflow.python.data.util import traverse
36from tensorflow.python.eager import context
37from tensorflow.python.framework import auto_control_deps
38from tensorflow.python.framework import auto_control_deps_utils as acd_utils
39from tensorflow.python.framework import composite_tensor
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import function
43from tensorflow.python.framework import ops
44from tensorflow.python.framework import random_seed as core_random_seed
45from tensorflow.python.framework import smart_cond
46from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.framework import type_spec
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import check_ops
53from tensorflow.python.ops import control_flow_ops
54from tensorflow.python.ops import gen_dataset_ops
55from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
56from tensorflow.python.ops import gen_io_ops
57from tensorflow.python.ops import gen_stateless_random_ops
58from tensorflow.python.ops import logging_ops
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import random_ops
61from tensorflow.python.ops import script_ops
62from tensorflow.python.ops import string_ops
63from tensorflow.python.ops.ragged import ragged_tensor
64from tensorflow.python.trackable import asset
65from tensorflow.python.trackable import base as tracking_base
66from tensorflow.python.trackable import resource as resource_lib
67from tensorflow.python.types import trace
68from tensorflow.python.util import deprecation
69from tensorflow.python.util import lazy_loader
70from tensorflow.python.util import nest as tf_nest
71from tensorflow.python.util.compat import collections_abc
72from tensorflow.python.util.tf_export import tf_export
73
74# Symbols forwarded for legacy access through dataset_ops.py. These forwarded
75# symbols can be removed once all internal uses are updated.
76StructuredFunctionWrapper = structured_function.StructuredFunctionWrapper
77
78# Loaded lazily due to a circular dependency (roughly
79# tf.function->wrap_function->dataset->autograph->tf.function).
80# TODO(b/133251390): Use a regular import.
81wrap_function = lazy_loader.LazyLoader(
82    "wrap_function", globals(),
83    "tensorflow.python.eager.wrap_function")
84# Loaded lazily due to a circular dependency
85# dataset_ops->def_function->func_graph->autograph->dataset_ops
86# TODO(kathywu): Use a regular import.
87def_function = lazy_loader.LazyLoader(
88    "def_function", globals(),
89    "tensorflow.python.eager.def_function")
90# Loaded lazily due to a circular dependency
91# dataset_ops->parsing_ops->dataset_ops
92# TODO(varshaan): Use a regular import.
93parsing_ops = lazy_loader.LazyLoader(
94    "parsing_ops", globals(),
95    "tensorflow.python.ops.parsing_ops")
96
97
98ops.NotDifferentiable("ReduceDataset")
99
100# A constant that can be used to enable auto-tuning.
101AUTOTUNE = -1
102tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
103# TODO(b/168128531): Deprecate and remove this symbol.
104tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
105
106# Constants representing infinite and unknown cardinalities.
107INFINITE = -1
108UNKNOWN = -2
109COMPRESSION_GZIP = "GZIP"
110COMPRESSION_SNAPPY = "NONE"
111DATASET_SPEC_FILENAME = "dataset_spec.pb"
112tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE")
113tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
114
115
116def _validate_and_encode(name):
117  if not name.isidentifier():
118    raise ValueError("Invalid `name`. The argument `name` needs to be a valid "
119                     "identifier. Value is considered a valid identifier if it "
120                     "only contains alphanumeric characters (a-z), (A-Z), and "
121                     "(0-9), or underscores (_). A valid identifier cannot "
122                     "start with a number, or contain any spaces.")
123  return name.encode("utf-8")
124
125
126def _get_type(value):
127  """Returns the type of `value` if it is a TypeSpec."""
128
129  if isinstance(value, type_spec.TypeSpec):
130    return value.value_type()
131  else:
132    return type(value)
133
134
135@tf_export("data.Dataset", v1=[])
136class DatasetV2(
137    collections_abc.Iterable,
138    tracking_base.Trackable,
139    composite_tensor.CompositeTensor,
140    metaclass=abc.ABCMeta):
141  """Represents a potentially large set of elements.
142
143  The `tf.data.Dataset` API supports writing descriptive and efficient input
144  pipelines. `Dataset` usage follows a common pattern:
145
146  1. Create a source dataset from your input data.
147  2. Apply dataset transformations to preprocess the data.
148  3. Iterate over the dataset and process the elements.
149
150  Iteration happens in a streaming fashion, so the full dataset does not need to
151  fit into memory.
152
153  Source Datasets:
154
155  The simplest way to create a dataset is to create it from a python `list`:
156
157  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
158  >>> for element in dataset:
159  ...   print(element)
160  tf.Tensor(1, shape=(), dtype=int32)
161  tf.Tensor(2, shape=(), dtype=int32)
162  tf.Tensor(3, shape=(), dtype=int32)
163
164  To process lines from files, use `tf.data.TextLineDataset`:
165
166  >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
167
168  To process records written in the `TFRecord` format, use `TFRecordDataset`:
169
170  >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
171
172  To create a dataset of all files matching a pattern, use
173  `tf.data.Dataset.list_files`:
174
175  ```python
176  dataset = tf.data.Dataset.list_files("/path/*.txt")
177  ```
178
179  See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator`
180  for more ways to create datasets.
181
182  Transformations:
183
184  Once you have a dataset, you can apply transformations to prepare the data for
185  your model:
186
187  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
188  >>> dataset = dataset.map(lambda x: x*2)
189  >>> list(dataset.as_numpy_iterator())
190  [2, 4, 6]
191
192  Common Terms:
193
194  **Element**: A single output from calling `next()` on a dataset iterator.
195    Elements may be nested structures containing multiple components. For
196    example, the element `(1, (3, "apple"))` has one tuple nested in another
197    tuple. The components are `1`, `3`, and `"apple"`.
198
199  **Component**: The leaf in the nested structure of an element.
200
201  Supported types:
202
203  Elements can be nested structures of tuples, named tuples, and dictionaries.
204  Note that Python lists are *not* treated as nested structures of components.
205  Instead, lists are converted to tensors and treated as components. For
206  example, the element `(1, [1, 2, 3])` has only two components; the tensor `1`
207  and the tensor `[1, 2, 3]`. Element components can be of any type
208  representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`,
209  `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`.
210
211  ```python
212  a = 1 # Integer element
213  b = 2.0 # Float element
214  c = (1, 2) # Tuple element with 2 components
215  d = {"a": (2, 2), "b": 3} # Dict element with 3 components
216  Point = collections.namedtuple("Point", ["x", "y"])
217  e = Point(1, 2) # Named tuple
218  f = tf.data.Dataset.range(10) # Dataset element
219  ```
220
221  For more information,
222  read [this guide](https://www.tensorflow.org/guide/data).
223  """
224
225  def __init__(self, variant_tensor):
226    """Creates a DatasetV2 object.
227
228    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
229    take anything in its constructor whereas in the DatasetV2, we expect
230    subclasses to create a variant_tensor and pass it in to the super() call.
231
232    Args:
233      variant_tensor: A DT_VARIANT tensor that represents the dataset.
234    """
235    self._variant_tensor_attr = variant_tensor
236    self._graph_attr = ops.get_default_graph()
237
238    # Initialize the options for this dataset and its inputs.
239    self._options_attr = options_lib.Options()
240    for input_dataset in self._inputs():
241      input_options = None
242      if isinstance(input_dataset, DatasetV1):
243        # If the V1 dataset does not have the `_dataset` attribute, we assume it
244        # is a dataset source and hence does not have options. Otherwise, we
245        # grab the options of `_dataset` object
246        if hasattr(input_dataset, "_dataset"):
247          if not isinstance(input_dataset._dataset, DatasetV2):
248            raise TypeError(
249                f"Each input of dataset {type(self)} should be a subclass of "
250                f"`tf.data.Dataset` but encountered "
251                f"{type(input_dataset._dataset)}.")
252          input_options = input_dataset._dataset._options_attr
253      elif isinstance(input_dataset, DatasetV2):
254        input_options = input_dataset._options_attr
255      else:
256        raise TypeError(
257            f"Each input of dataset {type(self)} should be a subclass of "
258            f"`tf.data.Dataset` but encountered {type(input_dataset)}.")
259      if input_options is not None:
260        self._options_attr = self._options_attr.merge(input_options)
261    self._options_attr._set_mutable(False)  # pylint: disable=protected-access
262
263  @property
264  def _variant_tensor(self):
265    return self._variant_tensor_attr
266
267  @_variant_tensor.setter
268  def _variant_tensor(self, _):
269    raise ValueError("The `_variant_tensor` property cannot be modified.")
270
271  @deprecation.deprecated_args(None, "Use external_state_policy instead",
272                               "allow_stateful")
273  def _as_serialized_graph(
274      self,
275      allow_stateful=None,
276      strip_device_assignment=None,
277      external_state_policy=options_lib.ExternalStatePolicy.WARN):
278    """Produces serialized graph representation of the dataset.
279
280    Args:
281      allow_stateful: If true, we allow stateful ops to be present in the graph
282        def. In that case, the state in these ops would be thrown away.
283      strip_device_assignment: If true, non-local (i.e. job and task) device
284        assignment is stripped from ops in the serialized graph.
285      external_state_policy: The ExternalStatePolicy enum that determines how we
286        handle input pipelines that depend on external state. By default, its
287        set to WARN.
288
289    Returns:
290      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
291      serialized graph.
292    """
293    if external_state_policy:
294      policy = external_state_policy.value
295      return gen_dataset_ops.dataset_to_graph_v2(
296          self._variant_tensor,
297          external_state_policy=policy,
298          strip_device_assignment=strip_device_assignment)
299    if strip_device_assignment:
300      return gen_dataset_ops.dataset_to_graph(
301          self._variant_tensor,
302          allow_stateful=allow_stateful,
303          strip_device_assignment=strip_device_assignment)
304    return gen_dataset_ops.dataset_to_graph(
305        self._variant_tensor, allow_stateful=allow_stateful)
306
307  def _maybe_track_assets(self, graph_def):
308    """Finds and tracks nodes in `graph_def` that refer to asset files.
309
310    Args:
311      graph_def: Serialized graph representation of this dataset.
312
313    Returns:
314      A dictionary mapping the node name of an asset constant to a tracked
315      `asset.Asset` object.
316    """
317    asset_tracker = {}
318    for node in graph_def.node:
319      if node.name.startswith("FileIdentity"):
320        asset_tracker[node.input[0]] = None
321
322    if not asset_tracker:
323      return {}
324
325    for node in graph_def.node:
326      if node.name in asset_tracker:
327        tensor_proto = node.attr["value"].tensor
328        with context.eager_mode(), ops.device("CPU"):
329          node_value = parsing_ops.parse_tensor(
330              tensor_proto.SerializeToString(), dtypes.string).numpy()
331        asset_tracker[node.name] = ([
332            self._track_trackable(asset.Asset(n),
333                                  name=node.name + "_" + str(i), overwrite=True)
334            for i, n in enumerate(node_value)
335        ])
336    return asset_tracker
337
338  def _trackable_children(self,
339                          save_type=tracking_base.SaveType.CHECKPOINT,
340                          **kwargs):
341    if save_type != tracking_base.SaveType.SAVEDMODEL:
342      return {}
343
344    # _trace_variant_creation only works when executing eagerly, so we don't
345    # want to run it in the object initialization.
346    @def_function.function(input_signature=[], autograph=False)
347    def _creator():
348      resource = self._trace_variant_creation()()  # pylint: disable=protected-access
349      return resource
350    _creator.get_concrete_function()  # Trigger asset tracking
351
352    children = super(DatasetV2, self)._trackable_children(save_type, **kwargs)
353    children["_variant_tracker"] = _VariantTracker(self._variant_tensor,
354                                                   _creator)
355    return children
356
357  def _trace_variant_creation(self):
358    """Traces a function which outputs a variant `tf.Tensor` for this dataset.
359
360    Note that creating this function involves evaluating an op, and is currently
361    only supported when executing eagerly.
362
363    Returns:
364      A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`.
365    """
366    variant = self._variant_tensor
367    if not isinstance(variant, ops.EagerTensor):
368      raise NotImplementedError(
369          "Constructing a tf.function that reproduces a given dataset is only "
370          "supported for datasets created eagerly. Please file a feature "
371          "request if this is important to you.")
372    with context.eager_mode(), ops.device("CPU"):
373      # pylint: disable=protected-access
374      graph_def = graph_pb2.GraphDef().FromString(
375          self._as_serialized_graph(external_state_policy=options_lib
376                                    .ExternalStatePolicy.FAIL).numpy())
377    output_node_names = []
378    for node in graph_def.node:
379      if node.op == "_Retval":
380        output_node_names = node.input
381
382    if len(output_node_names) != 1:
383      raise AssertionError(
384          f"Dataset graph is expected to only have one return value but found "
385          f"{len(output_node_names)} return values: {output_node_names}.")
386
387    output_node_name = output_node_names[0]
388
389    file_path_nodes = {}
390    # When building a tf.function, track files as `saved_model.Asset`s.
391    if ops.get_default_graph().building_function:
392      asset_tracker = self._maybe_track_assets(graph_def)
393      for key in asset_tracker:
394        assets_list = [
395            array_ops.expand_dims(asset.asset_path, axis=0)
396            for asset in asset_tracker[key]
397        ]
398        file_path_nodes[key] = array_ops.concat(assets_list, axis=0)
399
400    # Add functions used in this Dataset to the function's graph, since they
401    # need to follow it around (and for example be added to a SavedModel which
402    # references the dataset).
403    variant_function = wrap_function.function_from_graph_def(
404        graph_def,
405        inputs=[],
406        outputs=output_node_name + ":0",
407        captures=file_path_nodes)
408    for used_function in self._functions():
409      used_function.function.add_to_graph(variant_function.graph)
410    return variant_function
411
412  @abc.abstractmethod
413  def _inputs(self):
414    """Returns a list of the input datasets of the dataset."""
415
416    raise NotImplementedError(f"{type(self)}._inputs()")
417
418  @property
419  def _graph(self):
420    return self._graph_attr
421
422  @_graph.setter
423  def _graph(self, _):
424    raise ValueError("The `_graph` property cannot be modified.")
425
426  # TODO(jsimsa): Change this to be the transitive closure of functions used
427  # by this dataset and its inputs.
428  def _functions(self):
429    """Returns a list of functions associated with this dataset.
430
431    Returns:
432      A list of `StructuredFunctionWrapper` objects.
433    """
434    return []
435
436  def _options(self):
437    """Returns the options tensor for this dataset."""
438    # pylint: disable=protected-access
439    return gen_dataset_ops.get_options(self._variant_tensor)
440
441  @classmethod
442  def _options_tensor_to_options(cls, serialized_options):
443    """Converts options tensor to tf.data.Options object."""
444    options = options_lib.Options()
445    if tensor_util.constant_value(serialized_options) is not None:
446      pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value(
447          serialized_options))
448      options._from_proto(pb)  # pylint: disable=protected-access
449    return options
450
451  def options(self):
452    """Returns the options for this dataset and its inputs.
453
454    Returns:
455      A `tf.data.Options` object representing the dataset options.
456    """
457    if context.executing_eagerly():
458      options = self._options_tensor_to_options(self._options())
459      options._set_mutable(False)  # pylint: disable=protected-access
460      return options
461    warnings.warn("To make it possible to preserve tf.data options across "
462                  "serialization boundaries, their implementation has moved to "
463                  "be part of the TensorFlow graph. As a consequence, the "
464                  "options value is in general no longer known at graph "
465                  "construction time. Invoking this method in graph mode "
466                  "retains the legacy behavior of the original implementation, "
467                  "but note that the returned value might not reflect the "
468                  "actual value of the options.")
469    return self._options_attr
470
471  def _apply_debug_options(self):
472    if DEBUG_MODE:
473      # Disable autotuning and static optimizations that could introduce
474      # parallelism or asynchrony.
475      options = options_lib.Options()
476      options.autotune.enabled = False
477      options.experimental_optimization.filter_parallelization = False
478      options.experimental_optimization.map_and_batch_fusion = False
479      options.experimental_optimization.map_parallelization = False
480      dataset = _OptionsDataset(self, options)
481    else:
482      dataset = self
483
484    return dataset
485
486  def __iter__(self):
487    """Creates an iterator for elements of this dataset.
488
489    The returned iterator implements the Python Iterator protocol.
490
491    Returns:
492      An `tf.data.Iterator` for the elements of this dataset.
493
494    Raises:
495      RuntimeError: If not inside of tf.function and not executing eagerly.
496    """
497    if context.executing_eagerly() or ops.inside_function():
498      with ops.colocate_with(self._variant_tensor):
499        return iterator_ops.OwnedIterator(self)
500    else:
501      raise RuntimeError("`tf.data.Dataset` only supports Python-style "
502                         "iteration in eager mode or within tf.function.")
503
504  def __bool__(self):
505    return True  # Required as __len__ is defined
506
507  __nonzero__ = __bool__  # Python 2 backward compatibility
508
509  def __len__(self):
510    """Returns the length of the dataset if it is known and finite.
511
512    This method requires that you are running in eager mode, and that the
513    length of the dataset is known and non-infinite. When the length may be
514    unknown or infinite, or if you are running in graph mode, use
515    `tf.data.Dataset.cardinality` instead.
516
517    Returns:
518      An integer representing the length of the dataset.
519
520    Raises:
521      RuntimeError: If the dataset length is unknown or infinite, or if eager
522        execution is not enabled.
523    """
524    if not context.executing_eagerly():
525      raise TypeError("`tf.data.Dataset` only supports `len` in eager mode. "
526                      "Use `tf.data.Dataset.cardinality()` instead.")
527    length = self.cardinality()
528    if length.numpy() == INFINITE:
529      raise TypeError("The dataset is infinite.")
530    if length.numpy() == UNKNOWN:
531      raise TypeError("The dataset length is unknown.")
532    return length
533
534  @abc.abstractproperty
535  def element_spec(self):
536    """The type specification of an element of this dataset.
537
538    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
539    >>> dataset.element_spec
540    TensorSpec(shape=(), dtype=tf.int32, name=None)
541
542    For more information,
543    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
544
545    Returns:
546      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
547      element of this dataset and specifying the type of individual components.
548    """
549    raise NotImplementedError(f"{type(self)}.element_spec()")
550
551  def __repr__(self):
552    type_ = type(self._dataset if isinstance(self, DatasetV1Adapter) else self)
553    return f"<{type_.__name__} element_spec={self.element_spec}>"
554
555  def __debug_string__(self):
556    """Returns a string showing the type of the dataset and its inputs.
557
558    This string is intended only for debugging purposes, and may change without
559    warning.
560    """
561    lines = []
562    to_process = [(self, 0)]  # Stack of (dataset, depth) pairs.
563    while to_process:
564      dataset, depth = to_process.pop()
565      lines.append("-"*2*depth + repr(dataset))
566      to_process.extend([(ds, depth+1) for ds in dataset._inputs()])  # pylint: disable=protected-access
567    return "\n".join(lines)
568
569  def as_numpy_iterator(self):
570    """Returns an iterator which converts all elements of the dataset to numpy.
571
572    Use `as_numpy_iterator` to inspect the content of your dataset. To see
573    element shapes and types, print dataset elements directly instead of using
574    `as_numpy_iterator`.
575
576    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
577    >>> for element in dataset:
578    ...   print(element)
579    tf.Tensor(1, shape=(), dtype=int32)
580    tf.Tensor(2, shape=(), dtype=int32)
581    tf.Tensor(3, shape=(), dtype=int32)
582
583    This method requires that you are running in eager mode and the dataset's
584    element_spec contains only `TensorSpec` components.
585
586    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
587    >>> for element in dataset.as_numpy_iterator():
588    ...   print(element)
589    1
590    2
591    3
592
593    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
594    >>> print(list(dataset.as_numpy_iterator()))
595    [1, 2, 3]
596
597    `as_numpy_iterator()` will preserve the nested structure of dataset
598    elements.
599
600    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
601    ...                                               'b': [5, 6]})
602    >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
603    ...                                       {'a': (2, 4), 'b': 6}]
604    True
605
606    Returns:
607      An iterable over the elements of the dataset, with their tensors converted
608      to numpy arrays.
609
610    Raises:
611      TypeError: if an element contains a non-`Tensor` value.
612      RuntimeError: if eager execution is not enabled.
613    """
614    if not context.executing_eagerly():
615      raise RuntimeError("`tf.data.Dataset.as_numpy_iterator()` is only "
616                         "supported in eager mode.")
617    for component_spec in nest.flatten(self.element_spec):
618      if not isinstance(component_spec,
619                        (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec,
620                         structure.NoneTensorSpec)):
621        raise TypeError(
622            f"`tf.data.Dataset.as_numpy_iterator()` is not supported for "
623            f"datasets that produce values of type {component_spec.value_type}")
624
625    return _NumpyIterator(self)
626
627  @property
628  def _flat_shapes(self):
629    """Returns a list `tf.TensorShapes`s for the element tensor representation.
630
631    Returns:
632      A list `tf.TensorShapes`s for the element tensor representation.
633    """
634    return structure.get_flat_tensor_shapes(self.element_spec)
635
636  @property
637  def _flat_types(self):
638    """Returns a list `tf.DType`s for the element tensor representation.
639
640    Returns:
641      A list `tf.DType`s for the element tensor representation.
642    """
643    return structure.get_flat_tensor_types(self.element_spec)
644
645  @property
646  def _flat_structure(self):
647    """Helper for setting `output_shapes` and `output_types` attrs of an op.
648
649    Most dataset op constructors expect `output_shapes` and `output_types`
650    arguments that represent the flattened structure of an element. This helper
651    function generates these attrs as a keyword argument dictionary, allowing
652    `Dataset._variant_tensor` implementations to pass `**self._flat_structure`
653    to the op constructor.
654
655    Returns:
656      A dictionary of keyword arguments that can be passed to a dataset op
657      constructor.
658    """
659    return {
660        "output_shapes": self._flat_shapes,
661        "output_types": self._flat_types,
662    }
663
664  @property
665  def _metadata(self):
666    """Helper for generating dataset metadata."""
667    metadata = dataset_metadata_pb2.Metadata()
668    if self._name:
669      metadata.name = _validate_and_encode(self._name)
670    return metadata
671
672  @property
673  def _common_args(self):
674    """Helper for generating arguments that are common across most dataset ops.
675
676    Most dataset op constructors expect `output_shapes` and `output_types`
677    arguments that represent the flattened structure of an element, as well as a
678    `metadata` argument for additional metadata such as user-defined dataset
679    name. This helper function generates common attributes as a keyword argument
680    dictionary, allowing `Dataset._variant_tensor` implementations to pass
681    `**self._common_args` to the op constructor.
682
683    Returns:
684      A dictionary of keyword arguments that can be passed to a dataset op
685      constructor.
686    """
687    return {
688        "metadata": self._metadata.SerializeToString(),
689        "output_shapes": self._flat_shapes,
690        "output_types": self._flat_types,
691    }
692
693  @property
694  def _type_spec(self):
695    return DatasetSpec(self.element_spec)
696
697  @staticmethod
698  def from_tensors(tensors, name=None):
699    """Creates a `Dataset` with a single element, comprising the given tensors.
700
701    `from_tensors` produces a dataset containing only a single element. To slice
702    the input tensor into multiple elements, use `from_tensor_slices` instead.
703
704    >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3])
705    >>> list(dataset.as_numpy_iterator())
706    [array([1, 2, 3], dtype=int32)]
707    >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
708    >>> list(dataset.as_numpy_iterator())
709    [(array([1, 2, 3], dtype=int32), b'A')]
710
711    >>> # You can use `from_tensors` to produce a dataset which repeats
712    >>> # the same example many times.
713    >>> example = tf.constant([1,2,3])
714    >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2)
715    >>> list(dataset.as_numpy_iterator())
716    [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]
717
718    Note that if `tensors` contains a NumPy array, and eager execution is not
719    enabled, the values will be embedded in the graph as one or more
720    `tf.constant` operations. For large datasets (> 1 GB), this can waste
721    memory and run into byte limits of graph serialization. If `tensors`
722    contains one or more large NumPy arrays, consider the alternative described
723    in [this
724    guide](https://tensorflow.org/guide/data#consuming_numpy_arrays).
725
726    Args:
727      tensors: A dataset "element". Supported values are documented
728        [here](https://www.tensorflow.org/guide/data#dataset_structure).
729      name: (Optional.) A name for the tf.data operation.
730
731    Returns:
732      Dataset: A `Dataset`.
733    """
734    return TensorDataset(tensors, name=name)
735
736  @staticmethod
737  def from_tensor_slices(tensors, name=None):
738    """Creates a `Dataset` whose elements are slices of the given tensors.
739
740    The given tensors are sliced along their first dimension. This operation
741    preserves the structure of the input tensors, removing the first dimension
742    of each tensor and using it as the dataset dimension. All input tensors
743    must have the same size in their first dimensions.
744
745    >>> # Slicing a 1D tensor produces scalar tensor elements.
746    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
747    >>> list(dataset.as_numpy_iterator())
748    [1, 2, 3]
749
750    >>> # Slicing a 2D tensor produces 1D tensor elements.
751    >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
752    >>> list(dataset.as_numpy_iterator())
753    [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
754
755    >>> # Slicing a tuple of 1D tensors produces tuple elements containing
756    >>> # scalar tensors.
757    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
758    >>> list(dataset.as_numpy_iterator())
759    [(1, 3, 5), (2, 4, 6)]
760
761    >>> # Dictionary structure is also preserved.
762    >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
763    >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
764    ...                                       {'a': 2, 'b': 4}]
765    True
766
767    >>> # Two tensors can be combined into one Dataset object.
768    >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
769    >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
770    >>> dataset = Dataset.from_tensor_slices((features, labels))
771    >>> # Both the features and the labels tensors can be converted
772    >>> # to a Dataset object separately and combined after.
773    >>> features_dataset = Dataset.from_tensor_slices(features)
774    >>> labels_dataset = Dataset.from_tensor_slices(labels)
775    >>> dataset = Dataset.zip((features_dataset, labels_dataset))
776    >>> # A batched feature and label set can be converted to a Dataset
777    >>> # in similar fashion.
778    >>> batched_features = tf.constant([[[1, 3], [2, 3]],
779    ...                                 [[2, 1], [1, 2]],
780    ...                                 [[3, 3], [3, 2]]], shape=(3, 2, 2))
781    >>> batched_labels = tf.constant([['A', 'A'],
782    ...                               ['B', 'B'],
783    ...                               ['A', 'B']], shape=(3, 2, 1))
784    >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
785    >>> for element in dataset.as_numpy_iterator():
786    ...   print(element)
787    (array([[1, 3],
788           [2, 3]], dtype=int32), array([[b'A'],
789           [b'A']], dtype=object))
790    (array([[2, 1],
791           [1, 2]], dtype=int32), array([[b'B'],
792           [b'B']], dtype=object))
793    (array([[3, 3],
794           [3, 2]], dtype=int32), array([[b'A'],
795           [b'B']], dtype=object))
796
797    Note that if `tensors` contains a NumPy array, and eager execution is not
798    enabled, the values will be embedded in the graph as one or more
799    `tf.constant` operations. For large datasets (> 1 GB), this can waste
800    memory and run into byte limits of graph serialization. If `tensors`
801    contains one or more large NumPy arrays, consider the alternative described
802    in [this guide](
803    https://tensorflow.org/guide/data#consuming_numpy_arrays).
804
805    Args:
806      tensors: A dataset element, whose components have the same first
807        dimension. Supported values are documented
808        [here](https://www.tensorflow.org/guide/data#dataset_structure).
809      name: (Optional.) A name for the tf.data operation.
810
811    Returns:
812      Dataset: A `Dataset`.
813    """
814    return TensorSliceDataset(tensors, name=name)
815
816  class _GeneratorState:
817    """Stores outstanding iterators created from a Python generator.
818
819    This class keeps track of potentially multiple iterators that may have
820    been created from a generator, e.g. in the case that the dataset is
821    repeated, or nested within a parallel computation.
822    """
823
824    def __init__(self, generator):
825      self._generator = generator
826      self._lock = threading.Lock()
827      self._next_id = 0  # GUARDED_BY(self._lock)
828      self._args = {}
829      self._iterators = {}
830
831    def _normalize_id(self, iterator_id):
832      # In debug mode, iterator ids may be eagerly-generated np.arrays instead
833      # of Tensors. We convert them to scalars to make them hashable.
834      if isinstance(iterator_id, np.ndarray):
835        return iterator_id.item()
836      return iterator_id
837
838    def get_next_id(self, *args):
839      with self._lock:
840        ret = self._next_id
841        self._next_id += 1
842      self._args[ret] = args
843      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
844      # casting in `py_func()` will create an array of `np.int32` on Windows,
845      # leading to a runtime error.
846      return np.array(ret, dtype=np.int64)
847
848    def get_iterator(self, iterator_id):
849      iterator_id = self._normalize_id(iterator_id)
850      try:
851        return self._iterators[iterator_id]
852      except KeyError:
853        iterator = iter(self._generator(*self._args.pop(iterator_id)))
854        self._iterators[iterator_id] = iterator
855        return iterator
856
857    def iterator_completed(self, iterator_id):
858      del self._iterators[self._normalize_id(iterator_id)]
859
860  @staticmethod
861  @deprecation.deprecated_args(None, "Use output_signature instead",
862                               "output_types", "output_shapes")
863  def from_generator(generator,
864                     output_types=None,
865                     output_shapes=None,
866                     args=None,
867                     output_signature=None,
868                     name=None):
869    """Creates a `Dataset` whose elements are generated by `generator`.
870
871    Note: The current implementation of `Dataset.from_generator()` uses
872    `tf.numpy_function` and inherits the same constraints. In particular, it
873    requires the dataset and iterator related operations to be placed
874    on a device in the same process as the Python program that called
875    `Dataset.from_generator()`. In particular, using `from_generator` will
876    preclude the use of tf.data service for scaling out dataset processing.
877    The body of `generator` will not be serialized in a `GraphDef`, and you
878    should not use this method if you need to serialize your model and restore
879    it in a different environment.
880
881    The `generator` argument must be a callable object that returns
882    an object that supports the `iter()` protocol (e.g. a generator function).
883
884    The elements generated by `generator` must be compatible with either the
885    given `output_signature` argument or with the given `output_types` and
886    (optionally) `output_shapes` arguments, whichever was specified.
887
888    The recommended way to call `from_generator` is to use the
889    `output_signature` argument. In this case the output will be assumed to
890    consist of objects with the classes, shapes and types defined by
891    `tf.TypeSpec` objects from `output_signature` argument:
892
893    >>> def gen():
894    ...   ragged_tensor = tf.ragged.constant([[1, 2], [3]])
895    ...   yield 42, ragged_tensor
896    >>>
897    >>> dataset = tf.data.Dataset.from_generator(
898    ...      gen,
899    ...      output_signature=(
900    ...          tf.TensorSpec(shape=(), dtype=tf.int32),
901    ...          tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
902    >>>
903    >>> list(dataset.take(1))
904    [(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
905    <tf.RaggedTensor [[1, 2], [3]]>)]
906
907    There is also a deprecated way to call `from_generator` by either with
908    `output_types` argument alone or together with `output_shapes` argument.
909    In this case the output of the function will be assumed to consist of
910    `tf.Tensor` objects with the types defined by `output_types` and with the
911    shapes which are either unknown or defined by `output_shapes`.
912
913    Note: If `generator` depends on mutable global variables or other external
914    state, be aware that the runtime may invoke `generator` multiple times
915    (in order to support repeating the `Dataset`) and at any time
916    between the call to `Dataset.from_generator()` and the production of the
917    first element from the generator. Mutating global variables or external
918    state can cause undefined behavior, and we recommend that you explicitly
919    cache any external state in `generator` before calling
920    `Dataset.from_generator()`.
921
922    Note: While the `output_signature` parameter makes it possible to yield
923    `Dataset` elements, the scope of `Dataset.from_generator()` should be
924    limited to logic that cannot be expressed through tf.data operations. Using
925    tf.data operations within the generator function is an anti-pattern and may
926    result in incremental memory growth.
927
928    Args:
929      generator: A callable object that returns an object that supports the
930        `iter()` protocol. If `args` is not specified, `generator` must take no
931        arguments; otherwise it must take as many arguments as there are values
932        in `args`.
933      output_types: (Optional.) A (nested) structure of `tf.DType` objects
934        corresponding to each component of an element yielded by `generator`.
935      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
936        objects corresponding to each component of an element yielded by
937        `generator`.
938      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
939        and passed to `generator` as NumPy-array arguments.
940      output_signature: (Optional.) A (nested) structure of `tf.TypeSpec`
941        objects corresponding to each component of an element yielded by
942        `generator`.
943      name: (Optional.) A name for the tf.data operations used by
944        `from_generator`.
945
946    Returns:
947      Dataset: A `Dataset`.
948    """
949    if not callable(generator):
950      raise TypeError("`generator` must be a Python callable.")
951
952    if output_signature is not None:
953      if output_types is not None:
954        raise TypeError("The `output_types` argument can not be used together "
955                        "with the `output_signature` argument.")
956      if output_shapes is not None:
957        raise TypeError("The `output_shapes` argument can not be used together "
958                        "with the `output_signature` argument.")
959      for spec in nest.flatten(output_signature):
960        if not isinstance(spec, type_spec.TypeSpec):
961          raise TypeError(f"`output_signature` must contain objects that are "
962                          f"subclass of `tf.TypeSpec` but found {type(spec)} "
963                          f"which is not.")
964    else:
965      if output_types is None:
966        raise TypeError("To specify the output signature you need to provide "
967                        "either the `output_signature` argument or the "
968                        "`output_types` argument.")
969
970    if output_signature is None:
971      if output_shapes is None:
972        output_shapes = nest.map_structure(
973            lambda _: tensor_shape.TensorShape(None), output_types)
974      else:
975        output_shapes = nest.map_structure_up_to(output_types,
976                                                 tensor_shape.as_shape,
977                                                 output_shapes)
978      output_signature = nest.map_structure_up_to(output_types,
979                                                  tensor_spec.TensorSpec,
980                                                  output_shapes, output_types)
981    if all(
982        isinstance(x, tensor_spec.TensorSpec)
983        for x in nest.flatten(output_signature)):
984      output_types = nest.pack_sequence_as(
985          output_signature, [x.dtype for x in nest.flatten(output_signature)])
986      output_shapes = nest.pack_sequence_as(
987          output_signature, [x.shape for x in nest.flatten(output_signature)])
988
989    if args is None:
990      args = ()
991    else:
992      args = tuple(ops.convert_n_to_tensor(args, name="args"))
993
994    generator_state = DatasetV2._GeneratorState(generator)
995
996    def get_iterator_id_fn(unused_dummy):
997      """Creates a unique `iterator_id` for each pass over the dataset.
998
999      The returned `iterator_id` disambiguates between multiple concurrently
1000      existing iterators.
1001
1002      Args:
1003        unused_dummy: Ignored value.
1004
1005      Returns:
1006        A `tf.int64` tensor whose value uniquely identifies an iterator in
1007        `generator_state`.
1008      """
1009      return script_ops.numpy_function(generator_state.get_next_id, args,
1010                                       dtypes.int64)
1011
1012    def generator_next_fn(iterator_id_t):
1013      """Generates the next element from iterator with ID `iterator_id_t`.
1014
1015      We map this function across an infinite repetition of the
1016      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
1017
1018      Args:
1019        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
1020          iterator in `generator_state` from which to generate an element.
1021
1022      Returns:
1023        The next element to generate from the iterator.
1024      """
1025      if output_types and output_shapes:
1026        flattened_types = [
1027            dtypes.as_dtype(dt) for dt in nest.flatten(output_types)
1028        ]
1029        flattened_shapes = nest.flatten(output_shapes)
1030
1031        def generator_py_func(iterator_id):
1032          """A `py_func` that will be called to invoke the iterator."""
1033          # `next()` raises `StopIteration` when there are no more
1034          # elements remaining to be generated.
1035          values = next(generator_state.get_iterator(iterator_id))
1036
1037          # Use the same _convert function from the py_func() implementation to
1038          # convert the returned values to arrays early, so that we can inspect
1039          # their values.
1040          try:
1041            flattened_values = nest.flatten_up_to(output_types, values)
1042          except (TypeError, ValueError) as e:
1043            raise TypeError(
1044                f"`generator` yielded an element that did not match the "
1045                f"expected structure. The expected structure was "
1046                f"{output_types}, but the yielded element was {values}.") from e
1047          ret_arrays = []
1048          for ret, dtype in zip(flattened_values, flattened_types):
1049            try:
1050              ret_arrays.append(
1051                  script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
1052                      ret,
1053                      dtype=dtype.as_numpy_dtype))
1054            except (TypeError, ValueError) as e:
1055              raise TypeError(
1056                  f"`generator` yielded an element that could not be "
1057                  f"converted to the expected type. The expected type was "
1058                  f"{dtype.name}, but the yielded element was {ret}.") from e
1059
1060          # Additional type and shape checking to ensure that the components of
1061          # the generated element match the `output_types` and `output_shapes`
1062          # arguments.
1063          for (ret_array, expected_dtype,
1064               expected_shape) in zip(ret_arrays, flattened_types,
1065                                      flattened_shapes):
1066            if ret_array.dtype != expected_dtype.as_numpy_dtype:
1067              raise TypeError(
1068                  f"`generator` yielded an element of type {ret_array.dtype} "
1069                  f"where an element of type {expected_dtype.as_numpy_dtype} "
1070                  f"was expected.")
1071            if not expected_shape.is_compatible_with(ret_array.shape):
1072              raise TypeError(
1073                  f"`generator` yielded an element of shape {ret_array.shape} "
1074                  f"where an element of shape {expected_shape} was expected.")
1075
1076          return ret_arrays
1077
1078        flat_values = script_ops.numpy_function(generator_py_func,
1079                                                [iterator_id_t],
1080                                                flattened_types)
1081
1082        # In debug mode the numpy_function will return a scalar if
1083        # generator_py_func produces only a single value.
1084        if not isinstance(flat_values, (list, tuple)):
1085          flat_values = [flat_values]
1086
1087        # The `py_func()` op drops the inferred shapes, so we add them back in
1088        # here.
1089        if output_shapes is not None:
1090          for ret_t, shape in zip(flat_values, flattened_shapes):
1091            ret_t.set_shape(shape)
1092
1093        return nest.pack_sequence_as(output_types, flat_values)
1094      else:
1095        flat_output_types = structure.get_flat_tensor_types(output_signature)
1096
1097        def generator_py_func(iterator_id):
1098          """A `py_func` that will be called to invoke the iterator."""
1099          # `next()` raises `StopIteration` when there are no more
1100          # elements remaining to be generated.
1101          values = next(generator_state.get_iterator(iterator_id.numpy()))
1102
1103          try:
1104            values = structure.normalize_element(values, output_signature)
1105          except (TypeError, ValueError) as e:
1106            raise TypeError(
1107                f"`generator` yielded an element that did not match the "
1108                f"expected structure. The expected structure was "
1109                f"{output_signature}, but the yielded element was "
1110                f"{values}.") from e
1111
1112          values_spec = structure.type_spec_from_value(values)
1113
1114          if not structure.are_compatible(values_spec, output_signature):
1115            raise TypeError(
1116                f"`generator` yielded an element of {values_spec} where an "
1117                f"element of {output_signature} was expected.")
1118
1119          return structure.to_tensor_list(output_signature, values)
1120
1121        return script_ops.eager_py_func(
1122            generator_py_func, inp=[iterator_id_t], Tout=flat_output_types)
1123
1124    def finalize_fn(iterator_id_t):
1125      """Releases host-side state for the iterator with ID `iterator_id_t`."""
1126
1127      def finalize_py_func(iterator_id):
1128        generator_state.iterator_completed(iterator_id)
1129        # We return a dummy value so that the `finalize_fn` has a valid
1130        # signature.
1131        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
1132        # casting in `py_func()` will create an array of `np.int32` on Windows,
1133        # leading to a runtime error.
1134        return np.array(0, dtype=np.int64)
1135
1136      return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
1137                                       dtypes.int64)
1138
1139    # This function associates each traversal of `generator` with a unique
1140    # iterator ID.
1141    def flat_map_fn(dummy_arg):
1142      # The `get_iterator_id_fn` gets a unique ID for the current instance of
1143      # of the generator.
1144      # The `generator_next_fn` gets the next element from the iterator with the
1145      # given ID, and raises StopIteration when that iterator contains no
1146      # more elements.
1147      return _GeneratorDataset(
1148          dummy_arg,
1149          get_iterator_id_fn,
1150          generator_next_fn,
1151          finalize_fn,
1152          output_signature,
1153          name=name)
1154
1155    # A single-element dataset that, each time it is evaluated, contains a
1156    # freshly-generated and unique (for the returned dataset) int64
1157    # ID that will be used to identify the appropriate Python state, which
1158    # is encapsulated in `generator_state`, and captured in
1159    # `get_iterator_id_map_fn`.
1160    dummy = 0
1161    id_dataset = Dataset.from_tensors(dummy, name=name)
1162
1163    # A dataset that contains all of the elements generated by a
1164    # single iterator created from `generator`, identified by the
1165    # iterator ID contained in `id_dataset`. Lifting the iteration
1166    # into a flat_map here enables multiple repetitions and/or nested
1167    # versions of the returned dataset to be created, because it forces
1168    # the generation of a new ID for each version.
1169    return id_dataset.flat_map(flat_map_fn, name=name)
1170
1171  @staticmethod
1172  def range(*args, **kwargs):
1173    """Creates a `Dataset` of a step-separated range of values.
1174
1175    >>> list(Dataset.range(5).as_numpy_iterator())
1176    [0, 1, 2, 3, 4]
1177    >>> list(Dataset.range(2, 5).as_numpy_iterator())
1178    [2, 3, 4]
1179    >>> list(Dataset.range(1, 5, 2).as_numpy_iterator())
1180    [1, 3]
1181    >>> list(Dataset.range(1, 5, -2).as_numpy_iterator())
1182    []
1183    >>> list(Dataset.range(5, 1).as_numpy_iterator())
1184    []
1185    >>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
1186    [5, 3]
1187    >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
1188    [2, 3, 4]
1189    >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
1190    [1.0, 3.0]
1191
1192    Args:
1193      *args: follows the same semantics as python's range.
1194        len(args) == 1 -> start = 0, stop = args[0], step = 1.
1195        len(args) == 2 -> start = args[0], stop = args[1], step = 1.
1196        len(args) == 3 -> start = args[0], stop = args[1], step = args[2].
1197      **kwargs:
1198        - output_type: Its expected dtype. (Optional, default: `tf.int64`).
1199        - name: (Optional.) A name for the tf.data operation.
1200
1201    Returns:
1202      Dataset: A `RangeDataset`.
1203
1204    Raises:
1205      ValueError: if len(args) == 0.
1206    """
1207    return RangeDataset(*args, **kwargs)
1208
1209  @staticmethod
1210  def zip(datasets, name=None):
1211    """Creates a `Dataset` by zipping together the given datasets.
1212
1213    This method has similar semantics to the built-in `zip()` function
1214    in Python, with the main difference being that the `datasets`
1215    argument can be a (nested) structure of `Dataset` objects. The supported
1216    nesting mechanisms are documented
1217    [here] (https://www.tensorflow.org/guide/data#dataset_structure).
1218
1219    >>> # The nested structure of the `datasets` argument determines the
1220    >>> # structure of elements in the resulting dataset.
1221    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1222    >>> b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
1223    >>> ds = tf.data.Dataset.zip((a, b))
1224    >>> list(ds.as_numpy_iterator())
1225    [(1, 4), (2, 5), (3, 6)]
1226    >>> ds = tf.data.Dataset.zip((b, a))
1227    >>> list(ds.as_numpy_iterator())
1228    [(4, 1), (5, 2), (6, 3)]
1229    >>>
1230    >>> # The `datasets` argument may contain an arbitrary number of datasets.
1231    >>> c = tf.data.Dataset.range(7, 13).batch(2)  # ==> [ [7, 8],
1232    ...                                            #       [9, 10],
1233    ...                                            #       [11, 12] ]
1234    >>> ds = tf.data.Dataset.zip((a, b, c))
1235    >>> for element in ds.as_numpy_iterator():
1236    ...   print(element)
1237    (1, 4, array([7, 8]))
1238    (2, 5, array([ 9, 10]))
1239    (3, 6, array([11, 12]))
1240    >>>
1241    >>> # The number of elements in the resulting dataset is the same as
1242    >>> # the size of the smallest dataset in `datasets`.
1243    >>> d = tf.data.Dataset.range(13, 15)  # ==> [ 13, 14 ]
1244    >>> ds = tf.data.Dataset.zip((a, d))
1245    >>> list(ds.as_numpy_iterator())
1246    [(1, 13), (2, 14)]
1247
1248    Args:
1249      datasets: A (nested) structure of datasets.
1250      name: (Optional.) A name for the tf.data operation.
1251
1252    Returns:
1253      Dataset: A `Dataset`.
1254    """
1255    return ZipDataset(datasets, name=name)
1256
1257  def concatenate(self, dataset, name=None):
1258    """Creates a `Dataset` by concatenating the given dataset with this dataset.
1259
1260    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1261    >>> b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
1262    >>> ds = a.concatenate(b)
1263    >>> list(ds.as_numpy_iterator())
1264    [1, 2, 3, 4, 5, 6, 7]
1265    >>> # The input dataset and dataset to be concatenated should have
1266    >>> # compatible element specs.
1267    >>> c = tf.data.Dataset.zip((a, b))
1268    >>> a.concatenate(c)
1269    Traceback (most recent call last):
1270    TypeError: Two datasets to concatenate have different types
1271    <dtype: 'int64'> and (tf.int64, tf.int64)
1272    >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
1273    >>> a.concatenate(d)
1274    Traceback (most recent call last):
1275    TypeError: Two datasets to concatenate have different types
1276    <dtype: 'int64'> and <dtype: 'string'>
1277
1278    Args:
1279      dataset: `Dataset` to be concatenated.
1280      name: (Optional.) A name for the tf.data operation.
1281
1282    Returns:
1283      Dataset: A `Dataset`.
1284    """
1285    return ConcatenateDataset(self, dataset, name=name)
1286
1287  def prefetch(self, buffer_size, name=None):
1288    """Creates a `Dataset` that prefetches elements from this dataset.
1289
1290    Most dataset input pipelines should end with a call to `prefetch`. This
1291    allows later elements to be prepared while the current element is being
1292    processed. This often improves latency and throughput, at the cost of
1293    using additional memory to store prefetched elements.
1294
1295    Note: Like other `Dataset` methods, prefetch operates on the
1296    elements of the input dataset. It has no concept of examples vs. batches.
1297    `examples.prefetch(2)` will prefetch two elements (2 examples),
1298    while `examples.batch(20).prefetch(2)` will prefetch 2 elements
1299    (2 batches, of 20 examples each).
1300
1301    >>> dataset = tf.data.Dataset.range(3)
1302    >>> dataset = dataset.prefetch(2)
1303    >>> list(dataset.as_numpy_iterator())
1304    [0, 1, 2]
1305
1306    Args:
1307      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
1308        number of elements that will be buffered when prefetching. If the value
1309        `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned.
1310      name: Optional. A name for the tf.data transformation.
1311
1312    Returns:
1313      Dataset: A `Dataset`.
1314    """
1315    if DEBUG_MODE:
1316      return self
1317    return PrefetchDataset(self, buffer_size, name=name)
1318
1319  @staticmethod
1320  def list_files(file_pattern, shuffle=None, seed=None, name=None):
1321    """A dataset of all files matching one or more glob patterns.
1322
1323    The `file_pattern` argument should be a small number of glob patterns.
1324    If your filenames have already been globbed, use
1325    `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every
1326    filename with `list_files` may result in poor performance with remote
1327    storage systems.
1328
1329    Note: The default behavior of this method is to return filenames in
1330    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
1331    to get results in a deterministic order.
1332
1333    Example:
1334      If we had the following files on our filesystem:
1335
1336        - /path/to/dir/a.txt
1337        - /path/to/dir/b.py
1338        - /path/to/dir/c.py
1339
1340      If we pass "/path/to/dir/*.py" as the directory, the dataset
1341      would produce:
1342
1343        - /path/to/dir/b.py
1344        - /path/to/dir/c.py
1345
1346    Args:
1347      file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
1348        (scalar or vector), representing the filename glob (i.e. shell wildcard)
1349        pattern(s) that will be matched.
1350      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
1351        Defaults to `True`.
1352      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1353        seed that will be used to create the distribution. See
1354        `tf.random.set_seed` for behavior.
1355      name: Optional. A name for the tf.data operations used by `list_files`.
1356
1357    Returns:
1358     Dataset: A `Dataset` of strings corresponding to file names.
1359    """
1360    with ops.name_scope("list_files"):
1361      if shuffle is None:
1362        shuffle = True
1363      file_pattern = ops.convert_to_tensor(
1364          file_pattern, dtype=dtypes.string, name="file_pattern")
1365      matching_files = gen_io_ops.matching_files(file_pattern)
1366
1367      # Raise an exception if `file_pattern` does not match any files.
1368      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
1369                                   name="match_not_empty")
1370
1371      message = math_ops.add(
1372          "No files matched pattern: ",
1373          string_ops.reduce_join(file_pattern, separator=", "), name="message")
1374
1375      assert_not_empty = control_flow_ops.Assert(
1376          condition, [message], summarize=1, name="assert_not_empty")
1377      with ops.control_dependencies([assert_not_empty]):
1378        matching_files = array_ops.identity(matching_files)
1379
1380      dataset = TensorSliceDataset(matching_files, is_files=True, name=name)
1381      if issubclass(Dataset, DatasetV1):
1382        dataset = DatasetV1Adapter(dataset)
1383      if shuffle:
1384        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
1385        # list of files might be empty.
1386        buffer_size = math_ops.maximum(
1387            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
1388        dataset = dataset.shuffle(buffer_size, seed=seed, name=name)
1389      return dataset
1390
1391  def repeat(self, count=None, name=None):
1392    """Repeats this dataset so each original value is seen `count` times.
1393
1394    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1395    >>> dataset = dataset.repeat(3)
1396    >>> list(dataset.as_numpy_iterator())
1397    [1, 2, 3, 1, 2, 3, 1, 2, 3]
1398
1399    Note: If the input dataset depends on global state (e.g. a random number
1400    generator) or its output is non-deterministic (e.g. because of upstream
1401    `shuffle`), then different repetitions may produce different elements.
1402
1403    Args:
1404      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1405        number of times the dataset should be repeated. The default behavior (if
1406        `count` is `None` or `-1`) is for the dataset be repeated indefinitely.
1407      name: (Optional.) A name for the tf.data operation.
1408
1409    Returns:
1410      Dataset: A `Dataset`.
1411    """
1412    return RepeatDataset(self, count, name=name)
1413
1414  def enumerate(self, start=0, name=None):
1415    """Enumerates the elements of this dataset.
1416
1417    It is similar to python's `enumerate`.
1418
1419    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1420    >>> dataset = dataset.enumerate(start=5)
1421    >>> for element in dataset.as_numpy_iterator():
1422    ...   print(element)
1423    (5, 1)
1424    (6, 2)
1425    (7, 3)
1426
1427    >>> # The (nested) structure of the input dataset determines the
1428    >>> # structure of elements in the resulting dataset.
1429    >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
1430    >>> dataset = dataset.enumerate()
1431    >>> for element in dataset.as_numpy_iterator():
1432    ...   print(element)
1433    (0, array([7, 8], dtype=int32))
1434    (1, array([ 9, 10], dtype=int32))
1435
1436    Args:
1437      start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
1438        enumeration.
1439      name: Optional. A name for the tf.data operations used by `enumerate`.
1440
1441    Returns:
1442      Dataset: A `Dataset`.
1443    """
1444
1445    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
1446    range_dataset = Dataset.range(start, max_value, name=name)
1447    # Replicate the range component so that each split is enumerated
1448    # independently. This avoids the need for prohibitively expensive
1449    # cross-split coordination.
1450    range_dataset = _apply_rewrite(range_dataset, "replicate_on_split")
1451    return Dataset.zip((range_dataset, self), name=name)
1452
1453  def shuffle(self,
1454              buffer_size,
1455              seed=None,
1456              reshuffle_each_iteration=None,
1457              name=None):
1458    """Randomly shuffles the elements of this dataset.
1459
1460    This dataset fills a buffer with `buffer_size` elements, then randomly
1461    samples elements from this buffer, replacing the selected elements with new
1462    elements. For perfect shuffling, a buffer size greater than or equal to the
1463    full size of the dataset is required.
1464
1465    For instance, if your dataset contains 10,000 elements but `buffer_size` is
1466    set to 1,000, then `shuffle` will initially select a random element from
1467    only the first 1,000 elements in the buffer. Once an element is selected,
1468    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
1469    maintaining the 1,000 element buffer.
1470
1471    `reshuffle_each_iteration` controls whether the shuffle order should be
1472    different for each epoch. In TF 1.X, the idiomatic way to create epochs
1473    was through the `repeat` transformation:
1474
1475    ```python
1476    dataset = tf.data.Dataset.range(3)
1477    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1478    dataset = dataset.repeat(2)
1479    # [1, 0, 2, 1, 2, 0]
1480
1481    dataset = tf.data.Dataset.range(3)
1482    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1483    dataset = dataset.repeat(2)
1484    # [1, 0, 2, 1, 0, 2]
1485    ```
1486
1487    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
1488    possible to also create epochs through Python iteration:
1489
1490    ```python
1491    dataset = tf.data.Dataset.range(3)
1492    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1493    list(dataset.as_numpy_iterator())
1494    # [1, 0, 2]
1495    list(dataset.as_numpy_iterator())
1496    # [1, 2, 0]
1497    ```
1498
1499    ```python
1500    dataset = tf.data.Dataset.range(3)
1501    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1502    list(dataset.as_numpy_iterator())
1503    # [1, 0, 2]
1504    list(dataset.as_numpy_iterator())
1505    # [1, 0, 2]
1506    ```
1507
1508    Args:
1509      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1510        elements from this dataset from which the new dataset will sample.
1511      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1512        seed that will be used to create the distribution. See
1513        `tf.random.set_seed` for behavior.
1514      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1515        that the dataset should be pseudorandomly reshuffled each time it is
1516        iterated over. (Defaults to `True`.)
1517      name: (Optional.) A name for the tf.data operation.
1518
1519    Returns:
1520      Dataset: A `Dataset`.
1521    """
1522    return ShuffleDataset(
1523        self, buffer_size, seed, reshuffle_each_iteration, name=name)
1524
1525  def cache(self, filename="", name=None):
1526    """Caches the elements in this dataset.
1527
1528    The first time the dataset is iterated over, its elements will be cached
1529    either in the specified file or in memory. Subsequent iterations will
1530    use the cached data.
1531
1532    Note: To guarantee that the cache gets finalized, the input dataset must be
1533    iterated through in its entirety, until it raises StopIteration. Otherwise,
1534    subsequent iterations may not use cached data.
1535
1536    >>> dataset = tf.data.Dataset.range(5)
1537    >>> dataset = dataset.map(lambda x: x**2)
1538    >>> dataset = dataset.cache()
1539    >>> # The first time reading through the data will generate the data using
1540    >>> # `range` and `map`.
1541    >>> list(dataset.as_numpy_iterator())
1542    [0, 1, 4, 9, 16]
1543    >>> # Subsequent iterations read from the cache.
1544    >>> list(dataset.as_numpy_iterator())
1545    [0, 1, 4, 9, 16]
1546
1547    When caching to a file, the cached data will persist across runs. Even the
1548    first iteration through the data will read from the cache file. Changing
1549    the input pipeline before the call to `.cache()` will have no effect until
1550    the cache file is removed or the filename is changed.
1551
1552    ```python
1553    dataset = tf.data.Dataset.range(5)
1554    dataset = dataset.cache("/path/to/file")
1555    list(dataset.as_numpy_iterator())
1556    # [0, 1, 2, 3, 4]
1557    dataset = tf.data.Dataset.range(10)
1558    dataset = dataset.cache("/path/to/file")  # Same file!
1559    list(dataset.as_numpy_iterator())
1560    # [0, 1, 2, 3, 4]
1561    ```
1562
1563    Note: `cache` will produce exactly the same elements during each iteration
1564    through the dataset. If you wish to randomize the iteration order, make sure
1565    to call `shuffle` *after* calling `cache`.
1566
1567    Args:
1568      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
1569        directory on the filesystem to use for caching elements in this Dataset.
1570        If a filename is not provided, the dataset will be cached in memory.
1571      name: (Optional.) A name for the tf.data operation.
1572
1573    Returns:
1574      Dataset: A `Dataset`.
1575    """
1576    return CacheDataset(self, filename, name=name)
1577
1578  def take(self, count, name=None):
1579    """Creates a `Dataset` with at most `count` elements from this dataset.
1580
1581    >>> dataset = tf.data.Dataset.range(10)
1582    >>> dataset = dataset.take(3)
1583    >>> list(dataset.as_numpy_iterator())
1584    [0, 1, 2]
1585
1586    Args:
1587      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1588        elements of this dataset that should be taken to form the new dataset.
1589        If `count` is -1, or if `count` is greater than the size of this
1590        dataset, the new dataset will contain all elements of this dataset.
1591      name: (Optional.) A name for the tf.data operation.
1592
1593    Returns:
1594      Dataset: A `Dataset`.
1595    """
1596    return TakeDataset(self, count, name=name)
1597
1598  def skip(self, count, name=None):
1599    """Creates a `Dataset` that skips `count` elements from this dataset.
1600
1601    >>> dataset = tf.data.Dataset.range(10)
1602    >>> dataset = dataset.skip(7)
1603    >>> list(dataset.as_numpy_iterator())
1604    [7, 8, 9]
1605
1606    Args:
1607      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1608        elements of this dataset that should be skipped to form the new dataset.
1609        If `count` is greater than the size of this dataset, the new dataset
1610        will contain no elements.  If `count` is -1, skips the entire dataset.
1611      name: (Optional.) A name for the tf.data operation.
1612
1613    Returns:
1614      Dataset: A `Dataset`.
1615    """
1616    return SkipDataset(self, count, name=name)
1617
1618  def shard(self, num_shards, index, name=None):
1619    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
1620
1621    `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will
1622    contain all elements of A whose index mod n = i.
1623
1624    >>> A = tf.data.Dataset.range(10)
1625    >>> B = A.shard(num_shards=3, index=0)
1626    >>> list(B.as_numpy_iterator())
1627    [0, 3, 6, 9]
1628    >>> C = A.shard(num_shards=3, index=1)
1629    >>> list(C.as_numpy_iterator())
1630    [1, 4, 7]
1631    >>> D = A.shard(num_shards=3, index=2)
1632    >>> list(D.as_numpy_iterator())
1633    [2, 5, 8]
1634
1635    This dataset operator is very useful when running distributed training, as
1636    it allows each worker to read a unique subset.
1637
1638    When reading a single input file, you can shard elements as follows:
1639
1640    ```python
1641    d = tf.data.TFRecordDataset(input_file)
1642    d = d.shard(num_workers, worker_index)
1643    d = d.repeat(num_epochs)
1644    d = d.shuffle(shuffle_buffer_size)
1645    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1646    ```
1647
1648    Important caveats:
1649
1650    - Be sure to shard before you use any randomizing operator (such as
1651      shuffle).
1652    - Generally it is best if the shard operator is used early in the dataset
1653      pipeline. For example, when reading from a set of TFRecord files, shard
1654      before converting the dataset to input samples. This avoids reading every
1655      file on every worker. The following is an example of an efficient
1656      sharding strategy within a complete pipeline:
1657
1658    ```python
1659    d = Dataset.list_files(pattern)
1660    d = d.shard(num_workers, worker_index)
1661    d = d.repeat(num_epochs)
1662    d = d.shuffle(shuffle_buffer_size)
1663    d = d.interleave(tf.data.TFRecordDataset,
1664                     cycle_length=num_readers, block_length=1)
1665    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1666    ```
1667
1668    Args:
1669      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
1670        shards operating in parallel.
1671      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
1672      name: (Optional.) A name for the tf.data operation.
1673
1674    Returns:
1675      Dataset: A `Dataset`.
1676
1677    Raises:
1678      InvalidArgumentError: if `num_shards` or `index` are illegal values.
1679
1680        Note: error checking is done on a best-effort basis, and errors aren't
1681        guaranteed to be caught upon dataset creation. (e.g. providing in a
1682        placeholder tensor bypasses the early checking, and will instead result
1683        in an error during a session.run call.)
1684    """
1685    return ShardDataset(self, num_shards, index, name=name)
1686
1687  def save(self,
1688           path,
1689           compression=None,
1690           shard_func=None,
1691           checkpoint_args=None):
1692    """Saves the content of the given dataset.
1693
1694      Example usage:
1695
1696      >>> import tempfile
1697      >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
1698      >>> # Save a dataset
1699      >>> dataset = tf.data.Dataset.range(2)
1700      >>> dataset.save(path)
1701      >>> new_dataset = tf.data.Dataset.load(path)
1702      >>> for elem in new_dataset:
1703      ...   print(elem)
1704      tf.Tensor(0, shape=(), dtype=int64)
1705      tf.Tensor(1, shape=(), dtype=int64)
1706
1707      The saved dataset is saved in multiple file "shards". By default, the
1708      dataset output is divided to shards in a round-robin fashion but custom
1709      sharding can be specified via the `shard_func` function. For example, you
1710      can save the dataset to using a single shard as follows:
1711
1712      ```python
1713      dataset = make_dataset()
1714      def custom_shard_func(element):
1715        return np.int64(0)
1716      dataset.save(
1717          path="/path/to/data", ..., shard_func=custom_shard_func)
1718      ```
1719
1720      To enable checkpointing, pass in `checkpoint_args` to the `save` method
1721      as follows:
1722
1723      ```python
1724      dataset = tf.data.Dataset.range(100)
1725      save_dir = "..."
1726      checkpoint_prefix = "..."
1727      step_counter = tf.Variable(0, trainable=False)
1728      checkpoint_args = {
1729        "checkpoint_interval": 50,
1730        "step_counter": step_counter,
1731        "directory": checkpoint_prefix,
1732        "max_to_keep": 20,
1733      }
1734      dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)
1735      ```
1736
1737      NOTE: The directory layout and file format used for saving the dataset is
1738      considered an implementation detail and may change. For this reason,
1739      datasets saved through `tf.data.Dataset.save` should only be consumed
1740      through `tf.data.Dataset.load`, which is guaranteed to be
1741      backwards compatible.
1742
1743    Args:
1744     path: Required. A directory to use for saving the dataset.
1745     compression: Optional. The algorithm to use to compress data when writing
1746          it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
1747     shard_func: Optional. A function to control the mapping of dataset
1748          elements to file shards. The function is expected to map elements of
1749          the input dataset to int64 shard IDs. If present, the function will be
1750          traced and executed as graph computation.
1751     checkpoint_args: Optional args for checkpointing which will be passed into
1752          the `tf.train.CheckpointManager`. If `checkpoint_args` are not
1753          specified, then checkpointing will not be performed. The `save()`
1754          implementation creates a `tf.train.Checkpoint` object internally, so
1755          users should not set the `checkpoint` argument in `checkpoint_args`.
1756
1757    Raises:
1758      ValueError if `checkpoint` is passed into `checkpoint_args`.
1759    """
1760    # Loaded lazily due to a circular dependency
1761    # dataset_ops->save_ops->dataset_ops
1762    from tensorflow.python.data.ops import save_op  # pylint: disable=g-import-not-at-top
1763    save_op.save(self, path, compression, shard_func, checkpoint_args)
1764
1765  @staticmethod
1766  def load(path, element_spec=None, compression=None, reader_func=None):
1767    """Loads a previously saved dataset.
1768
1769    Example usage:
1770
1771    >>> import tempfile
1772    >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
1773    >>> # Save a dataset
1774    >>> dataset = tf.data.Dataset.range(2)
1775    >>> tf.data.Dataset.save(dataset, path)
1776    >>> new_dataset = tf.data.Dataset.load(path)
1777    >>> for elem in new_dataset:
1778    ...   print(elem)
1779    tf.Tensor(0, shape=(), dtype=int64)
1780    tf.Tensor(1, shape=(), dtype=int64)
1781
1782
1783    If the default option of sharding the saved dataset was used, the element
1784    order of the saved dataset will be preserved when loading it.
1785
1786    The `reader_func` argument can be used to specify a custom order in which
1787    elements should be loaded from the individual shards. The `reader_func` is
1788    expected to take a single argument -- a dataset of datasets, each containing
1789    elements of one of the shards -- and return a dataset of elements. For
1790    example, the order of shards can be shuffled when loading them as follows:
1791
1792    ```python
1793    def custom_reader_func(datasets):
1794      datasets = datasets.shuffle(NUM_SHARDS)
1795      return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
1796
1797    dataset = tf.data.Dataset.load(
1798        path="/path/to/data", ..., reader_func=custom_reader_func)
1799    ```
1800
1801    Args:
1802      path: Required. A path pointing to a previously saved dataset.
1803      element_spec: Optional. A nested structure of `tf.TypeSpec` objects
1804        matching the structure of an element of the saved dataset and specifying
1805        the type of individual element components. If not provided, the nested
1806        structure of `tf.TypeSpec` saved with the saved dataset is used. Note
1807        that this argument is required in graph mode.
1808      compression: Optional. The algorithm to use to decompress the data when
1809        reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
1810      reader_func: Optional. A function to control how to read data from shards.
1811        If present, the function will be traced and executed as graph
1812        computation.
1813
1814    Returns:
1815      A `tf.data.Dataset` instance.
1816
1817    Raises:
1818      FileNotFoundError: If `element_spec` is not specified and the saved nested
1819        structure of `tf.TypeSpec` can not be located with the saved dataset.
1820      ValueError: If `element_spec` is not specified and the method is executed
1821        in graph mode.
1822    """
1823    # Loaded lazily due to a circular dependency
1824    # dataset_ops->load_ops->dataset_ops
1825    from tensorflow.python.data.ops import load_op  # pylint: disable=g-import-not-at-top
1826    return load_op.load(
1827        path=path,
1828        element_spec=element_spec,
1829        compression=compression,
1830        reader_func=reader_func)
1831
1832  def batch(self,
1833            batch_size,
1834            drop_remainder=False,
1835            num_parallel_calls=None,
1836            deterministic=None,
1837            name=None):
1838    """Combines consecutive elements of this dataset into batches.
1839
1840    >>> dataset = tf.data.Dataset.range(8)
1841    >>> dataset = dataset.batch(3)
1842    >>> list(dataset.as_numpy_iterator())
1843    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
1844
1845    >>> dataset = tf.data.Dataset.range(8)
1846    >>> dataset = dataset.batch(3, drop_remainder=True)
1847    >>> list(dataset.as_numpy_iterator())
1848    [array([0, 1, 2]), array([3, 4, 5])]
1849
1850    The components of the resulting element will have an additional outer
1851    dimension, which will be `batch_size` (or `N % batch_size` for the last
1852    element if `batch_size` does not divide the number of input elements `N`
1853    evenly and `drop_remainder` is `False`). If your program depends on the
1854    batches having the same outer dimension, you should set the `drop_remainder`
1855    argument to `True` to prevent the smaller batch from being produced.
1856
1857    Note: If your program requires data to have a statically known shape (e.g.,
1858    when using XLA), you should use `drop_remainder=True`. Without
1859    `drop_remainder=True` the shape of the output dataset will have an unknown
1860    leading dimension due to the possibility of a smaller final batch.
1861
1862    Args:
1863      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1864        consecutive elements of this dataset to combine in a single batch.
1865      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1866        whether the last batch should be dropped in the case it has fewer than
1867        `batch_size` elements; the default behavior is not to drop the smaller
1868        batch.
1869      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1870        representing the number of batches to compute asynchronously in
1871        parallel.
1872        If not specified, batches will be computed sequentially. If the value
1873        `tf.data.AUTOTUNE` is used, then the number of parallel
1874        calls is set dynamically based on available resources.
1875      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
1876        boolean is specified (`True` or `False`), it controls the order in which
1877        the transformation produces elements. If set to `False`, the
1878        transformation is allowed to yield elements out of order to trade
1879        determinism for performance. If not specified, the
1880        `tf.data.Options.deterministic` option (`True` by default) controls the
1881        behavior.
1882      name: (Optional.) A name for the tf.data operation.
1883
1884    Returns:
1885      Dataset: A `Dataset`.
1886    """
1887    if num_parallel_calls is None or DEBUG_MODE:
1888      if deterministic is not None and not DEBUG_MODE:
1889        warnings.warn("The `deterministic` argument has no effect unless the "
1890                      "`num_parallel_calls` argument is specified.")
1891      return BatchDataset(self, batch_size, drop_remainder, name=name)
1892    else:
1893      return ParallelBatchDataset(
1894          self,
1895          batch_size,
1896          drop_remainder,
1897          num_parallel_calls,
1898          deterministic,
1899          name=name)
1900
1901  def padded_batch(self,
1902                   batch_size,
1903                   padded_shapes=None,
1904                   padding_values=None,
1905                   drop_remainder=False,
1906                   name=None):
1907    """Combines consecutive elements of this dataset into padded batches.
1908
1909    This transformation combines multiple consecutive elements of the input
1910    dataset into a single element.
1911
1912    Like `tf.data.Dataset.batch`, the components of the resulting element will
1913    have an additional outer dimension, which will be `batch_size` (or
1914    `N % batch_size` for the last element if `batch_size` does not divide the
1915    number of input elements `N` evenly and `drop_remainder` is `False`). If
1916    your program depends on the batches having the same outer dimension, you
1917    should set the `drop_remainder` argument to `True` to prevent the smaller
1918    batch from being produced.
1919
1920    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
1921    different shapes, and this transformation will pad each component to the
1922    respective shape in `padded_shapes`. The `padded_shapes` argument
1923    determines the resulting shape for each dimension of each component in an
1924    output element:
1925
1926    * If the dimension is a constant, the component will be padded out to that
1927      length in that dimension.
1928    * If the dimension is unknown, the component will be padded out to the
1929      maximum length of all elements in that dimension.
1930
1931    >>> A = (tf.data.Dataset
1932    ...      .range(1, 5, output_type=tf.int32)
1933    ...      .map(lambda x: tf.fill([x], x)))
1934    >>> # Pad to the smallest per-batch size that fits all elements.
1935    >>> B = A.padded_batch(2)
1936    >>> for element in B.as_numpy_iterator():
1937    ...   print(element)
1938    [[1 0]
1939     [2 2]]
1940    [[3 3 3 0]
1941     [4 4 4 4]]
1942    >>> # Pad to a fixed size.
1943    >>> C = A.padded_batch(2, padded_shapes=5)
1944    >>> for element in C.as_numpy_iterator():
1945    ...   print(element)
1946    [[1 0 0 0 0]
1947     [2 2 0 0 0]]
1948    [[3 3 3 0 0]
1949     [4 4 4 4 0]]
1950    >>> # Pad with a custom value.
1951    >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
1952    >>> for element in D.as_numpy_iterator():
1953    ...   print(element)
1954    [[ 1 -1 -1 -1 -1]
1955     [ 2  2 -1 -1 -1]]
1956    [[ 3  3  3 -1 -1]
1957     [ 4  4  4  4 -1]]
1958    >>> # Components of nested elements can be padded independently.
1959    >>> elements = [([1, 2, 3], [10]),
1960    ...             ([4, 5], [11, 12])]
1961    >>> dataset = tf.data.Dataset.from_generator(
1962    ...     lambda: iter(elements), (tf.int32, tf.int32))
1963    >>> # Pad the first component of the tuple to length 4, and the second
1964    >>> # component to the smallest size that fits.
1965    >>> dataset = dataset.padded_batch(2,
1966    ...     padded_shapes=([4], [None]),
1967    ...     padding_values=(-1, 100))
1968    >>> list(dataset.as_numpy_iterator())
1969    [(array([[ 1,  2,  3, -1], [ 4,  5, -1, -1]], dtype=int32),
1970      array([[ 10, 100], [ 11,  12]], dtype=int32))]
1971    >>> # Pad with a single value and multiple components.
1972    >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
1973    >>> for element in E.as_numpy_iterator():
1974    ...   print(element)
1975    (array([[ 1, -1],
1976           [ 2,  2]], dtype=int32), array([[ 1, -1],
1977           [ 2,  2]], dtype=int32))
1978    (array([[ 3,  3,  3, -1],
1979           [ 4,  4,  4,  4]], dtype=int32), array([[ 3,  3,  3, -1],
1980           [ 4,  4,  4,  4]], dtype=int32))
1981
1982    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
1983    elements that may have different shapes into a `tf.sparse.SparseTensor`.
1984
1985    Args:
1986      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1987        consecutive elements of this dataset to combine in a single batch.
1988      padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or
1989        `tf.int64` vector tensor-like objects representing the shape to which
1990        the respective component of each input element should be padded prior
1991        to batching. Any unknown dimensions will be padded to the maximum size
1992        of that dimension in each batch. If unset, all dimensions of all
1993        components are padded to the maximum size in the batch. `padded_shapes`
1994        must be set if any component has an unknown rank.
1995      padding_values: (Optional.) A (nested) structure of scalar-shaped
1996        `tf.Tensor`, representing the padding values to use for the respective
1997        components. None represents that the (nested) structure should be padded
1998        with default values.  Defaults are `0` for numeric types and the empty
1999        string for string types. The `padding_values` should have the same
2000        (nested) structure as the input dataset. If `padding_values` is a single
2001        element and the input dataset has multiple components, then the same
2002        `padding_values` will be used to pad every component of the dataset.
2003        If `padding_values` is a scalar, then its value will be broadcasted
2004        to match the shape of each component.
2005      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2006        whether the last batch should be dropped in the case it has fewer than
2007        `batch_size` elements; the default behavior is not to drop the smaller
2008        batch.
2009      name: (Optional.) A name for the tf.data operation.
2010
2011    Returns:
2012      Dataset: A `Dataset`.
2013
2014    Raises:
2015      ValueError: If a component has an unknown rank, and the `padded_shapes`
2016        argument is not set.
2017      TypeError: If a component is of an unsupported type. The list of supported
2018        types is documented in
2019        https://www.tensorflow.org/guide/data#dataset_structure.
2020    """
2021    if padded_shapes is None:
2022      padded_shapes = get_legacy_output_shapes(self)
2023      for i, shape in enumerate(nest.flatten(padded_shapes)):
2024        # A `tf.TensorShape` is only false if its *rank* is unknown.
2025        if not shape:
2026          raise ValueError(f"You must provide `padded_shapes` argument because "
2027                           f"component {i} has unknown rank.")
2028    return PaddedBatchDataset(
2029        self,
2030        batch_size,
2031        padded_shapes,
2032        padding_values,
2033        drop_remainder,
2034        name=name)
2035
2036  def map(self,
2037          map_func,
2038          num_parallel_calls=None,
2039          deterministic=None,
2040          name=None):
2041    """Maps `map_func` across the elements of this dataset.
2042
2043    This transformation applies `map_func` to each element of this dataset, and
2044    returns a new dataset containing the transformed elements, in the same
2045    order as they appeared in the input. `map_func` can be used to change both
2046    the values and the structure of a dataset's elements. Supported structure
2047    constructs are documented
2048    [here](https://www.tensorflow.org/guide/data#dataset_structure).
2049
2050    For example, `map` can be used for adding 1 to each element, or projecting a
2051    subset of element components.
2052
2053    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
2054    >>> dataset = dataset.map(lambda x: x + 1)
2055    >>> list(dataset.as_numpy_iterator())
2056    [2, 3, 4, 5, 6]
2057
2058    The input signature of `map_func` is determined by the structure of each
2059    element in this dataset.
2060
2061    >>> dataset = Dataset.range(5)
2062    >>> # `map_func` takes a single argument of type `tf.Tensor` with the same
2063    >>> # shape and dtype.
2064    >>> result = dataset.map(lambda x: x + 1)
2065
2066    >>> # Each element is a tuple containing two `tf.Tensor` objects.
2067    >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")]
2068    >>> dataset = tf.data.Dataset.from_generator(
2069    ...     lambda: elements, (tf.int32, tf.string))
2070    >>> # `map_func` takes two arguments of type `tf.Tensor`. This function
2071    >>> # projects out just the first component.
2072    >>> result = dataset.map(lambda x_int, y_str: x_int)
2073    >>> list(result.as_numpy_iterator())
2074    [1, 2, 3]
2075
2076    >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects.
2077    >>> elements =  ([{"a": 1, "b": "foo"},
2078    ...               {"a": 2, "b": "bar"},
2079    ...               {"a": 3, "b": "baz"}])
2080    >>> dataset = tf.data.Dataset.from_generator(
2081    ...     lambda: elements, {"a": tf.int32, "b": tf.string})
2082    >>> # `map_func` takes a single argument of type `dict` with the same keys
2083    >>> # as the elements.
2084    >>> result = dataset.map(lambda d: str(d["a"]) + d["b"])
2085
2086    The value or values returned by `map_func` determine the structure of each
2087    element in the returned dataset.
2088
2089    >>> dataset = tf.data.Dataset.range(3)
2090    >>> # `map_func` returns two `tf.Tensor` objects.
2091    >>> def g(x):
2092    ...   return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
2093    >>> result = dataset.map(g)
2094    >>> result.element_spec
2095    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \
2096dtype=tf.string, name=None))
2097    >>> # Python primitives, lists, and NumPy arrays are implicitly converted to
2098    >>> # `tf.Tensor`.
2099    >>> def h(x):
2100    ...   return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
2101    >>> result = dataset.map(h)
2102    >>> result.element_spec
2103    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \
2104dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \
2105name=None))
2106    >>> # `map_func` can return nested structures.
2107    >>> def i(x):
2108    ...   return (37.0, [42, 16]), "foo"
2109    >>> result = dataset.map(i)
2110    >>> result.element_spec
2111    ((TensorSpec(shape=(), dtype=tf.float32, name=None),
2112      TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
2113     TensorSpec(shape=(), dtype=tf.string, name=None))
2114
2115    `map_func` can accept as arguments and return any type of dataset element.
2116
2117    Note that irrespective of the context in which `map_func` is defined (eager
2118    vs. graph), tf.data traces the function and executes it as a graph. To use
2119    Python code inside of the function you have a few options:
2120
2121    1) Rely on AutoGraph to convert Python code into an equivalent graph
2122    computation. The downside of this approach is that AutoGraph can convert
2123    some but not all Python code.
2124
2125    2) Use `tf.py_function`, which allows you to write arbitrary Python code but
2126    will generally result in worse performance than 1). For example:
2127
2128    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
2129    >>> # transform a string tensor to upper case string using a Python function
2130    >>> def upper_case_fn(t: tf.Tensor):
2131    ...   return t.numpy().decode('utf-8').upper()
2132    >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn,
2133    ...           inp=[x], Tout=tf.string))
2134    >>> list(d.as_numpy_iterator())
2135    [b'HELLO', b'WORLD']
2136
2137    3) Use `tf.numpy_function`, which also allows you to write arbitrary
2138    Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas
2139    `tf.numpy_function` accepts numpy arrays and returns only numpy arrays.
2140    For example:
2141
2142    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
2143    >>> def upper_case_fn(t: np.ndarray):
2144    ...   return t.decode('utf-8').upper()
2145    >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn,
2146    ...           inp=[x], Tout=tf.string))
2147    >>> list(d.as_numpy_iterator())
2148    [b'HELLO', b'WORLD']
2149
2150    Note that the use of `tf.numpy_function` and `tf.py_function`
2151    in general precludes the possibility of executing user-defined
2152    transformations in parallel (because of Python GIL).
2153
2154    Performance can often be improved by setting `num_parallel_calls` so that
2155    `map` will use multiple threads to process elements. If deterministic order
2156    isn't required, it can also improve performance to set
2157    `deterministic=False`.
2158
2159    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
2160    >>> dataset = dataset.map(lambda x: x + 1,
2161    ...     num_parallel_calls=tf.data.AUTOTUNE,
2162    ...     deterministic=False)
2163
2164    The order of elements yielded by this transformation is deterministic if
2165    `deterministic=True`. If `map_func` contains stateful operations and
2166    `num_parallel_calls > 1`, the order in which that state is accessed is
2167    undefined, so the values of output elements may not be deterministic
2168    regardless of the `deterministic` flag value.
2169
2170    Args:
2171      map_func: A function mapping a dataset element to another dataset element.
2172      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
2173        representing the number elements to process asynchronously in parallel.
2174        If not specified, elements will be processed sequentially. If the value
2175        `tf.data.AUTOTUNE` is used, then the number of parallel
2176        calls is set dynamically based on available CPU.
2177      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
2178        boolean is specified (`True` or `False`), it controls the order in which
2179        the transformation produces elements. If set to `False`, the
2180        transformation is allowed to yield elements out of order to trade
2181        determinism for performance. If not specified, the
2182        `tf.data.Options.deterministic` option (`True` by default) controls the
2183        behavior.
2184      name: (Optional.) A name for the tf.data operation.
2185
2186    Returns:
2187      Dataset: A `Dataset`.
2188    """
2189    if num_parallel_calls is None or DEBUG_MODE:
2190      if deterministic is not None and not DEBUG_MODE:
2191        warnings.warn("The `deterministic` argument has no effect unless the "
2192                      "`num_parallel_calls` argument is specified.")
2193      return MapDataset(self, map_func, preserve_cardinality=True, name=name)
2194    else:
2195      return ParallelMapDataset(
2196          self,
2197          map_func,
2198          num_parallel_calls,
2199          deterministic,
2200          preserve_cardinality=True,
2201          name=name)
2202
2203  def flat_map(self, map_func, name=None):
2204    """Maps `map_func` across this dataset and flattens the result.
2205
2206    The type signature is:
2207
2208    ```
2209    def flat_map(
2210      self: Dataset[T],
2211      map_func: Callable[[T], Dataset[S]]
2212    ) -> Dataset[S]
2213    ```
2214
2215    Use `flat_map` if you want to make sure that the order of your dataset
2216    stays the same. For example, to flatten a dataset of batches into a
2217    dataset of their elements:
2218
2219    >>> dataset = tf.data.Dataset.from_tensor_slices(
2220    ...     [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2221    >>> dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices)
2222    >>> list(dataset.as_numpy_iterator())
2223    [1, 2, 3, 4, 5, 6, 7, 8, 9]
2224
2225    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
2226    `flat_map` produces the same output as
2227    `tf.data.Dataset.interleave(cycle_length=1)`
2228
2229    Args:
2230      map_func: A function mapping a dataset element to a dataset.
2231      name: (Optional.) A name for the tf.data operation.
2232
2233    Returns:
2234      Dataset: A `Dataset`.
2235    """
2236    return FlatMapDataset(self, map_func, name=name)
2237
2238  def interleave(self,
2239                 map_func,
2240                 cycle_length=None,
2241                 block_length=None,
2242                 num_parallel_calls=None,
2243                 deterministic=None,
2244                 name=None):
2245    """Maps `map_func` across this dataset, and interleaves the results.
2246
2247    The type signature is:
2248
2249    ```
2250    def interleave(
2251      self: Dataset[T],
2252      map_func: Callable[[T], Dataset[S]]
2253    ) -> Dataset[S]
2254    ```
2255
2256    For example, you can use `Dataset.interleave()` to process many input files
2257    concurrently:
2258
2259    >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records
2260    >>> # from each file.
2261    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
2262    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
2263    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
2264    >>> def parse_fn(filename):
2265    ...   return tf.data.Dataset.range(10)
2266    >>> dataset = dataset.interleave(lambda x:
2267    ...     tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
2268    ...     cycle_length=4, block_length=16)
2269
2270    The `cycle_length` and `block_length` arguments control the order in which
2271    elements are produced. `cycle_length` controls the number of input elements
2272    that are processed concurrently. If you set `cycle_length` to 1, this
2273    transformation will handle one input element at a time, and will produce
2274    identical results to `tf.data.Dataset.flat_map`. In general,
2275    this transformation will apply `map_func` to `cycle_length` input elements,
2276    open iterators on the returned `Dataset` objects, and cycle through them
2277    producing `block_length` consecutive elements from each iterator, and
2278    consuming the next input element each time it reaches the end of an
2279    iterator.
2280
2281    For example:
2282
2283    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
2284    >>> # NOTE: New lines indicate "block" boundaries.
2285    >>> dataset = dataset.interleave(
2286    ...     lambda x: Dataset.from_tensors(x).repeat(6),
2287    ...     cycle_length=2, block_length=4)
2288    >>> list(dataset.as_numpy_iterator())
2289    [1, 1, 1, 1,
2290     2, 2, 2, 2,
2291     1, 1,
2292     2, 2,
2293     3, 3, 3, 3,
2294     4, 4, 4, 4,
2295     3, 3,
2296     4, 4,
2297     5, 5, 5, 5,
2298     5, 5]
2299
2300    Note: The order of elements yielded by this transformation is
2301    deterministic, as long as `map_func` is a pure function and
2302    `deterministic=True`. If `map_func` contains any stateful operations, the
2303    order in which that state is accessed is undefined.
2304
2305    Performance can often be improved by setting `num_parallel_calls` so that
2306    `interleave` will use multiple threads to fetch elements. If determinism
2307    isn't required, it can also improve performance to set
2308    `deterministic=False`.
2309
2310    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
2311    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
2312    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
2313    >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
2314    ...     cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
2315    ...     deterministic=False)
2316
2317    Args:
2318      map_func: A function that takes a dataset element and returns a
2319        `tf.data.Dataset`.
2320      cycle_length: (Optional.) The number of input elements that will be
2321        processed concurrently. If not set, the tf.data runtime decides what it
2322        should be based on available CPU. If `num_parallel_calls` is set to
2323        `tf.data.AUTOTUNE`, the `cycle_length` argument identifies
2324        the maximum degree of parallelism.
2325      block_length: (Optional.) The number of consecutive elements to produce
2326        from each input element before cycling to another input element. If not
2327        set, defaults to 1.
2328      num_parallel_calls: (Optional.) If specified, the implementation creates a
2329        threadpool, which is used to fetch inputs from cycle elements
2330        asynchronously and in parallel. The default behavior is to fetch inputs
2331        from cycle elements synchronously with no parallelism. If the value
2332        `tf.data.AUTOTUNE` is used, then the number of parallel
2333        calls is set dynamically based on available CPU.
2334      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
2335        boolean is specified (`True` or `False`), it controls the order in which
2336        the transformation produces elements. If set to `False`, the
2337        transformation is allowed to yield elements out of order to trade
2338        determinism for performance. If not specified, the
2339        `tf.data.Options.deterministic` option (`True` by default) controls the
2340        behavior.
2341      name: (Optional.) A name for the tf.data operation.
2342
2343    Returns:
2344      Dataset: A `Dataset`.
2345    """
2346    if block_length is None:
2347      block_length = 1
2348
2349    if cycle_length is None:
2350      cycle_length = AUTOTUNE
2351
2352    if num_parallel_calls is None or DEBUG_MODE:
2353      if deterministic is not None and not DEBUG_MODE:
2354        warnings.warn("The `deterministic` argument has no effect unless the "
2355                      "`num_parallel_calls` argument is specified.")
2356      return InterleaveDataset(
2357          self, map_func, cycle_length, block_length, name=name)
2358    else:
2359      return ParallelInterleaveDataset(
2360          self,
2361          map_func,
2362          cycle_length,
2363          block_length,
2364          num_parallel_calls,
2365          deterministic=deterministic,
2366          name=name)
2367
2368  def filter(self, predicate, name=None):
2369    """Filters this dataset according to `predicate`.
2370
2371    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
2372    >>> dataset = dataset.filter(lambda x: x < 3)
2373    >>> list(dataset.as_numpy_iterator())
2374    [1, 2]
2375    >>> # `tf.math.equal(x, y)` is required for equality comparison
2376    >>> def filter_fn(x):
2377    ...   return tf.math.equal(x, 1)
2378    >>> dataset = dataset.filter(filter_fn)
2379    >>> list(dataset.as_numpy_iterator())
2380    [1]
2381
2382    Args:
2383      predicate: A function mapping a dataset element to a boolean.
2384      name: (Optional.) A name for the tf.data operation.
2385
2386    Returns:
2387      Dataset: The `Dataset` containing the elements of this dataset for which
2388          `predicate` is `True`.
2389    """
2390    return FilterDataset(self, predicate, name=name)
2391
2392  def apply(self, transformation_func):
2393    """Applies a transformation function to this dataset.
2394
2395    `apply` enables chaining of custom `Dataset` transformations, which are
2396    represented as functions that take one `Dataset` argument and return a
2397    transformed `Dataset`.
2398
2399    >>> dataset = tf.data.Dataset.range(100)
2400    >>> def dataset_fn(ds):
2401    ...   return ds.filter(lambda x: x < 5)
2402    >>> dataset = dataset.apply(dataset_fn)
2403    >>> list(dataset.as_numpy_iterator())
2404    [0, 1, 2, 3, 4]
2405
2406    Args:
2407      transformation_func: A function that takes one `Dataset` argument and
2408        returns a `Dataset`.
2409
2410    Returns:
2411      Dataset: The `Dataset` returned by applying `transformation_func` to this
2412          dataset.
2413    """
2414    dataset = transformation_func(self)
2415    if not isinstance(dataset, DatasetV2):
2416      raise TypeError(
2417          f"`transformation_func` must return a `tf.data.Dataset` object. "
2418          f"Got {type(dataset)}.")
2419    dataset._input_datasets = [self]  # pylint: disable=protected-access
2420    return dataset
2421
2422  def window(self, size, shift=None, stride=1, drop_remainder=False, name=None):
2423    """Returns a dataset of "windows".
2424
2425    Each "window" is a dataset that contains a subset of elements of the
2426    input dataset. These are finite datasets of size `size` (or possibly fewer
2427    if there are not enough input elements to fill the window and
2428    `drop_remainder` evaluates to `False`).
2429
2430    For example:
2431
2432    >>> dataset = tf.data.Dataset.range(7).window(3)
2433    >>> for window in dataset:
2434    ...   print(window)
2435    <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2436    <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2437    <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2438
2439    Since windows are datasets, they can be iterated over:
2440
2441    >>> for window in dataset:
2442    ...   print([item.numpy() for item in window])
2443    [0, 1, 2]
2444    [3, 4, 5]
2445    [6]
2446
2447    #### Shift
2448
2449    The `shift` argument determines the number of input elements to shift
2450    between the start of each window. If windows and elements are both numbered
2451    starting at 0, the first element in window `k` will be element `k * shift`
2452    of the input dataset. In particular, the first element of the first window
2453    will always be the first element of the input dataset.
2454
2455    >>> dataset = tf.data.Dataset.range(7).window(3, shift=1,
2456    ...                                           drop_remainder=True)
2457    >>> for window in dataset:
2458    ...   print(list(window.as_numpy_iterator()))
2459    [0, 1, 2]
2460    [1, 2, 3]
2461    [2, 3, 4]
2462    [3, 4, 5]
2463    [4, 5, 6]
2464
2465    #### Stride
2466
2467    The `stride` argument determines the stride between input elements within a
2468    window.
2469
2470    >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2,
2471    ...                                           drop_remainder=True)
2472    >>> for window in dataset:
2473    ...   print(list(window.as_numpy_iterator()))
2474    [0, 2, 4]
2475    [1, 3, 5]
2476    [2, 4, 6]
2477
2478    #### Nested elements
2479
2480    When the `window` transformation is applied to a dataset whos elements are
2481    nested structures, it produces a dataset where the elements have the same
2482    nested structure but each leaf is replaced by a window. In other words,
2483    the nesting is applied outside of the windows as opposed inside of them.
2484
2485    The type signature is:
2486
2487    ```
2488    def window(
2489        self: Dataset[Nest[T]], ...
2490    ) -> Dataset[Nest[Dataset[T]]]
2491    ```
2492
2493    Applying `window` to a `Dataset` of tuples gives a tuple of windows:
2494
2495    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5],
2496    ...                                               [6, 7, 8, 9, 10]))
2497    >>> dataset = dataset.window(2)
2498    >>> windows = next(iter(dataset))
2499    >>> windows
2500    (<...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>,
2501     <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>)
2502
2503    >>> def to_numpy(ds):
2504    ...   return list(ds.as_numpy_iterator())
2505    >>>
2506    >>> for windows in dataset:
2507    ...   print(to_numpy(windows[0]), to_numpy(windows[1]))
2508    [1, 2] [6, 7]
2509    [3, 4] [8, 9]
2510    [5] [10]
2511
2512    Applying `window` to a `Dataset` of dictionaries gives a dictionary of
2513    `Datasets`:
2514
2515    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3],
2516    ...                                               'b': [4, 5, 6],
2517    ...                                               'c': [7, 8, 9]})
2518    >>> dataset = dataset.window(2)
2519    >>> def to_numpy(ds):
2520    ...   return list(ds.as_numpy_iterator())
2521    >>>
2522    >>> for windows in dataset:
2523    ...   print(tf.nest.map_structure(to_numpy, windows))
2524    {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]}
2525    {'a': [3], 'b': [6], 'c': [9]}
2526
2527    #### Flatten a dataset of windows
2528
2529    The `Dataset.flat_map` and `Dataset.interleave` methods can be used to
2530    flatten a dataset of windows into a single dataset.
2531
2532    The argument to `flat_map` is a function that takes an element from the
2533    dataset and returns a `Dataset`. `flat_map` chains together the resulting
2534    datasets sequentially.
2535
2536    For example, to turn each window into a dense tensor:
2537
2538    >>> size = 3
2539    >>> dataset = tf.data.Dataset.range(7).window(size, shift=1,
2540    ...                                           drop_remainder=True)
2541    >>> batched = dataset.flat_map(lambda x:x.batch(3))
2542    >>> for batch in batched:
2543    ...   print(batch.numpy())
2544    [0 1 2]
2545    [1 2 3]
2546    [2 3 4]
2547    [3 4 5]
2548    [4 5 6]
2549
2550    Args:
2551      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
2552        of the input dataset to combine into a window. Must be positive.
2553      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2554        number of input elements by which the window moves in each iteration.
2555        Defaults to `size`. Must be positive.
2556      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2557        stride of the input elements in the sliding window. Must be positive.
2558        The default value of 1 means "retain every input element".
2559      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2560        whether the last windows should be dropped if their size is smaller than
2561        `size`.
2562      name: (Optional.) A name for the tf.data operation.
2563
2564    Returns:
2565      Dataset: A `Dataset` of (nests of) windows. Each window is a finite
2566        datasets of flat elements.
2567    """
2568    if shift is None:
2569      shift = size
2570    return WindowDataset(self, size, shift, stride, drop_remainder, name=name)
2571
2572  def reduce(self, initial_state, reduce_func, name=None):
2573    """Reduces the input dataset to a single element.
2574
2575    The transformation calls `reduce_func` successively on every element of
2576    the input dataset until the dataset is exhausted, aggregating information in
2577    its internal state. The `initial_state` argument is used for the initial
2578    state and the final state is returned as the result.
2579
2580    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
2581    5
2582    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
2583    10
2584
2585    Args:
2586      initial_state: An element representing the initial state of the
2587        transformation.
2588      reduce_func: A function that maps `(old_state, input_element)` to
2589        `new_state`. It must take two arguments and return a new element
2590        The structure of `new_state` must match the structure of
2591        `initial_state`.
2592      name: (Optional.) A name for the tf.data operation.
2593
2594    Returns:
2595      A dataset element corresponding to the final state of the transformation.
2596
2597    """
2598
2599    with ops.name_scope("initial_state"):
2600      initial_state = structure.normalize_element(initial_state)
2601    state_structure = structure.type_spec_from_value(initial_state)
2602
2603    # Iteratively rerun the reduce function until reaching a fixed point on
2604    # `state_structure`.
2605    need_to_rerun = True
2606    while need_to_rerun:
2607
2608      wrapped_func = structured_function.StructuredFunctionWrapper(
2609          reduce_func,
2610          "reduce()",
2611          input_structure=(state_structure, self.element_spec),
2612          add_to_graph=False)
2613
2614      # Extract and validate class information from the returned values.
2615      output_classes = wrapped_func.output_classes
2616      state_classes = nest.map_structure(
2617          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2618          state_structure)
2619      for new_state_class, state_class in zip(
2620          nest.flatten(output_classes), nest.flatten(state_classes)):
2621        if not issubclass(new_state_class, state_class):
2622          raise TypeError(
2623              f"The element classes for the new state must match the initial "
2624              f"state. Expected {state_classes} but got "
2625              f"{wrapped_func.output_classes}.")
2626
2627      # Extract and validate type information from the returned values.
2628      output_types = wrapped_func.output_types
2629      state_types = nest.map_structure(
2630          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2631          state_structure)
2632      for new_state_type, state_type in zip(
2633          nest.flatten(output_types), nest.flatten(state_types)):
2634        if new_state_type != state_type:
2635          raise TypeError(
2636              f"The element types for the new state must match the initial "
2637              f"state. Expected {state_types} but got "
2638              f"{wrapped_func.output_types}.")
2639
2640      # Extract shape information from the returned values.
2641      output_shapes = wrapped_func.output_shapes
2642      state_shapes = nest.map_structure(
2643          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2644          state_structure)
2645      flat_state_shapes = nest.flatten(state_shapes)
2646      flat_new_state_shapes = nest.flatten(output_shapes)
2647      weakened_state_shapes = [
2648          original.most_specific_compatible_shape(new)
2649          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
2650      ]
2651
2652      need_to_rerun = False
2653      for original_shape, weakened_shape in zip(flat_state_shapes,
2654                                                weakened_state_shapes):
2655        if original_shape.ndims is not None and (
2656            weakened_shape.ndims is None or
2657            original_shape.as_list() != weakened_shape.as_list()):
2658          need_to_rerun = True
2659          break
2660
2661      if need_to_rerun:
2662        # TODO(b/110122868): Support a "most specific compatible structure"
2663        # method for combining structures, to avoid using legacy structures
2664        # here.
2665        state_structure = structure.convert_legacy_structure(
2666            state_types,
2667            nest.pack_sequence_as(state_shapes, weakened_state_shapes),
2668            state_classes)
2669
2670    reduce_func = wrapped_func.function
2671    reduce_func.add_to_graph(ops.get_default_graph())
2672
2673    dataset = self._apply_debug_options()
2674
2675    # pylint: disable=protected-access
2676    metadata = dataset_metadata_pb2.Metadata()
2677    if name:
2678      metadata.name = _validate_and_encode(name)
2679    return structure.from_compatible_tensor_list(
2680        state_structure,
2681        gen_dataset_ops.reduce_dataset(
2682            dataset._variant_tensor,
2683            structure.to_tensor_list(state_structure, initial_state),
2684            reduce_func.captured_inputs,
2685            f=reduce_func,
2686            output_shapes=structure.get_flat_tensor_shapes(state_structure),
2687            output_types=structure.get_flat_tensor_types(state_structure),
2688            metadata=metadata.SerializeToString()))
2689
2690  def get_single_element(self, name=None):
2691    """Returns the single element of the `dataset`.
2692
2693    The function enables you to use a `tf.data.Dataset` in a stateless
2694    "tensor-in tensor-out" expression, without creating an iterator.
2695    This facilitates the ease of data transformation on tensors using the
2696    optimized `tf.data.Dataset` abstraction on top of them.
2697
2698    For example, lets consider a `preprocessing_fn` which would take as an
2699    input the raw features and returns the processed feature along with
2700    it's label.
2701
2702    ```python
2703    def preprocessing_fn(raw_feature):
2704      # ... the raw_feature is preprocessed as per the use-case
2705      return feature
2706
2707    raw_features = ...  # input batch of BATCH_SIZE elements.
2708    dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2709              .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2710              .batch(BATCH_SIZE))
2711
2712    processed_features = dataset.get_single_element()
2713    ```
2714
2715    In the above example, the `raw_features` tensor of length=BATCH_SIZE
2716    was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
2717    mapped using the `preprocessing_fn` and the processed features were
2718    grouped into a single batch. The final `dataset` contains only one element
2719    which is a batch of all the processed features.
2720
2721    NOTE: The `dataset` should contain only one element.
2722
2723    Now, instead of creating an iterator for the `dataset` and retrieving the
2724    batch of features, the `tf.data.get_single_element()` function is used
2725    to skip the iterator creation process and directly output the batch of
2726    features.
2727
2728    This can be particularly useful when your tensor transformations are
2729    expressed as `tf.data.Dataset` operations, and you want to use those
2730    transformations while serving your model.
2731
2732    #### Keras
2733
2734    ```python
2735
2736    model = ... # A pre-built or custom model
2737
2738    class PreprocessingModel(tf.keras.Model):
2739      def __init__(self, model):
2740        super().__init__(self)
2741        self.model = model
2742
2743      @tf.function(input_signature=[...])
2744      def serving_fn(self, data):
2745        ds = tf.data.Dataset.from_tensor_slices(data)
2746        ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2747        ds = ds.batch(batch_size=BATCH_SIZE)
2748        return tf.argmax(self.model(ds.get_single_element()), axis=-1)
2749
2750    preprocessing_model = PreprocessingModel(model)
2751    your_exported_model_dir = ... # save the model to this path.
2752    tf.saved_model.save(preprocessing_model, your_exported_model_dir,
2753                  signatures={'serving_default': preprocessing_model.serving_fn}
2754                  )
2755    ```
2756
2757    #### Estimator
2758
2759    In the case of estimators, you need to generally define a `serving_input_fn`
2760    which would require the features to be processed by the model while
2761    inferencing.
2762
2763    ```python
2764    def serving_input_fn():
2765
2766      raw_feature_spec = ... # Spec for the raw_features
2767      input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
2768          raw_feature_spec, default_batch_size=None)
2769      )
2770      serving_input_receiver = input_fn()
2771      raw_features = serving_input_receiver.features
2772
2773      def preprocessing_fn(raw_feature):
2774        # ... the raw_feature is preprocessed as per the use-case
2775        return feature
2776
2777      dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2778                .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2779                .batch(BATCH_SIZE))
2780
2781      processed_features = dataset.get_single_element()
2782
2783      # Please note that the value of `BATCH_SIZE` should be equal to
2784      # the size of the leading dimension of `raw_features`. This ensures
2785      # that `dataset` has only element, which is a pre-requisite for
2786      # using `dataset.get_single_element()`.
2787
2788      return tf.estimator.export.ServingInputReceiver(
2789          processed_features, serving_input_receiver.receiver_tensors)
2790
2791    estimator = ... # A pre-built or custom estimator
2792    estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
2793    ```
2794
2795    Args:
2796      name: (Optional.) A name for the tf.data operation.
2797
2798    Returns:
2799      A nested structure of `tf.Tensor` objects, corresponding to the single
2800      element of `dataset`.
2801
2802    Raises:
2803      InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
2804        one element.
2805    """
2806
2807    metadata = dataset_metadata_pb2.Metadata()
2808    if name:
2809      metadata.name = _validate_and_encode(name)
2810    return structure.from_compatible_tensor_list(
2811        self.element_spec,
2812        gen_dataset_ops.dataset_to_single_element(
2813            self._variant_tensor,
2814            metadata=metadata.SerializeToString(),
2815            **self._flat_structure))  # pylint: disable=protected-access
2816
2817  def unbatch(self, name=None):
2818    """Splits elements of a dataset into multiple elements.
2819
2820    For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
2821    where `B` may vary for each input element, then for each element in the
2822    dataset, the unbatched dataset will contain `B` consecutive elements
2823    of shape `[a0, a1, ...]`.
2824
2825    >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
2826    >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
2827    >>> dataset = dataset.unbatch()
2828    >>> list(dataset.as_numpy_iterator())
2829    [1, 2, 3, 1, 2, 1, 2, 3, 4]
2830
2831    Note: `unbatch` requires a data copy to slice up the batched tensor into
2832    smaller, unbatched tensors. When optimizing performance, try to avoid
2833    unnecessary usage of `unbatch`.
2834
2835    Args:
2836      name: (Optional.) A name for the tf.data operation.
2837
2838    Returns:
2839      A `Dataset`.
2840    """
2841    normalized_dataset = normalize_to_dense(self)
2842    return _UnbatchDataset(normalized_dataset, name=name)
2843
2844  def with_options(self, options, name=None):
2845    """Returns a new `tf.data.Dataset` with the given options set.
2846
2847    The options are "global" in the sense they apply to the entire dataset.
2848    If options are set multiple times, they are merged as long as different
2849    options do not use different non-default values.
2850
2851    >>> ds = tf.data.Dataset.range(5)
2852    >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
2853    ...                    cycle_length=3,
2854    ...                    num_parallel_calls=3)
2855    >>> options = tf.data.Options()
2856    >>> # This will make the interleave order non-deterministic.
2857    >>> options.deterministic = False
2858    >>> ds = ds.with_options(options)
2859
2860    Args:
2861      options: A `tf.data.Options` that identifies the options the use.
2862      name: (Optional.) A name for the tf.data operation.
2863
2864    Returns:
2865      Dataset: A `Dataset` with the given options.
2866
2867    Raises:
2868      ValueError: when an option is set more than once to a non-default value
2869    """
2870    return _OptionsDataset(self, options, name=name)
2871
2872  def cardinality(self):
2873    """Returns the cardinality of the dataset, if known.
2874
2875    `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
2876    contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
2877    the analysis fails to determine the number of elements in the dataset
2878    (e.g. when the dataset source is a file).
2879
2880    >>> dataset = tf.data.Dataset.range(42)
2881    >>> print(dataset.cardinality().numpy())
2882    42
2883    >>> dataset = dataset.repeat()
2884    >>> cardinality = dataset.cardinality()
2885    >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
2886    True
2887    >>> dataset = dataset.filter(lambda x: True)
2888    >>> cardinality = dataset.cardinality()
2889    >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
2890    True
2891
2892    Returns:
2893      A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
2894      If the cardinality is infinite or unknown, `cardinality` returns the
2895      named constants `tf.data.INFINITE_CARDINALITY` and
2896      `tf.data.UNKNOWN_CARDINALITY` respectively.
2897    """
2898    return gen_dataset_ops.dataset_cardinality(self._variant_tensor)
2899
2900  def group_by_window(self,
2901                      key_func,
2902                      reduce_func,
2903                      window_size=None,
2904                      window_size_func=None,
2905                      name=None):
2906    """Groups windows of elements by key and reduces them.
2907
2908    This transformation maps each consecutive element in a dataset to a key
2909    using `key_func` and groups the elements by key. It then applies
2910    `reduce_func` to at most `window_size_func(key)` elements matching the same
2911    key. All except the final window for each key will contain
2912    `window_size_func(key)` elements; the final window may be smaller.
2913
2914    You may provide either a constant `window_size` or a window size determined
2915    by the key through `window_size_func`.
2916
2917    >>> dataset = tf.data.Dataset.range(10)
2918    >>> window_size = 5
2919    >>> key_func = lambda x: x%2
2920    >>> reduce_func = lambda key, dataset: dataset.batch(window_size)
2921    >>> dataset = dataset.group_by_window(
2922    ...           key_func=key_func,
2923    ...           reduce_func=reduce_func,
2924    ...           window_size=window_size)
2925    >>> for elem in dataset.as_numpy_iterator():
2926    ...   print(elem)
2927    [0 2 4 6 8]
2928    [1 3 5 7 9]
2929
2930    Args:
2931      key_func: A function mapping a nested structure of tensors (having shapes
2932        and types defined by `self.output_shapes` and `self.output_types`) to a
2933        scalar `tf.int64` tensor.
2934      reduce_func: A function mapping a key and a dataset of up to `window_size`
2935        consecutive elements matching that key to another dataset.
2936      window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
2937        consecutive elements matching the same key to combine in a single batch,
2938        which will be passed to `reduce_func`. Mutually exclusive with
2939        `window_size_func`.
2940      window_size_func: A function mapping a key to a `tf.int64` scalar
2941        `tf.Tensor`, representing the number of consecutive elements matching
2942        the same key to combine in a single batch, which will be passed to
2943        `reduce_func`. Mutually exclusive with `window_size`.
2944      name: (Optional.) A name for the tf.data operation.
2945
2946    Returns:
2947      A `Dataset`.
2948
2949    Raises:
2950      ValueError: if neither or both of {`window_size`, `window_size_func`} are
2951        passed.
2952    """
2953    if (window_size is not None and window_size_func or
2954        not (window_size is not None or window_size_func)):
2955      raise ValueError("Either the `window_size` argument or the "
2956                       "`window_size_func` argument must be specified.")
2957
2958    if window_size is not None:
2959
2960      def constant_window_func(unused_key):
2961        return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
2962
2963      window_size_func = constant_window_func
2964
2965    assert window_size_func is not None
2966
2967    return _GroupByWindowDataset(
2968        self, key_func, reduce_func, window_size_func, name=name)
2969
2970  def bucket_by_sequence_length(self,
2971                                element_length_func,
2972                                bucket_boundaries,
2973                                bucket_batch_sizes,
2974                                padded_shapes=None,
2975                                padding_values=None,
2976                                pad_to_bucket_boundary=False,
2977                                no_padding=False,
2978                                drop_remainder=False,
2979                                name=None):
2980    """A transformation that buckets elements in a `Dataset` by length.
2981
2982    Elements of the `Dataset` are grouped together by length and then are padded
2983    and batched.
2984
2985    This is useful for sequence tasks in which the elements have variable
2986    length. Grouping together elements that have similar lengths reduces the
2987    total fraction of padding in a batch which increases training step
2988    efficiency.
2989
2990    Below is an example to bucketize the input data to the 3 buckets
2991    "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.
2992
2993    >>> elements = [
2994    ...   [0], [1, 2, 3, 4], [5, 6, 7],
2995    ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
2996    >>> dataset = tf.data.Dataset.from_generator(
2997    ...     lambda: elements, tf.int64, output_shapes=[None])
2998    >>> dataset = dataset.bucket_by_sequence_length(
2999    ...         element_length_func=lambda elem: tf.shape(elem)[0],
3000    ...         bucket_boundaries=[3, 5],
3001    ...         bucket_batch_sizes=[2, 2, 2])
3002    >>> for elem in dataset.as_numpy_iterator():
3003    ...   print(elem)
3004    [[1 2 3 4]
3005    [5 6 7 0]]
3006    [[ 7  8  9 10 11  0]
3007    [13 14 15 16 19 20]]
3008    [[ 0  0]
3009    [21 22]]
3010
3011    Args:
3012      element_length_func: function from element in `Dataset` to `tf.int32`,
3013        determines the length of the element, which will determine the bucket it
3014        goes into.
3015      bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
3016      bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
3017        `len(bucket_boundaries) + 1`.
3018      padded_shapes: Nested structure of `tf.TensorShape` to pass to
3019        `tf.data.Dataset.padded_batch`. If not provided, will use
3020        `dataset.output_shapes`, which will result in variable length dimensions
3021        being padded out to the maximum length in each batch.
3022      padding_values: Values to pad with, passed to
3023        `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
3024      pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
3025        size to maximum length in batch. If `True`, will pad dimensions with
3026        unknown size to bucket boundary minus 1 (i.e., the maximum length in
3027        each bucket), and caller must ensure that the source `Dataset` does not
3028        contain any elements with length longer than `max(bucket_boundaries)`.
3029      no_padding: `bool`, indicates whether to pad the batch features (features
3030        need to be either of type `tf.sparse.SparseTensor` or of same shape).
3031      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
3032        whether the last batch should be dropped in the case it has fewer than
3033        `batch_size` elements; the default behavior is not to drop the smaller
3034        batch.
3035      name: (Optional.) A name for the tf.data operation.
3036
3037    Returns:
3038      A `Dataset`.
3039
3040    Raises:
3041      ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
3042    """
3043    if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
3044      raise ValueError(
3045          f"`len(bucket_batch_sizes)` must equal `len(bucket_boundaries) + 1` "
3046          f"but `len(bucket_batch_sizes)={len(bucket_batch_sizes)}` and "
3047          f"`len(bucket_boundaries)={len(bucket_boundaries)}`.")
3048
3049    batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
3050
3051    def element_to_bucket_id(*args):
3052      """Return int64 id of the length bucket for this element."""
3053      seq_length = element_length_func(*args)
3054
3055      boundaries = list(bucket_boundaries)
3056      buckets_min = [np.iinfo(np.int32).min] + boundaries
3057      buckets_max = boundaries + [np.iinfo(np.int32).max]
3058      conditions_c = math_ops.logical_and(
3059          math_ops.less_equal(buckets_min, seq_length),
3060          math_ops.less(seq_length, buckets_max))
3061      bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
3062
3063      return bucket_id
3064
3065    def window_size_fn(bucket_id):
3066      # The window size is set to the batch size for this bucket
3067      window_size = batch_sizes[bucket_id]
3068      return window_size
3069
3070    def make_padded_shapes(shapes, none_filler=None):
3071      padded = []
3072      for shape in nest.flatten(shapes):
3073        shape = tensor_shape.TensorShape(shape)
3074        shape = [
3075            none_filler if tensor_shape.dimension_value(d) is None else d
3076            for d in shape
3077        ]
3078        padded.append(shape)
3079      return nest.pack_sequence_as(shapes, padded)
3080
3081    def batching_fn(bucket_id, grouped_dataset):
3082      """Batch elements in dataset."""
3083      batch_size = window_size_fn(bucket_id)
3084      if no_padding:
3085        return grouped_dataset.batch(
3086            batch_size, drop_remainder=drop_remainder, name=name)
3087      none_filler = None
3088      if pad_to_bucket_boundary:
3089        err_msg = ("When pad_to_bucket_boundary=True, elements must have "
3090                   "length < max(bucket_boundaries).")
3091        check = check_ops.assert_less(
3092            bucket_id,
3093            constant_op.constant(
3094                len(bucket_batch_sizes) - 1, dtype=dtypes.int64),
3095            message=err_msg)
3096        with ops.control_dependencies([check]):
3097          boundaries = constant_op.constant(
3098              bucket_boundaries, dtype=dtypes.int64)
3099          bucket_boundary = boundaries[bucket_id]
3100          none_filler = bucket_boundary - 1
3101      input_shapes = get_legacy_output_shapes(grouped_dataset)
3102      shapes = make_padded_shapes(
3103          padded_shapes or input_shapes, none_filler=none_filler)
3104      return grouped_dataset.padded_batch(
3105          batch_size,
3106          shapes,
3107          padding_values,
3108          drop_remainder=drop_remainder,
3109          name=name)
3110
3111    return self.group_by_window(
3112        key_func=element_to_bucket_id,
3113        reduce_func=batching_fn,
3114        window_size_func=window_size_fn,
3115        name=name)
3116
3117  @staticmethod
3118  def random(seed=None, name=None):
3119    """Creates a `Dataset` of pseudorandom values.
3120
3121    The dataset generates a sequence of uniformly distributed integer values.
3122
3123    >>> ds1 = tf.data.Dataset.random(seed=4).take(10)
3124    >>> ds2 = tf.data.Dataset.random(seed=4).take(10)
3125    >>> print(list(ds2.as_numpy_iterator())==list(ds2.as_numpy_iterator()))
3126    True
3127
3128    Args:
3129      seed: (Optional) If specified, the dataset produces a deterministic
3130        sequence of values.
3131      name: (Optional.) A name for the tf.data operation.
3132
3133    Returns:
3134      Dataset: A `Dataset`.
3135    """
3136    return RandomDataset(seed=seed, name=name)
3137
3138  def snapshot(self,
3139               path,
3140               compression="AUTO",
3141               reader_func=None,
3142               shard_func=None,
3143               name=None):
3144    """API to persist the output of the input dataset.
3145
3146    The snapshot API allows users to transparently persist the output of their
3147    preprocessing pipeline to disk, and materialize the pre-processed data on a
3148    different training run.
3149
3150    This API enables repeated preprocessing steps to be consolidated, and allows
3151    re-use of already processed data, trading off disk storage and network
3152    bandwidth for freeing up more valuable CPU resources and accelerator compute
3153    time.
3154
3155    https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
3156    has detailed design documentation of this feature.
3157
3158    Users can specify various options to control the behavior of snapshot,
3159    including how snapshots are read from and written to by passing in
3160    user-defined functions to the `reader_func` and `shard_func` parameters.
3161
3162    `shard_func` is a user specified function that maps input elements to
3163    snapshot shards.
3164
3165    Users may want to specify this function to control how snapshot files should
3166    be written to disk. Below is an example of how a potential `shard_func`
3167    could be written.
3168
3169    ```python
3170    dataset = ...
3171    dataset = dataset.enumerate()
3172    dataset = dataset.snapshot("/path/to/snapshot/dir",
3173        shard_func=lambda x, y: x % NUM_SHARDS, ...)
3174    dataset = dataset.map(lambda x, y: y)
3175    ```
3176
3177    `reader_func` is a user specified function that accepts a single argument:
3178    (1) a Dataset of Datasets, each representing a "split" of elements of the
3179    original dataset. The cardinality of the input dataset matches the
3180    number of the shards specified in the `shard_func` (see above). The function
3181    should return a Dataset of elements of the original dataset.
3182
3183    Users may want specify this function to control how snapshot files should be
3184    read from disk, including the amount of shuffling and parallelism.
3185
3186    Here is an example of a standard reader function a user can define. This
3187    function enables both dataset shuffling and parallel reading of datasets:
3188
3189    ```python
3190    def user_reader_func(datasets):
3191      # shuffle the datasets splits
3192      datasets = datasets.shuffle(NUM_CORES)
3193      # read datasets in parallel and interleave their elements
3194      return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
3195
3196    dataset = dataset.snapshot("/path/to/snapshot/dir",
3197        reader_func=user_reader_func)
3198    ```
3199
3200    By default, snapshot parallelizes reads by the number of cores available on
3201    the system, but will not attempt to shuffle the data.
3202
3203    Args:
3204      path: Required. A directory to use for storing / loading the snapshot to /
3205        from.
3206      compression: Optional. The type of compression to apply to the snapshot
3207        written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
3208        Defaults to `AUTO`, which attempts to pick an appropriate compression
3209        algorithm for the dataset.
3210      reader_func: Optional. A function to control how to read data from
3211        snapshot shards.
3212      shard_func: Optional. A function to control how to shard data when writing
3213        a snapshot.
3214      name: (Optional.) A name for the tf.data operation.
3215
3216    Returns:
3217      A `Dataset`.
3218    """
3219
3220    project_func = None
3221    input_dataset = self
3222    if shard_func is None:
3223      input_dataset = input_dataset.enumerate(name=name)
3224      # This sets the amount of parallelism based on the number of CPU cores on
3225      # the machine where this Python code is executed, which may differ from
3226      # the number of CPU cores where the input pipeline graph is actually
3227      # executed (e.g. remote Cloud TPU workers).
3228      local_shard_func = lambda index, _: index % multiprocessing.cpu_count()
3229      project_func = lambda _, elem: elem
3230    else:
3231      local_shard_func = shard_func
3232    dataset = _SnapshotDataset(
3233        input_dataset=input_dataset,
3234        path=path,
3235        compression=compression,
3236        reader_func=reader_func,
3237        # This will not do the right thing where the graph is built on a
3238        # different machine than the executor (e.g. Cloud TPUs).
3239        shard_func=local_shard_func,
3240        name=name)
3241    if project_func is not None:
3242      dataset = dataset.map(project_func, name=name)
3243    return dataset
3244
3245  def scan(self, initial_state, scan_func, name=None):
3246    """A transformation that scans a function across an input dataset.
3247
3248    This transformation is a stateful relative of `tf.data.Dataset.map`.
3249    In addition to mapping `scan_func` across the elements of the input dataset,
3250    `scan()` accumulates one or more state tensors, whose initial values are
3251    `initial_state`.
3252
3253    >>> dataset = tf.data.Dataset.range(10)
3254    >>> initial_state = tf.constant(0, dtype=tf.int64)
3255    >>> scan_func = lambda state, i: (state + i, state + i)
3256    >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)
3257    >>> list(dataset.as_numpy_iterator())
3258    [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]
3259
3260    Args:
3261      initial_state: A nested structure of tensors, representing the initial
3262        state of the accumulator.
3263      scan_func: A function that maps `(old_state, input_element)` to
3264        `(new_state, output_element)`. It must take two arguments and return a
3265        pair of nested structures of tensors. The `new_state` must match the
3266        structure of `initial_state`.
3267      name: (Optional.) A name for the tf.data operation.
3268
3269    Returns:
3270      A `Dataset`.
3271    """
3272
3273    return _ScanDataset(
3274        self, initial_state=initial_state, scan_func=scan_func, name=name)
3275
3276  def take_while(self, predicate, name=None):
3277    """A transformation that stops dataset iteration based on a `predicate`.
3278
3279    >>> dataset = tf.data.Dataset.range(10)
3280    >>> dataset = dataset.take_while(lambda x: x < 5)
3281    >>> list(dataset.as_numpy_iterator())
3282    [0, 1, 2, 3, 4]
3283
3284    Args:
3285      predicate: A function that maps a nested structure of tensors (having
3286        shapes and types defined by `self.output_shapes` and
3287        `self.output_types`) to a scalar `tf.bool` tensor.
3288      name: (Optional.) A name for the tf.data operation.
3289
3290    Returns:
3291      A `Dataset`.
3292    """
3293
3294    return _TakeWhileDataset(self, predicate, name=name)
3295
3296  def unique(self, name=None):
3297    """A transformation that discards duplicate elements of a `Dataset`.
3298
3299    Use this transformation to produce a dataset that contains one instance of
3300    each unique element in the input. For example:
3301
3302    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
3303    >>> dataset = dataset.unique()
3304    >>> sorted(list(dataset.as_numpy_iterator()))
3305    [1, 2, 37]
3306
3307    Note: This transformation only supports datasets which fit into memory
3308    and have elements of either `tf.int32`, `tf.int64` or `tf.string` type.
3309
3310    Args:
3311      name: (Optional.) A name for the tf.data operation.
3312
3313    Returns:
3314      A `Dataset`.
3315    """
3316
3317    return _UniqueDataset(self, name=name)
3318
3319  def rejection_resample(self,
3320                         class_func,
3321                         target_dist,
3322                         initial_dist=None,
3323                         seed=None,
3324                         name=None):
3325    """A transformation that resamples a dataset to a target distribution.
3326
3327    Lets consider the following example where a dataset with an initial data
3328    distribution of `init_dist` needs to be resampled into a dataset with
3329    `target_dist` distribution.
3330
3331    >>> initial_dist = [0.6, 0.4]
3332    >>> num_classes = len(initial_dist)
3333    >>> num_samples = 1000
3334    >>> data_np = np.random.choice(num_classes, num_samples, p=initial_dist)
3335    >>> dataset = tf.data.Dataset.from_tensor_slices(data_np)
3336
3337    The value of `x` will be close to `{0: 50000, 1: 50000}` as per the
3338    `initial_dist` distribution.
3339
3340    >>> target_dist = [0.5, 0.5]
3341    >>> resampled_dataset = dataset.rejection_resample(
3342    ...    class_func=lambda x: x,
3343    ...    target_dist=target_dist,
3344    ...    initial_dist=initial_dist)
3345    >>> resampled_dataset = resampled_dataset.map(
3346    ...     lambda class_func_result, data: data)
3347
3348
3349    The value distribution of classes in the resampled_distribution will be now
3350    be close to the target distribution.
3351
3352    Args:
3353      class_func: A function mapping an element of the input dataset to a scalar
3354        `tf.int32` tensor. Values should be in `[0, num_classes)`.
3355      target_dist: A floating point type tensor, shaped `[num_classes]`.
3356      initial_dist: (Optional.)  A floating point type tensor, shaped
3357        `[num_classes]`.  If not provided, the true class distribution is
3358        estimated live in a streaming fashion.
3359      seed: (Optional.) Python integer seed for the resampler.
3360      name: (Optional.) A name for the tf.data operation.
3361
3362    Returns:
3363      A `Dataset`
3364    """
3365
3366    target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
3367    target_dist_t = math_ops.cast(target_dist_t, dtypes.float32)
3368
3369    # Get initial distribution.
3370    if initial_dist is not None:
3371      initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
3372      initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32)
3373      acceptance_dist, prob_of_original = (
3374          _calculate_acceptance_probs_with_mixing(initial_dist_t,
3375                                                  target_dist_t))
3376      initial_dist_ds = DatasetV2.from_tensors(
3377          initial_dist_t, name=name).repeat(name=name)
3378      acceptance_dist_ds = DatasetV2.from_tensors(
3379          acceptance_dist, name=name).repeat(name=name)
3380      prob_of_original_ds = DatasetV2.from_tensors(
3381          prob_of_original, name=name).repeat(name=name)
3382    else:
3383      initial_dist_ds = _estimate_initial_dist_ds(
3384          target_dist_t, self.map(class_func, name=name), name=name)
3385      acceptance_and_original_prob_ds = initial_dist_ds.map(
3386          lambda initial: _calculate_acceptance_probs_with_mixing(  # pylint: disable=g-long-lambda
3387              initial, target_dist_t),
3388          name=name)
3389      acceptance_dist_ds = acceptance_and_original_prob_ds.map(
3390          lambda accept_prob, _: accept_prob, name=name)
3391      prob_of_original_ds = acceptance_and_original_prob_ds.map(
3392          lambda _, prob_original: prob_original, name=name)
3393    filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds,
3394                             class_func, seed)
3395    # Prefetch filtered dataset for speed.
3396    filtered_ds = filtered_ds.prefetch(3, name=name)
3397
3398    prob_original_static = _get_prob_original_static(
3399        initial_dist_t, target_dist_t) if initial_dist is not None else None
3400
3401    def add_class_value(*x):
3402      if len(x) == 1:
3403        return class_func(*x), x[0]
3404      else:
3405        return class_func(*x), x
3406
3407    if prob_original_static == 1:
3408      return self.map(add_class_value, name=name)
3409    elif prob_original_static == 0:
3410      return filtered_ds
3411    else:
3412      return Dataset.sample_from_datasets(
3413          [self.map(add_class_value), filtered_ds],
3414          weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
3415          seed=seed,
3416          stop_on_empty_dataset=True)
3417
3418  @staticmethod
3419  def sample_from_datasets(datasets,
3420                           weights=None,
3421                           seed=None,
3422                           stop_on_empty_dataset=False):
3423    """Samples elements at random from the datasets in `datasets`.
3424
3425    Creates a dataset by interleaving elements of `datasets` with `weight[i]`
3426    probability of picking an element from dataset `i`. Sampling is done without
3427    replacement. For example, suppose we have 2 datasets:
3428
3429    ```python
3430    dataset1 = tf.data.Dataset.range(0, 3)
3431    dataset2 = tf.data.Dataset.range(100, 103)
3432    ```
3433
3434    Suppose that we sample from these 2 datasets with the following weights:
3435
3436    ```python
3437    sample_dataset = tf.data.Dataset.sample_from_datasets(
3438        [dataset1, dataset2], weights=[0.5, 0.5])
3439    ```
3440
3441    One possible outcome of elements in sample_dataset is:
3442
3443    ```
3444    print(list(sample_dataset.as_numpy_iterator()))
3445    # [100, 0, 1, 101, 2, 102]
3446    ```
3447
3448    Args:
3449      datasets: A non-empty list of `tf.data.Dataset` objects with compatible
3450        structure.
3451      weights: (Optional.) A list or Tensor of `len(datasets)` floating-point
3452        values where `weights[i]` represents the probability to sample from
3453        `datasets[i]`, or a `tf.data.Dataset` object where each element is such
3454        a list. Defaults to a uniform distribution across `datasets`.
3455      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
3456        seed that will be used to create the distribution. See
3457        `tf.random.set_seed` for behavior.
3458      stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty
3459        dataset. If `False`, it continues sampling and skips any empty datasets.
3460        It is recommended to set it to `True`. Otherwise, the distribution of
3461        samples starts off as the user intends, but may change as input datasets
3462        become empty. This can be difficult to detect since the dataset starts
3463        off looking correct. Default to `False` for backward compatibility.
3464
3465    Returns:
3466      A dataset that interleaves elements from `datasets` at random, according
3467      to `weights` if provided, otherwise with uniform probability.
3468
3469    Raises:
3470      TypeError: If the `datasets` or `weights` arguments have the wrong type.
3471      ValueError:
3472        - If `datasets` is empty, or
3473        - If `weights` is specified and does not match the length of `datasets`.
3474    """
3475
3476    def _skip_datasets_with_zero_weight(datasets, weights):
3477      datasets_and_weights = [(dataset, weight)
3478                              for (dataset, weight) in zip(datasets, weights)
3479                              if weight > 0]
3480      return (zip(*datasets_and_weights) if datasets_and_weights else
3481              ([datasets[0].take(0)], [1.]))
3482
3483    if not datasets:
3484      raise ValueError("Invalid `datasets`. `datasets` should not be empty.")
3485
3486    if not isinstance(weights, DatasetV2):
3487      if weights is None:
3488        # Select inputs with uniform probability.
3489        logits = [[1.0] * len(datasets)]
3490
3491      else:
3492        if isinstance(weights, ops.Tensor):
3493          if not weights.shape.is_compatible_with([len(datasets)]):
3494            raise ValueError(f"Invalid `weights`. The shape of `weights` "
3495                             f"should be compatible with `[len(datasets)]` "
3496                             f"but is {weights.shape}.")
3497        else:
3498          if len(datasets) != len(weights):
3499            raise ValueError(f"Invalid `weights`. `weights` should have the "
3500                             f"same length as `datasets` but got "
3501                             f"`len(weights)={len(weights)}` vs. "
3502                             f"`len(datasets)={len(datasets)}`.")
3503
3504        # Use the given `weights` as the probability of choosing the respective
3505        # input.
3506        if not isinstance(weights, ops.Tensor):
3507          datasets, weights = _skip_datasets_with_zero_weight(datasets, weights)
3508        weights = ops.convert_to_tensor(weights, name="weights")
3509        if weights.dtype not in (dtypes.float32, dtypes.float64):
3510          raise TypeError(f"Invalid `weights`. `weights` type must be either "
3511                          f"`tf.float32` or `tf.float64` but is "
3512                          f"{weights.dtype}.")
3513
3514        # The `stateless_multinomial()` op expects log-probabilities, as opposed
3515        # to weights.
3516        logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
3517
3518      # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When
3519      # it is a `Dataset`, it is possible that evaluating it has a side effect
3520      # the user depends on.
3521      if len(datasets) == 1:
3522        return datasets[0]
3523
3524      def select_dataset_constant_logits(seed):
3525        return array_ops.squeeze(
3526            gen_stateless_random_ops.stateless_multinomial(
3527                logits, 1, seed=seed),
3528            axis=[0, 1])
3529
3530      selector_input = MapDataset(
3531          RandomDataset(seed).batch(2),
3532          select_dataset_constant_logits,
3533          use_inter_op_parallelism=False)
3534
3535    else:
3536      # Use each element of the given `weights` dataset as the probability of
3537      # choosing the respective input.
3538      #
3539      # The `stateless_multinomial()` op expects log-probabilities, as opposed
3540      # to weights.
3541      logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
3542
3543      def select_dataset_varying_logits(logits, seed):
3544        return array_ops.squeeze(
3545            gen_stateless_random_ops.stateless_multinomial(
3546                logits, 1, seed=seed),
3547            axis=[0, 1])
3548
3549      logits_and_seeds = Dataset.zip((logits_ds, RandomDataset(seed).batch(2)))
3550      selector_input = MapDataset(
3551          logits_and_seeds,
3552          select_dataset_varying_logits,
3553          use_inter_op_parallelism=False)
3554
3555    return _DirectedInterleaveDataset(selector_input, datasets,
3556                                      stop_on_empty_dataset)
3557
3558  @staticmethod
3559  def choose_from_datasets(datasets,
3560                           choice_dataset,
3561                           stop_on_empty_dataset=True):
3562    """Creates a dataset that deterministically chooses elements from `datasets`.
3563
3564    For example, given the following datasets:
3565
3566    ```python
3567    datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
3568                tf.data.Dataset.from_tensors("bar").repeat(),
3569                tf.data.Dataset.from_tensors("baz").repeat()]
3570
3571    # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
3572    choice_dataset = tf.data.Dataset.range(3).repeat(3)
3573
3574    result = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset)
3575    ```
3576
3577    The elements of `result` will be:
3578
3579    ```
3580    "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
3581    ```
3582
3583    Args:
3584      datasets: A non-empty list of `tf.data.Dataset` objects with compatible
3585        structure.
3586      choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
3587        `0` and `len(datasets) - 1`.
3588      stop_on_empty_dataset: If `True`, selection stops if it encounters an
3589        empty dataset. If `False`, it skips empty datasets. It is recommended to
3590        set it to `True`. Otherwise, the selected elements start off as the user
3591        intends, but may change as input datasets become empty. This can be
3592        difficult to detect since the dataset starts off looking correct.
3593        Defaults to `True`.
3594
3595    Returns:
3596      A dataset that interleaves elements from `datasets` according to the
3597      values of `choice_dataset`.
3598
3599    Raises:
3600      TypeError: If `datasets` or `choice_dataset` has the wrong type.
3601      ValueError: If `datasets` is empty.
3602    """
3603    if not datasets:
3604      raise ValueError("Invalid `datasets`. `datasets` should not be empty.")
3605    if not isinstance(choice_dataset, DatasetV2):
3606      raise TypeError(f"Invalid `choice_dataset`. `choice_dataset` should be a "
3607                      f"`tf.data.Dataset` but is {type(choice_dataset)}.")
3608    if not structure.are_compatible(choice_dataset.element_spec,
3609                                    tensor_spec.TensorSpec([], dtypes.int64)):
3610      raise TypeError(f"Invalid `choice_dataset`. Elements of `choice_dataset` "
3611                      f"must be scalar `tf.int64` tensors but are "
3612                      f"{choice_dataset.element_spec}.")
3613    # Replicate the `choice_dataset` component so that each split makes choices
3614    # independently. This avoids the need for prohibitively expensive
3615    # cross-split coordination.
3616    choice_dataset = _apply_rewrite(choice_dataset, "replicate_on_split")
3617    # pylint: disable=protected-access
3618    return _DirectedInterleaveDataset(choice_dataset, datasets,
3619                                      stop_on_empty_dataset)
3620
3621
3622@tf_export(v1=["data.Dataset"])
3623class DatasetV1(DatasetV2):
3624  """Represents a potentially large set of elements.
3625
3626  A `Dataset` can be used to represent an input pipeline as a
3627  collection of elements and a "logical plan" of transformations that act on
3628  those elements.
3629  """
3630
3631  def __init__(self):
3632    try:
3633      variant_tensor = self._as_variant_tensor()
3634    except AttributeError as e:
3635      if "_as_variant_tensor" in str(e):
3636        raise AttributeError("Please use `_variant_tensor` instead of "
3637                             "`_as_variant_tensor()` to obtain the variant "
3638                             "associated with a dataset.")
3639      raise AttributeError("{}: A likely cause of this error is that the super "
3640                           "call for this dataset is not the last line of the "
3641                           "`__init__` method. The base class invokes the "
3642                           "`_as_variant_tensor()` method in its constructor "
3643                           "and if that method uses attributes defined in the "
3644                           "`__init__` method, those attributes need to be "
3645                           "defined before the super call.".format(e))
3646    super(DatasetV1, self).__init__(variant_tensor)
3647
3648  @abc.abstractmethod
3649  def _as_variant_tensor(self):
3650    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
3651
3652    Returns:
3653      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
3654    """
3655    raise NotImplementedError(f"{type(self)}.as_variant_tensor()")
3656
3657  @deprecation.deprecated(
3658      None, "This is a deprecated API that should only be used in TF 1 graph "
3659      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In "
3660      "all other situations -- namely, eager mode and inside `tf.function` -- "
3661      "you can consume dataset elements using `for elem in dataset: ...` or "
3662      "by explicitly creating iterator via `iterator = iter(dataset)` and "
3663      "fetching its elements via `values = next(iterator)`. Furthermore, "
3664      "this API is not available in TF 2. During the transition from TF 1 "
3665      "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` "
3666      "to create a TF 1 graph mode style iterator for a dataset created "
3667      "through TF 2 APIs. Note that this should be a transient state of your "
3668      "code base as there are in general no guarantees about the "
3669      "interoperability of TF 1 and TF 2 code.")
3670  def make_one_shot_iterator(self):
3671    """Creates an iterator for elements of this dataset.
3672
3673    Note: The returned iterator will be initialized automatically.
3674    A "one-shot" iterator does not currently support re-initialization. For
3675    that see `make_initializable_iterator`.
3676
3677    Example:
3678
3679    ```python
3680    # Building graph ...
3681    dataset = ...
3682    next_value = dataset.make_one_shot_iterator().get_next()
3683
3684    # ... from within a session ...
3685    try:
3686      while True:
3687        value = sess.run(next_value)
3688        ...
3689    except tf.errors.OutOfRangeError:
3690        pass
3691    ```
3692
3693    Returns:
3694      An `tf.data.Iterator` for elements of this dataset.
3695    """
3696    return self._make_one_shot_iterator()
3697
3698  def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
3699    if context.executing_eagerly():
3700      with ops.colocate_with(self._variant_tensor):
3701        return iterator_ops.OwnedIterator(self)
3702
3703    _ensure_same_dataset_graph(self)
3704    # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to
3705    # to capture by value. We must allowlist these ops so that the capturing
3706    # logic captures the ops instead of raising an exception.
3707    allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self)
3708    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
3709
3710    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
3711    # a 0-argument function.
3712    @function.Defun(
3713        capture_by_value=True,
3714        allowlisted_stateful_ops=allowlisted_stateful_ops)
3715    def _make_dataset():
3716      """Factory function for a dataset."""
3717      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
3718      # enclosing graph, so if a graph-level seed is present we set the local
3719      # graph seed based on a combination of the graph- and op-level seeds.
3720      if graph_level_seed is not None:
3721        assert op_level_seed is not None
3722        core_random_seed.set_random_seed(
3723            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
3724
3725      dataset = self._apply_debug_options()
3726      return dataset._variant_tensor  # pylint: disable=protected-access
3727
3728    try:
3729      _make_dataset.add_to_graph(ops.get_default_graph())
3730    except ValueError as err:
3731      if "Cannot capture a stateful node" in str(err):
3732        raise ValueError(
3733            "{}: A likely cause of this error is that the dataset for which "
3734            "you are calling `make_one_shot_iterator()` captures a stateful "
3735            "object, such as a `tf.Variable` or `tf.lookup.StaticHashTable`, "
3736            "which is not supported. Use `make_initializable_iterator()` "
3737            "instead.".format(err)) from None
3738      else:
3739        raise
3740
3741    with ops.colocate_with(self._variant_tensor):
3742      # pylint: disable=protected-access
3743      return iterator_ops.Iterator(
3744          gen_dataset_ops.one_shot_iterator(
3745              dataset_factory=_make_dataset, **self._flat_structure), None,
3746          get_legacy_output_types(self), get_legacy_output_shapes(self),
3747          get_legacy_output_classes(self))
3748
3749  @deprecation.deprecated(
3750      None, "This is a deprecated API that should only be used in TF 1 graph "
3751      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. "
3752      "In all other situations -- namely, eager mode and inside `tf.function` "
3753      "-- you can consume dataset elements using `for elem in dataset: ...` "
3754      "or by explicitly creating iterator via `iterator = iter(dataset)` "
3755      "and fetching its elements via `values = next(iterator)`. "
3756      "Furthermore, this API is not available in TF 2. During the transition "
3757      "from TF 1 to TF 2 you can use "
3758      "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF "
3759      "1 graph mode style iterator for a dataset created through TF 2 APIs. "
3760      "Note that this should be a transient state of your code base as there "
3761      "are in general no guarantees about the interoperability of TF 1 and TF "
3762      "2 code.")
3763  def make_initializable_iterator(self, shared_name=None):
3764    """Creates an iterator for elements of this dataset.
3765
3766    Note: The returned iterator will be in an uninitialized state,
3767    and you must run the `iterator.initializer` operation before using it:
3768
3769    ```python
3770    # Building graph ...
3771    dataset = ...
3772    iterator = dataset.make_initializable_iterator()
3773    next_value = iterator.get_next()  # This is a Tensor.
3774
3775    # ... from within a session ...
3776    sess.run(iterator.initializer)
3777    try:
3778      while True:
3779        value = sess.run(next_value)
3780        ...
3781    except tf.errors.OutOfRangeError:
3782        pass
3783    ```
3784
3785    Args:
3786      shared_name: (Optional.) If non-empty, the returned iterator will be
3787        shared under the given name across multiple sessions that share the same
3788        devices (e.g. when using a remote server).
3789
3790    Returns:
3791      A `tf.data.Iterator` for elements of this dataset.
3792
3793    Raises:
3794      RuntimeError: If eager execution is enabled.
3795    """
3796    return self._make_initializable_iterator(shared_name)
3797
3798  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
3799    if context.executing_eagerly():
3800      raise RuntimeError("`make_initializable_iterator()` is not supported in "
3801                         "eager mode. Use Python-style iteration instead.")
3802    _ensure_same_dataset_graph(self)
3803    dataset = self._apply_debug_options()
3804    if shared_name is None:
3805      shared_name = ""
3806
3807    with ops.colocate_with(self._variant_tensor):
3808      iterator_resource = gen_dataset_ops.iterator_v2(
3809          container="", shared_name=shared_name, **self._flat_structure)
3810
3811      initializer = gen_dataset_ops.make_iterator(
3812          dataset._variant_tensor,  # pylint: disable=protected-access
3813          iterator_resource)
3814
3815      # pylint: disable=protected-access
3816      return iterator_ops.Iterator(iterator_resource, initializer,
3817                                   get_legacy_output_types(dataset),
3818                                   get_legacy_output_shapes(dataset),
3819                                   get_legacy_output_classes(dataset))
3820
3821  @property
3822  @deprecation.deprecated(
3823      None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.")
3824  def output_classes(self):
3825    """Returns the class of each component of an element of this dataset.
3826
3827    Returns:
3828      A (nested) structure of Python `type` objects corresponding to each
3829      component of an element of this dataset.
3830    """
3831    return nest.map_structure(
3832        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
3833        self.element_spec)
3834
3835  @property
3836  @deprecation.deprecated(
3837      None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.")
3838  def output_shapes(self):
3839    """Returns the shape of each component of an element of this dataset.
3840
3841    Returns:
3842      A (nested) structure of `tf.TensorShape` objects corresponding to each
3843      component of an element of this dataset.
3844    """
3845    return nest.map_structure(
3846        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
3847        self.element_spec)
3848
3849  @property
3850  @deprecation.deprecated(
3851      None, "Use `tf.compat.v1.data.get_output_types(dataset)`.")
3852  def output_types(self):
3853    """Returns the type of each component of an element of this dataset.
3854
3855    Returns:
3856      A (nested) structure of `tf.DType` objects corresponding to each component
3857      of an element of this dataset.
3858    """
3859    return nest.map_structure(
3860        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
3861        self.element_spec)
3862
3863  @property
3864  def element_spec(self):
3865    # TODO(b/110122868): Remove this override once all `Dataset` instances
3866    # implement `element_structure`.
3867    return structure.convert_legacy_structure(
3868        self.output_types, self.output_shapes, self.output_classes)
3869
3870  @staticmethod
3871  @functools.wraps(DatasetV2.from_tensors)
3872  def from_tensors(tensors, name=None):
3873    return DatasetV1Adapter(DatasetV2.from_tensors(tensors, name=name))
3874
3875  @staticmethod
3876  @functools.wraps(DatasetV2.from_tensor_slices)
3877  def from_tensor_slices(tensors, name=None):
3878    return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors, name=name))
3879
3880  @staticmethod
3881  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
3882  def from_sparse_tensor_slices(sparse_tensor):
3883    """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise.
3884
3885    Args:
3886      sparse_tensor: A `tf.sparse.SparseTensor`.
3887
3888    Returns:
3889      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
3890    """
3891    return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
3892
3893  @staticmethod
3894  @functools.wraps(DatasetV2.from_generator)
3895  @deprecation.deprecated_args(None, "Use output_signature instead",
3896                               "output_types", "output_shapes")
3897  def from_generator(generator,
3898                     output_types=None,
3899                     output_shapes=None,
3900                     args=None,
3901                     output_signature=None,
3902                     name=None):
3903    # Calling DatasetV2.from_generator with output_shapes or output_types is
3904    # deprecated, but this is already checked by the decorator on this function.
3905    with deprecation.silence():
3906      return DatasetV1Adapter(
3907          DatasetV2.from_generator(
3908              generator,
3909              output_types,
3910              output_shapes,
3911              args,
3912              output_signature,
3913              name=name))
3914
3915  @staticmethod
3916  @functools.wraps(DatasetV2.range)
3917  def range(*args, **kwargs):
3918    return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
3919
3920  @staticmethod
3921  @functools.wraps(DatasetV2.zip)
3922  def zip(datasets, name=None):
3923    return DatasetV1Adapter(DatasetV2.zip(datasets, name=name))
3924
3925  @functools.wraps(DatasetV2.concatenate)
3926  def concatenate(self, dataset, name=None):
3927    return DatasetV1Adapter(
3928        super(DatasetV1, self).concatenate(dataset, name=name))
3929
3930  @functools.wraps(DatasetV2.prefetch)
3931  def prefetch(self, buffer_size, name=None):
3932    return DatasetV1Adapter(
3933        super(DatasetV1, self).prefetch(buffer_size, name=name))
3934
3935  @staticmethod
3936  @functools.wraps(DatasetV2.list_files)
3937  def list_files(file_pattern, shuffle=None, seed=None, name=None):
3938    return DatasetV1Adapter(
3939        DatasetV2.list_files(file_pattern, shuffle, seed, name=name))
3940
3941  @functools.wraps(DatasetV2.repeat)
3942  def repeat(self, count=None, name=None):
3943    return DatasetV1Adapter(super(DatasetV1, self).repeat(count, name=name))
3944
3945  @functools.wraps(DatasetV2.shuffle)
3946  def shuffle(self,
3947              buffer_size,
3948              seed=None,
3949              reshuffle_each_iteration=None,
3950              name=None):
3951    return DatasetV1Adapter(
3952        super(DatasetV1, self).shuffle(
3953            buffer_size, seed, reshuffle_each_iteration, name=name))
3954
3955  @functools.wraps(DatasetV2.cache)
3956  def cache(self, filename="", name=None):
3957    return DatasetV1Adapter(super(DatasetV1, self).cache(filename, name=name))
3958
3959  @functools.wraps(DatasetV2.take)
3960  def take(self, count, name=None):
3961    return DatasetV1Adapter(super(DatasetV1, self).take(count, name=name))
3962
3963  @functools.wraps(DatasetV2.skip)
3964  def skip(self, count, name=None):
3965    return DatasetV1Adapter(super(DatasetV1, self).skip(count, name=name))
3966
3967  @functools.wraps(DatasetV2.shard)
3968  def shard(self, num_shards, index, name=None):
3969    return DatasetV1Adapter(
3970        super(DatasetV1, self).shard(num_shards, index, name=name))
3971
3972  @functools.wraps(DatasetV2.batch)
3973  def batch(self,
3974            batch_size,
3975            drop_remainder=False,
3976            num_parallel_calls=None,
3977            deterministic=None,
3978            name=None):
3979    return DatasetV1Adapter(
3980        super(DatasetV1, self).batch(
3981            batch_size,
3982            drop_remainder,
3983            num_parallel_calls,
3984            deterministic,
3985            name=name))
3986
3987  @functools.wraps(DatasetV2.padded_batch)
3988  def padded_batch(self,
3989                   batch_size,
3990                   padded_shapes=None,
3991                   padding_values=None,
3992                   drop_remainder=False,
3993                   name=None):
3994    return DatasetV1Adapter(
3995        super(DatasetV1, self).padded_batch(
3996            batch_size,
3997            padded_shapes,
3998            padding_values,
3999            drop_remainder,
4000            name=name))
4001
4002  @functools.wraps(DatasetV2.map)
4003  def map(self,
4004          map_func,
4005          num_parallel_calls=None,
4006          deterministic=None,
4007          name=None):
4008    if num_parallel_calls is None or DEBUG_MODE:
4009      return DatasetV1Adapter(
4010          MapDataset(self, map_func, preserve_cardinality=False))
4011    else:
4012      return DatasetV1Adapter(
4013          ParallelMapDataset(
4014              self,
4015              map_func,
4016              num_parallel_calls,
4017              deterministic,
4018              preserve_cardinality=False))
4019
4020  @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
4021  def map_with_legacy_function(self,
4022                               map_func,
4023                               num_parallel_calls=None,
4024                               deterministic=None):
4025    """Maps `map_func` across the elements of this dataset.
4026
4027    Note: This is an escape hatch for existing uses of `map` that do not work
4028    with V2 functions. New uses are strongly discouraged and existing uses
4029    should migrate to `map` as this method will be removed in V2.
4030
4031    Args:
4032      map_func: A function mapping a (nested) structure of tensors (having
4033        shapes and types defined by `self.output_shapes` and
4034        `self.output_types`) to another (nested) structure of tensors.
4035      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
4036        representing the number elements to process asynchronously in parallel.
4037        If not specified, elements will be processed sequentially. If the value
4038        `tf.data.AUTOTUNE` is used, then the number of parallel calls is set
4039        dynamically based on available CPU.
4040      deterministic: (Optional.) When `num_parallel_calls` is specified, this
4041        boolean controls the order in which the transformation produces
4042        elements. If set to `False`, the transformation is allowed to yield
4043        elements out of order to trade determinism for performance. If not
4044        specified, the `tf.data.Options.deterministic` option (`True` by
4045        default) controls the behavior.
4046
4047    Returns:
4048      Dataset: A `Dataset`.
4049    """
4050    if num_parallel_calls is None:
4051      if deterministic is not None:
4052        warnings.warn("The `deterministic` argument has no effect unless the "
4053                      "`num_parallel_calls` argument is specified.")
4054      return DatasetV1Adapter(
4055          MapDataset(
4056              self,
4057              map_func,
4058              preserve_cardinality=False,
4059              use_legacy_function=True))
4060    else:
4061      return DatasetV1Adapter(
4062          ParallelMapDataset(
4063              self,
4064              map_func,
4065              num_parallel_calls,
4066              deterministic,
4067              preserve_cardinality=False,
4068              use_legacy_function=True))
4069
4070  @functools.wraps(DatasetV2.flat_map)
4071  def flat_map(self, map_func, name=None):
4072    return DatasetV1Adapter(
4073        super(DatasetV1, self).flat_map(map_func, name=name))
4074
4075  @functools.wraps(DatasetV2.interleave)
4076  def interleave(self,
4077                 map_func,
4078                 cycle_length=None,
4079                 block_length=None,
4080                 num_parallel_calls=None,
4081                 deterministic=None,
4082                 name=None):
4083    return DatasetV1Adapter(
4084        super(DatasetV1, self).interleave(
4085            map_func,
4086            cycle_length,
4087            block_length,
4088            num_parallel_calls,
4089            deterministic,
4090            name=name))
4091
4092  @functools.wraps(DatasetV2.filter)
4093  def filter(self, predicate, name=None):
4094    return DatasetV1Adapter(super(DatasetV1, self).filter(predicate, name=name))
4095
4096  @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
4097  def filter_with_legacy_function(self, predicate):
4098    """Filters this dataset according to `predicate`.
4099
4100    Note: This is an escape hatch for existing uses of `filter` that do not work
4101    with V2 functions. New uses are strongly discouraged and existing uses
4102    should migrate to `filter` as this method will be removed in V2.
4103
4104    Args:
4105      predicate: A function mapping a (nested) structure of tensors (having
4106        shapes and types defined by `self.output_shapes` and
4107        `self.output_types`) to a scalar `tf.bool` tensor.
4108
4109    Returns:
4110      Dataset: The `Dataset` containing the elements of this dataset for which
4111          `predicate` is `True`.
4112    """
4113    return FilterDataset(self, predicate, use_legacy_function=True)
4114
4115  @functools.wraps(DatasetV2.apply)
4116  def apply(self, transformation_func):
4117    return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
4118
4119  @functools.wraps(DatasetV2.window)
4120  def window(self, size, shift=None, stride=1, drop_remainder=False, name=None):
4121    return DatasetV1Adapter(
4122        super(DatasetV1,
4123              self).window(size, shift, stride, drop_remainder, name=name))
4124
4125  @functools.wraps(DatasetV2.unbatch)
4126  def unbatch(self, name=None):
4127    return DatasetV1Adapter(super(DatasetV1, self).unbatch(name=name))
4128
4129  @functools.wraps(DatasetV2.with_options)
4130  def with_options(self, options, name=None):
4131    return DatasetV1Adapter(
4132        super(DatasetV1, self).with_options(options, name=name))
4133
4134
4135if tf2.enabled():
4136  Dataset = DatasetV2
4137else:
4138  Dataset = DatasetV1
4139
4140
4141class DatasetV1Adapter(DatasetV1):
4142  """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
4143
4144  def __init__(self, dataset):
4145    self._dataset = dataset
4146    super(DatasetV1Adapter, self).__init__()
4147
4148  def _as_variant_tensor(self):
4149    return self._dataset._variant_tensor  # pylint: disable=protected-access
4150
4151  def _inputs(self):
4152    return self._dataset._inputs()  # pylint: disable=protected-access
4153
4154  def _functions(self):
4155    return self._dataset._functions()  # pylint: disable=protected-access
4156
4157  def options(self):
4158    return self._dataset.options()
4159
4160  @property
4161  def element_spec(self):
4162    return self._dataset.element_spec  # pylint: disable=protected-access
4163
4164  def __iter__(self):
4165    return iter(self._dataset)
4166
4167
4168def _ensure_same_dataset_graph(dataset):
4169  """Walks the dataset graph to ensure all datasets come from the same graph."""
4170  # pylint: disable=protected-access
4171  current_graph = ops.get_default_graph()
4172  bfs_q = queue.Queue()
4173  bfs_q.put(dataset)
4174  visited = []
4175  while not bfs_q.empty():
4176    ds = bfs_q.get()
4177    visited.append(ds)
4178    ds_graph = ds._graph
4179    if current_graph != ds_graph:
4180      raise ValueError(
4181          f"The graph {current_graph} of the iterator is different from the "
4182          f"graph {ds_graph} the dataset: {ds._variant_tensor} was created in. "
4183          f"If you are using the Estimator API, make sure that no part of the "
4184          f"dataset returned by the `input_fn` function is defined outside the "
4185          f"`input_fn` function. Otherwise, make sure that the dataset is "
4186          f"created in the same graph as the iterator.")
4187    for input_ds in ds._inputs():
4188      if input_ds not in visited:
4189        bfs_q.put(input_ds)
4190
4191
4192@tf_export(v1=["data.make_one_shot_iterator"])
4193def make_one_shot_iterator(dataset):
4194  """Creates an iterator for elements of `dataset`.
4195
4196  Note: The returned iterator will be initialized automatically.
4197  A "one-shot" iterator does not support re-initialization.
4198
4199  Args:
4200    dataset: A `tf.data.Dataset`.
4201
4202  Returns:
4203    A `tf.data.Iterator` for elements of `dataset`.
4204
4205  @compatibility(TF2)
4206  This is a legacy API for consuming dataset elements and should only be used
4207  during transition from TF 1 to TF 2. Note that using this API should be
4208  a transient state of your code base as there are in general no guarantees
4209  about the interoperability of TF 1 and TF 2 code.
4210
4211  In TF 2 datasets are Python iterables which means you can consume their
4212  elements using `for elem in dataset: ...` or by explicitly creating iterator
4213  via `iterator = iter(dataset)` and fetching its elements via
4214  `values = next(iterator)`.
4215  @end_compatibility
4216  """
4217  try:
4218    # Call the defined `_make_one_shot_iterator()` if there is one, because some
4219    # datasets (e.g. for prefetching) override its behavior.
4220    return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
4221  except AttributeError:
4222    return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
4223
4224
4225@tf_export(v1=["data.make_initializable_iterator"])
4226def make_initializable_iterator(dataset, shared_name=None):
4227  """Creates an iterator for elements of `dataset`.
4228
4229  Note: The returned iterator will be in an uninitialized state,
4230  and you must run the `iterator.initializer` operation before using it:
4231
4232  ```python
4233  dataset = ...
4234  iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
4235  # ...
4236  sess.run(iterator.initializer)
4237  ```
4238
4239  Args:
4240    dataset: A `tf.data.Dataset`.
4241    shared_name: (Optional.) If non-empty, the returned iterator will be shared
4242      under the given name across multiple sessions that share the same devices
4243      (e.g. when using a remote server).
4244
4245  Returns:
4246    A `tf.data.Iterator` for elements of `dataset`.
4247
4248  Raises:
4249    RuntimeError: If eager execution is enabled.
4250
4251  @compatibility(TF2)
4252  This is a legacy API for consuming dataset elements and should only be used
4253  during transition from TF 1 to TF 2. Note that using this API should be
4254  a transient state of your code base as there are in general no guarantees
4255  about the interoperability of TF 1 and TF 2 code.
4256
4257  In TF 2 datasets are Python iterables which means you can consume their
4258  elements using `for elem in dataset: ...` or by explicitly creating iterator
4259  via `iterator = iter(dataset)` and fetching its elements via
4260  `values = next(iterator)`.
4261  @end_compatibility
4262  """
4263  try:
4264    # Call the defined `_make_initializable_iterator()` if there is one, because
4265    # some datasets (e.g. for prefetching) override its behavior.
4266    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
4267  except AttributeError:
4268    return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
4269
4270
4271@tf_export("data.experimental.get_structure")
4272def get_structure(dataset_or_iterator):
4273  """Returns the type signature for elements of the input dataset / iterator.
4274
4275  For example, to get the structure of a `tf.data.Dataset`:
4276
4277  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
4278  >>> tf.data.experimental.get_structure(dataset)
4279  TensorSpec(shape=(), dtype=tf.int32, name=None)
4280
4281  >>> dataset = tf.data.experimental.from_list([(1, 'a'), (2, 'b'), (3, 'c')])
4282  >>> tf.data.experimental.get_structure(dataset)
4283  (TensorSpec(shape=(), dtype=tf.int32, name=None),
4284   TensorSpec(shape=(), dtype=tf.string, name=None))
4285
4286  To get the structure of an `tf.data.Iterator`:
4287
4288  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
4289  >>> tf.data.experimental.get_structure(iter(dataset))
4290  TensorSpec(shape=(), dtype=tf.int32, name=None)
4291
4292  Args:
4293    dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
4294
4295  Returns:
4296    A (nested) structure of `tf.TypeSpec` objects matching the structure of an
4297    element of `dataset_or_iterator` and specifying the type of individual
4298    components.
4299
4300  Raises:
4301    TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
4302      object.
4303  """
4304  try:
4305    return dataset_or_iterator.element_spec  # pylint: disable=protected-access
4306  except AttributeError:
4307    raise TypeError(f"Invalid `dataset_or_iterator`. `dataset_or_iterator` "
4308                    f"must be a `tf.data.Dataset` or tf.data.Iterator object, "
4309                    f"but got {type(dataset_or_iterator)}.")
4310
4311
4312@tf_export(v1=["data.get_output_classes"])
4313def get_legacy_output_classes(dataset_or_iterator):
4314  """Returns the output classes for elements of the input dataset / iterator.
4315
4316  Args:
4317    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4318
4319  Returns:
4320    A (nested) structure of Python `type` objects matching the structure of the
4321    dataset / iterator elements and specifying the class of the individual
4322    components.
4323
4324  @compatibility(TF2)
4325  This is a legacy API for inspecting the type signature of dataset elements. In
4326  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4327  @end_compatibility
4328  """
4329  return nest.map_structure(
4330      lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
4331      get_structure(dataset_or_iterator))
4332
4333
4334@tf_export(v1=["data.get_output_shapes"])
4335def get_legacy_output_shapes(dataset_or_iterator):
4336  """Returns the output shapes for elements of the input dataset / iterator.
4337
4338  Args:
4339    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4340
4341  Returns:
4342    A (nested) structure of `tf.TensorShape` objects matching the structure of
4343    the dataset / iterator elements and specifying the shape of the individual
4344    components.
4345
4346  @compatibility(TF2)
4347  This is a legacy API for inspecting the type signature of dataset elements. In
4348  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4349  @end_compatibility
4350  """
4351  return nest.map_structure(
4352      lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
4353      get_structure(dataset_or_iterator))
4354
4355
4356@tf_export(v1=["data.get_output_types"])
4357def get_legacy_output_types(dataset_or_iterator):
4358  """Returns the output shapes for elements of the input dataset / iterator.
4359
4360  Args:
4361    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4362
4363  Returns:
4364    A (nested) structure of `tf.DType` objects matching the structure of
4365    dataset / iterator elements and specifying the shape of the individual
4366    components.
4367
4368  @compatibility(TF2)
4369  This is a legacy API for inspecting the type signature of dataset elements. In
4370  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4371  @end_compatibility
4372  """
4373  return nest.map_structure(
4374      lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
4375      get_structure(dataset_or_iterator))
4376
4377
4378class DatasetSource(DatasetV2):
4379  """Abstract class representing a dataset with no inputs."""
4380
4381  def _inputs(self):
4382    return []
4383
4384
4385class UnaryDataset(DatasetV2):
4386  """Abstract class representing a dataset with one input."""
4387
4388  def __init__(self, input_dataset, variant_tensor):
4389    self._input_dataset = input_dataset
4390    super(UnaryDataset, self).__init__(variant_tensor)
4391
4392  def _inputs(self):
4393    return [self._input_dataset]
4394
4395
4396class UnaryUnchangedStructureDataset(UnaryDataset):
4397  """Represents a unary dataset with the same input and output structure."""
4398
4399  def __init__(self, input_dataset, variant_tensor):
4400    self._input_dataset = input_dataset
4401    super(UnaryUnchangedStructureDataset, self).__init__(
4402        input_dataset, variant_tensor)
4403
4404  @property
4405  def element_spec(self):
4406    return self._input_dataset.element_spec
4407
4408
4409class _VariantDataset(DatasetV2):
4410  """A Dataset wrapper around a `tf.variant`-typed function argument."""
4411
4412  def __init__(self, dataset_variant, element_spec):
4413    self._element_spec = element_spec
4414    super(_VariantDataset, self).__init__(dataset_variant)
4415
4416  def _inputs(self):
4417    return []
4418
4419  @property
4420  def element_spec(self):
4421    return self._element_spec
4422
4423
4424class _NestedVariant(composite_tensor.CompositeTensor):
4425
4426  def __init__(self, variant_tensor, element_spec, dataset_shape):
4427    self._variant_tensor = variant_tensor
4428    self._element_spec = element_spec
4429    self._dataset_shape = dataset_shape
4430
4431  @property
4432  def _type_spec(self):
4433    return DatasetSpec(self._element_spec, self._dataset_shape)
4434
4435
4436@tf_export("data.experimental.from_variant")
4437def from_variant(variant, structure):
4438  """Constructs a dataset from the given variant and (nested) structure.
4439
4440  Args:
4441    variant: A scalar `tf.variant` tensor representing a dataset.
4442    structure: A (nested) structure of `tf.TypeSpec` objects representing the
4443      structure of each element in the dataset.
4444
4445  Returns:
4446    A `tf.data.Dataset` instance.
4447  """
4448  return _VariantDataset(variant, structure)  # pylint: disable=protected-access
4449
4450
4451@tf_export("data.experimental.to_variant")
4452def to_variant(dataset):
4453  """Returns a variant representing the given dataset.
4454
4455  Args:
4456    dataset: A `tf.data.Dataset`.
4457
4458  Returns:
4459    A scalar `tf.variant` tensor representing the given dataset.
4460  """
4461  return dataset._variant_tensor  # pylint: disable=protected-access
4462
4463
4464@tf_export(
4465    "data.DatasetSpec",
4466    v1=["data.DatasetSpec", "data.experimental.DatasetStructure"])
4467class DatasetSpec(type_spec.BatchableTypeSpec):
4468  """Type specification for `tf.data.Dataset`.
4469
4470  See `tf.TypeSpec` for more information about TensorFlow type specifications.
4471
4472  >>> dataset = tf.data.Dataset.range(3)
4473  >>> tf.data.DatasetSpec.from_value(dataset)
4474  DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([]))
4475  """
4476
4477  __slots__ = ["_element_spec", "_dataset_shape"]
4478
4479  def __init__(self, element_spec, dataset_shape=()):
4480    self._element_spec = element_spec
4481    self._dataset_shape = tensor_shape.as_shape(dataset_shape)
4482
4483  @property
4484  def value_type(self):
4485    return Dataset
4486
4487  @property
4488  def element_spec(self):
4489    """The inner element spec."""
4490    return self._element_spec
4491
4492  def is_subtype_of(self, other):
4493    """See base class."""
4494    if type(self) is not type(other):
4495      return False
4496
4497    # TODO(b/220385675): _element_spec should always be a TypeSpec.
4498    try:
4499      tf_nest.assert_same_structure(self.element_spec, other.element_spec)
4500    except (TypeError, ValueError):
4501      return False
4502
4503    self_elements = tf_nest.flatten(self.element_spec)
4504    other_elements = tf_nest.flatten(other.element_spec)
4505
4506    def is_subtype_or_equal(a, b):
4507      if isinstance(a, trace.TraceType):
4508        return a.is_subtype_of(b)
4509      else:
4510        return a == b
4511
4512    for self_element, other_element in zip(self_elements, other_elements):
4513      if not is_subtype_or_equal(self_element, other_element):
4514        return False
4515
4516    return self._dataset_shape.is_subtype_of(other._dataset_shape)  # pylint: disable=protected-access
4517
4518  def most_specific_common_supertype(self, others):
4519    """See base class."""
4520    if not all(type(self) is type(other) for other in others):
4521      return None
4522
4523    try:
4524      for other in others:
4525        tf_nest.assert_same_structure(self.element_spec, other.element_spec)
4526    except (TypeError, ValueError):
4527      return None
4528
4529    self_components = tf_nest.flatten(self.element_spec)
4530    others_components = [
4531        tf_nest.flatten(other.element_spec) for other in others
4532    ]
4533    common_components = [None] * len(self_components)
4534
4535    def common_supertype_or_equal(a, bs):
4536      if isinstance(a, trace.TraceType):
4537        return a.most_specific_common_supertype(bs)
4538      else:
4539        return a if all(a == b for b in bs) else None
4540
4541    for i, self_component in enumerate(self_components):
4542      common_components[i] = common_supertype_or_equal(
4543          self_component,
4544          [other_components[i] for other_components in others_components])
4545      if self_component is not None and common_components[i] is None:
4546        return None
4547    common_element_spec = tf_nest.pack_sequence_as(self._element_spec,
4548                                                   common_components)
4549
4550    common_dataset_shape = self._dataset_shape.most_specific_common_supertype(
4551        [other._dataset_shape for other in others])  # pylint: disable=protected-access
4552    if common_dataset_shape is None:
4553      return None
4554
4555    return DatasetSpec(common_element_spec, common_dataset_shape)
4556
4557  # TODO(b/220385675): Once _element_spec is guaranteed to be TypeSpec, the
4558  # following functions do not need to be overloaded: is_subtype_of,
4559  # most_specific_common_supertype, __hash__ and __eq__
4560  def _serialize(self):
4561    return (self._element_spec, self._dataset_shape)
4562
4563  @property
4564  def _component_specs(self):
4565    return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant)
4566
4567  def _to_components(self, value):
4568    return value._variant_tensor  # pylint: disable=protected-access
4569
4570  def _from_components(self, components):
4571    # pylint: disable=protected-access
4572    if self._dataset_shape.ndims == 0:
4573      return _VariantDataset(components, self._element_spec)
4574    else:
4575      return _NestedVariant(components, self._element_spec, self._dataset_shape)
4576
4577  def _to_tensor_list(self, value):
4578    return [
4579        ops.convert_to_tensor(
4580            tf_nest.map_structure(lambda x: x._variant_tensor, value))  # pylint: disable=protected-access
4581    ]
4582
4583  @staticmethod
4584  def from_value(value):
4585    """Creates a `DatasetSpec` for the given `tf.data.Dataset` value."""
4586    return DatasetSpec(value.element_spec)  # pylint: disable=protected-access
4587
4588  def _batch(self, batch_size):
4589    return DatasetSpec(
4590        self._element_spec,
4591        tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
4592
4593  def _unbatch(self):
4594    if self._dataset_shape.ndims == 0:
4595      raise ValueError("Slicing dataset elements is not supported for rank 0.")
4596    return DatasetSpec(self._element_spec, self._dataset_shape[1:])
4597
4598  def _to_batched_tensor_list(self, value):
4599    if self._dataset_shape.ndims == 0:
4600      raise ValueError("Slicing dataset elements is not supported for rank 0.")
4601    return self._to_tensor_list(value)
4602
4603  def _to_legacy_output_types(self):
4604    return self
4605
4606  def _to_legacy_output_shapes(self):
4607    return self
4608
4609  def _to_legacy_output_classes(self):
4610    return self
4611
4612  def __hash__(self):
4613    # TODO(b/220385675): attributes can be dicts and hence unhashable.
4614    return hash(DatasetSpec)
4615
4616  def __eq__(self, other):
4617    return (isinstance(other, DatasetSpec) and
4618            self._element_spec == other._element_spec and
4619            self._dataset_shape == other._dataset_shape)
4620
4621
4622class _NumpyIterator:
4623  """Iterator over a dataset with elements converted to numpy."""
4624
4625  __slots__ = ["_iterator"]
4626
4627  def __init__(self, dataset):
4628    self._iterator = iter(dataset)
4629
4630  def __iter__(self):
4631    return self
4632
4633  def __next__(self):
4634
4635    def to_numpy(x):
4636      numpy = x._numpy()  # pylint: disable=protected-access
4637      if isinstance(numpy, np.ndarray):
4638        # `numpy` shares the same underlying buffer as the `x` Tensor.
4639        # Tensors are expected to be immutable, so we disable writes.
4640        numpy.setflags(write=False)
4641      return numpy
4642
4643    return nest.map_structure(to_numpy, next(self._iterator))
4644
4645  def next(self):
4646    return self.__next__()
4647
4648
4649class _VariantTracker(resource_lib.CapturableResource):
4650  """Allows export of functions capturing a Dataset in SavedModels.
4651
4652  When saving a SavedModel, `tf.saved_model.save` traverses the object
4653  graph. Since Datasets reference _VariantTracker objects, that traversal will
4654  find a _VariantTracker for each Dataset and so know how to save and restore
4655  functions which reference the Dataset's variant Tensor.
4656  """
4657
4658  def __init__(self, variant_tensor, resource_creator):
4659    """Record that `variant_tensor` is associated with `resource_creator`.
4660
4661    Args:
4662      variant_tensor: The variant-dtype Tensor associated with the Dataset. This
4663        Tensor will be a captured input to functions which use the Dataset, and
4664        is used by saving code to identify the corresponding _VariantTracker.
4665      resource_creator: A zero-argument function which creates a new
4666        variant-dtype Tensor. This function will be included in SavedModels and
4667        run to re-create the Dataset's variant Tensor on restore.
4668    """
4669    super(_VariantTracker, self).__init__(device="CPU")
4670    self._resource_handle = variant_tensor
4671    if not isinstance(resource_creator, def_function.Function):
4672      # Internal validation -- _VariantTracker assumes that resource creator is
4673      # already a tf.function.
4674      raise TypeError("Resource creator should already be a tf.function.")
4675    self._create_resource = resource_creator
4676
4677  def _trackable_children(self,
4678                          save_type=tracking_base.SaveType.CHECKPOINT,
4679                          **kwargs):
4680    if save_type != tracking_base.SaveType.SAVEDMODEL:
4681      return {}
4682
4683    children = super(_VariantTracker,
4684                     self)._trackable_children(save_type, **kwargs)
4685    # Overwrite the _create_resource function, since `self._create_resource`
4686    # is already a tf.function.
4687    children["_create_resource"] = self._create_resource
4688    return children
4689
4690
4691class TensorDataset(DatasetSource):
4692  """A `Dataset` with a single element."""
4693
4694  def __init__(self, element, name=None):
4695    """See `Dataset.from_tensors()` for details."""
4696    element = structure.normalize_element(element)
4697    self._structure = structure.type_spec_from_value(element)
4698    self._tensors = structure.to_tensor_list(self._structure, element)
4699    self._name = name
4700    variant_tensor = gen_dataset_ops.tensor_dataset(
4701        self._tensors,
4702        output_shapes=structure.get_flat_tensor_shapes(self._structure),
4703        metadata=self._metadata.SerializeToString())
4704    super(TensorDataset, self).__init__(variant_tensor)
4705
4706  @property
4707  def element_spec(self):
4708    return self._structure
4709
4710
4711class TensorSliceDataset(DatasetSource):
4712  """A `Dataset` of slices from a dataset element."""
4713
4714  def __init__(self, element, is_files=False, name=None):
4715    """See `Dataset.from_tensor_slices()` for details."""
4716    element = structure.normalize_element(element)
4717    batched_spec = structure.type_spec_from_value(element)
4718    self._tensors = structure.to_batched_tensor_list(batched_spec, element)
4719    if not self._tensors:
4720      raise ValueError("Invalid `element`. `element` should not be empty.")
4721    self._structure = nest.map_structure(
4722        lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access
4723    self._name = name
4724
4725    batch_dim = tensor_shape.Dimension(
4726        tensor_shape.dimension_value(self._tensors[0].get_shape()[0]))
4727    for t in self._tensors[1:]:
4728      batch_dim.assert_is_compatible_with(
4729          tensor_shape.Dimension(
4730              tensor_shape.dimension_value(t.get_shape()[0])))
4731
4732    variant_tensor = gen_dataset_ops.tensor_slice_dataset(
4733        self._tensors,
4734        output_shapes=structure.get_flat_tensor_shapes(self._structure),
4735        is_files=is_files,
4736        metadata=self._metadata.SerializeToString())
4737    super(TensorSliceDataset, self).__init__(variant_tensor)
4738
4739  @property
4740  def element_spec(self):
4741    return self._structure
4742
4743
4744class SparseTensorSliceDataset(DatasetSource):
4745  """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows."""
4746
4747  def __init__(self, sparse_tensor):
4748    """See `Dataset.from_sparse_tensor_slices()` for details."""
4749    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
4750      raise TypeError(f"Invalid `sparse_tensor`. `sparse_tensor` must be a "
4751                      f"`tf.sparse.SparseTensor`. Got {type(sparse_tensor)}.")
4752    self._sparse_tensor = sparse_tensor
4753
4754    indices_shape = self._sparse_tensor.indices.get_shape()
4755    shape_shape = self._sparse_tensor.dense_shape.get_shape()
4756    rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
4757    self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64),
4758                       tensor_spec.TensorSpec([None],
4759                                              self._sparse_tensor.dtype),
4760                       tensor_spec.TensorSpec([rank], dtypes.int64))
4761
4762    variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
4763        self._sparse_tensor.indices, self._sparse_tensor.values,
4764        self._sparse_tensor.dense_shape)
4765    super(SparseTensorSliceDataset, self).__init__(variant_tensor)
4766
4767  @property
4768  def element_spec(self):
4769    return self._structure
4770
4771
4772class _GeneratorDataset(DatasetSource):
4773  """A `Dataset` that generates elements by invoking a function."""
4774
4775  def __init__(self,
4776               init_args,
4777               init_func,
4778               next_func,
4779               finalize_func,
4780               output_signature,
4781               name=None):
4782    """Constructs a `_GeneratorDataset`.
4783
4784    Args:
4785      init_args: A (nested) structure representing the arguments to `init_func`.
4786      init_func: A TensorFlow function that will be called on `init_args` each
4787        time a C++ iterator over this dataset is constructed. Returns a (nested)
4788        structure representing the "state" of the dataset.
4789      next_func: A TensorFlow function that will be called on the result of
4790        `init_func` to produce each element, and that raises `OutOfRangeError`
4791        to terminate iteration.
4792      finalize_func: A TensorFlow function that will be called on the result of
4793        `init_func` immediately before a C++ iterator over this dataset is
4794        destroyed. The return value is ignored.
4795      output_signature: A (nested) structure of `tf.TypeSpec` objects describing
4796        the output of `next_func`.
4797      name: Optional. A name for the tf.data transformation.
4798    """
4799    self._init_args = init_args
4800
4801    self._init_structure = structure.type_spec_from_value(init_args)
4802
4803    self._init_func = structured_function.StructuredFunctionWrapper(
4804        init_func,
4805        self._transformation_name(),
4806        input_structure=self._init_structure)
4807
4808    self._next_func = structured_function.StructuredFunctionWrapper(
4809        next_func,
4810        self._transformation_name(),
4811        input_structure=self._init_func.output_structure)
4812
4813    self._finalize_func = structured_function.StructuredFunctionWrapper(
4814        finalize_func,
4815        self._transformation_name(),
4816        input_structure=self._init_func.output_structure)
4817
4818    self._output_signature = output_signature
4819
4820    self._name = name
4821
4822    variant_tensor = gen_dataset_ops.generator_dataset(
4823        structure.to_tensor_list(self._init_structure, self._init_args) +
4824        self._init_func.function.captured_inputs,
4825        self._next_func.function.captured_inputs,
4826        self._finalize_func.function.captured_inputs,
4827        init_func=self._init_func.function,
4828        next_func=self._next_func.function,
4829        finalize_func=self._finalize_func.function,
4830        **self._common_args)
4831    super(_GeneratorDataset, self).__init__(variant_tensor)
4832
4833  @property
4834  def element_spec(self):
4835    return self._output_signature
4836
4837  def _transformation_name(self):
4838    return "Dataset.from_generator()"
4839
4840
4841class ZipDataset(DatasetV2):
4842  """A `Dataset` that zips its inputs together."""
4843
4844  def __init__(self, datasets, name=None):
4845    """See `Dataset.zip()` for details."""
4846    for ds in nest.flatten(datasets):
4847      if not isinstance(ds, DatasetV2):
4848        if isinstance(ds, list):
4849          raise TypeError("Invalid `datasets`. `datasets` is expected to be a "
4850                          "(nested) structure of `tf.data.Dataset` objects. "
4851                          "Python `list` is not supported and you should use "
4852                          "`tuple` instead.")
4853        else:
4854          raise TypeError(f"Invalid `datasets`. `datasets` is expected to be a "
4855                          f"(nested) structure of `tf.data.Dataset` objects "
4856                          f"but encountered object of type {type(ds)}.")
4857    self._datasets = datasets
4858    self._structure = nest.pack_sequence_as(
4859        self._datasets,
4860        [ds.element_spec for ds in nest.flatten(self._datasets)])
4861    self._name = name
4862    variant_tensor = gen_dataset_ops.zip_dataset(
4863        [ds._variant_tensor for ds in nest.flatten(self._datasets)],
4864        **self._common_args)
4865    super(ZipDataset, self).__init__(variant_tensor)
4866
4867  def _inputs(self):
4868    return nest.flatten(self._datasets)
4869
4870  @property
4871  def element_spec(self):
4872    return self._structure
4873
4874
4875class ConcatenateDataset(DatasetV2):
4876  """A `Dataset` that concatenates its input with given dataset."""
4877
4878  def __init__(self, input_dataset, dataset_to_concatenate, name=None):
4879    """See `Dataset.concatenate()` for details."""
4880    self._input_dataset = input_dataset
4881    self._dataset_to_concatenate = dataset_to_concatenate
4882
4883    def common_supertype(a, b):
4884      result = a.most_specific_common_supertype([b])
4885      if result is None:
4886        raise TypeError(f"No common supertype of {a} and {b}.")
4887      return result
4888
4889    try:
4890      self._structure = tf_nest.map_structure(
4891          common_supertype, input_dataset.element_spec,
4892          dataset_to_concatenate.element_spec)
4893    except (TypeError, ValueError) as e:
4894      raise TypeError(
4895          f"Incompatible dataset elements:\n"
4896          f"  {input_dataset.element_spec} vs. "
4897          f"  {dataset_to_concatenate.element_spec}") from e
4898
4899    self._input_datasets = [input_dataset, dataset_to_concatenate]
4900    self._name = name
4901    # pylint: disable=protected-access
4902    variant_tensor = gen_dataset_ops.concatenate_dataset(
4903        input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
4904        **self._common_args)
4905    # pylint: enable=protected-access
4906    super(ConcatenateDataset, self).__init__(variant_tensor)
4907
4908  def _inputs(self):
4909    return self._input_datasets
4910
4911  @property
4912  def element_spec(self):
4913    return self._structure
4914
4915
4916class RepeatDataset(UnaryUnchangedStructureDataset):
4917  """A `Dataset` that repeats its input several times."""
4918
4919  def __init__(self, input_dataset, count, name=None):
4920    """See `Dataset.repeat()` for details."""
4921    self._input_dataset = input_dataset
4922    if count is None:
4923      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
4924    else:
4925      self._count = ops.convert_to_tensor(
4926          count, dtype=dtypes.int64, name="count")
4927    self._name = name
4928    variant_tensor = gen_dataset_ops.repeat_dataset(
4929        input_dataset._variant_tensor,  # pylint: disable=protected-access
4930        count=self._count,
4931        **self._common_args)
4932    super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
4933
4934
4935class RangeDataset(DatasetSource):
4936  """A `Dataset` of a step separated range of values."""
4937
4938  def __init__(self, *args, **kwargs):
4939    """See `Dataset.range()` for details."""
4940    self._parse_args(*args, **kwargs)
4941    self._structure = tensor_spec.TensorSpec([], self._output_type)
4942    variant_tensor = gen_dataset_ops.range_dataset(
4943        start=self._start,
4944        stop=self._stop,
4945        step=self._step,
4946        **self._common_args)
4947    super(RangeDataset, self).__init__(variant_tensor)
4948
4949  def _parse_args(self, *args, **kwargs):
4950    """Parse arguments according to the same rules as the `range()` builtin."""
4951    if len(args) == 1:
4952      self._start = self._build_tensor(0, "start")
4953      self._stop = self._build_tensor(args[0], "stop")
4954      self._step = self._build_tensor(1, "step")
4955    elif len(args) == 2:
4956      self._start = self._build_tensor(args[0], "start")
4957      self._stop = self._build_tensor(args[1], "stop")
4958      self._step = self._build_tensor(1, "step")
4959    elif len(args) == 3:
4960      self._start = self._build_tensor(args[0], "start")
4961      self._stop = self._build_tensor(args[1], "stop")
4962      self._step = self._build_tensor(args[2], "step")
4963    else:
4964      raise ValueError(f"Invalid `args`. The lenght of `args` should be "
4965                       f"between 1 and 3 but was {len(args)}.")
4966    if "output_type" in kwargs:
4967      self._output_type = kwargs["output_type"]
4968    else:
4969      self._output_type = dtypes.int64
4970    self._name = kwargs["name"] if "name" in kwargs else None
4971
4972  def _build_tensor(self, int64_value, name):
4973    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
4974
4975  @property
4976  def element_spec(self):
4977    return self._structure
4978
4979
4980class CacheDataset(UnaryUnchangedStructureDataset):
4981  """A `Dataset` that caches elements of its input."""
4982
4983  def __init__(self, input_dataset, filename, name=None):
4984    """See `Dataset.cache()` for details."""
4985    self._input_dataset = input_dataset
4986    self._filename = ops.convert_to_tensor(
4987        filename, dtype=dtypes.string, name="filename")
4988    self._name = name
4989    if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
4990      variant_tensor = gen_dataset_ops.cache_dataset_v2(
4991          input_dataset._variant_tensor,  # pylint: disable=protected-access
4992          filename=self._filename,
4993          cache=gen_dataset_ops.dummy_memory_cache(),
4994          **self._common_args)
4995    else:
4996      variant_tensor = gen_dataset_ops.cache_dataset(
4997          input_dataset._variant_tensor,  # pylint: disable=protected-access
4998          filename=self._filename,
4999          **self._common_args)
5000    super(CacheDataset, self).__init__(input_dataset, variant_tensor)
5001
5002
5003class ShuffleDataset(UnaryUnchangedStructureDataset):
5004  """A `Dataset` that randomly shuffles the elements of its input."""
5005
5006  def __init__(self,
5007               input_dataset,
5008               buffer_size,
5009               seed=None,
5010               reshuffle_each_iteration=None,
5011               name=None):
5012    """See `Dataset.shuffle()` for details."""
5013    self._input_dataset = input_dataset
5014    self._buffer_size = ops.convert_to_tensor(
5015        buffer_size, dtype=dtypes.int64, name="buffer_size")
5016    self._seed, self._seed2 = random_seed.get_seed(seed)
5017    if reshuffle_each_iteration is None:
5018      reshuffle_each_iteration = True
5019    self._reshuffle_each_iteration = reshuffle_each_iteration
5020    self._name = name
5021
5022    if (tf2.enabled() and
5023        (context.executing_eagerly() or ops.inside_function())):
5024      variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
5025          input_dataset._variant_tensor,  # pylint: disable=protected-access
5026          buffer_size=self._buffer_size,
5027          seed=self._seed,
5028          seed2=self._seed2,
5029          seed_generator=gen_dataset_ops.dummy_seed_generator(),
5030          reshuffle_each_iteration=self._reshuffle_each_iteration,
5031          **self._common_args)
5032    else:
5033      variant_tensor = gen_dataset_ops.shuffle_dataset(
5034          input_dataset._variant_tensor,  # pylint: disable=protected-access
5035          buffer_size=self._buffer_size,
5036          seed=self._seed,
5037          seed2=self._seed2,
5038          reshuffle_each_iteration=self._reshuffle_each_iteration,
5039          **self._common_args)
5040    super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
5041
5042
5043class TakeDataset(UnaryUnchangedStructureDataset):
5044  """A `Dataset` containing the first `count` elements from its input."""
5045
5046  def __init__(self, input_dataset, count, name=None):
5047    """See `Dataset.take()` for details."""
5048    self._input_dataset = input_dataset
5049    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
5050    self._name = name
5051    variant_tensor = gen_dataset_ops.take_dataset(
5052        input_dataset._variant_tensor,  # pylint: disable=protected-access
5053        count=self._count,
5054        **self._common_args)
5055    super(TakeDataset, self).__init__(input_dataset, variant_tensor)
5056
5057
5058class SkipDataset(UnaryUnchangedStructureDataset):
5059  """A `Dataset` skipping the first `count` elements from its input."""
5060
5061  def __init__(self, input_dataset, count, name=None):
5062    """See `Dataset.skip()` for details."""
5063    self._input_dataset = input_dataset
5064    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
5065    self._name = name
5066    variant_tensor = gen_dataset_ops.skip_dataset(
5067        input_dataset._variant_tensor,  # pylint: disable=protected-access
5068        count=self._count,
5069        **self._common_args)
5070    super(SkipDataset, self).__init__(input_dataset, variant_tensor)
5071
5072
5073class ShardDataset(UnaryUnchangedStructureDataset):
5074  """A `Dataset` for sharding its input."""
5075
5076  def __init__(self, input_dataset, num_shards, index, name=None):
5077    """See `Dataset.shard()` for details."""
5078    self._input_dataset = input_dataset
5079    self._num_shards = ops.convert_to_tensor(
5080        num_shards, dtype=dtypes.int64, name="num_shards")
5081    self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
5082    self._name = name
5083    variant_tensor = gen_dataset_ops.shard_dataset(
5084        input_dataset._variant_tensor,  # pylint: disable=protected-access
5085        num_shards=self._num_shards,
5086        index=self._index,
5087        **self._common_args)
5088    super(ShardDataset, self).__init__(input_dataset, variant_tensor)
5089
5090
5091class BatchDataset(UnaryDataset):
5092  """A `Dataset` that batches contiguous elements from its input."""
5093
5094  def __init__(self, input_dataset, batch_size, drop_remainder, name=None):
5095    """See `Dataset.batch()` for details."""
5096    self._input_dataset = input_dataset
5097    self._batch_size = ops.convert_to_tensor(
5098        batch_size, dtype=dtypes.int64, name="batch_size")
5099    self._drop_remainder = ops.convert_to_tensor(
5100        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
5101
5102    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
5103    # pylint: disable=protected-access
5104    if constant_drop_remainder:
5105      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
5106      # or `False` (explicitly retaining the remainder).
5107      # pylint: disable=g-long-lambda
5108      constant_batch_size = tensor_util.constant_value(self._batch_size)
5109      self._structure = nest.map_structure(
5110          lambda component_spec: component_spec._batch(constant_batch_size),
5111          input_dataset.element_spec)
5112    else:
5113      self._structure = nest.map_structure(
5114          lambda component_spec: component_spec._batch(None),
5115          input_dataset.element_spec)
5116
5117    self._name = name
5118    variant_tensor = gen_dataset_ops.batch_dataset_v2(
5119        input_dataset._variant_tensor,
5120        batch_size=self._batch_size,
5121        drop_remainder=self._drop_remainder,
5122        **self._common_args)
5123    super(BatchDataset, self).__init__(input_dataset, variant_tensor)
5124
5125  @property
5126  def element_spec(self):
5127    return self._structure
5128
5129
5130class ParallelBatchDataset(UnaryDataset):
5131  """A `Dataset` that batches contiguous elements from its input in parallel."""
5132
5133  def __init__(self,
5134               input_dataset,
5135               batch_size,
5136               drop_remainder,
5137               num_parallel_calls,
5138               deterministic,
5139               name=None):
5140    """See `Dataset.batch()` for details."""
5141    self._input_dataset = input_dataset
5142    self._batch_size = ops.convert_to_tensor(
5143        batch_size, dtype=dtypes.int64, name="batch_size")
5144    self._drop_remainder = ops.convert_to_tensor(
5145        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
5146    self._num_parallel_calls = ops.convert_to_tensor(
5147        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
5148    if deterministic is None:
5149      self._deterministic = "default"
5150    elif deterministic:
5151      self._deterministic = "true"
5152    else:
5153      self._deterministic = "false"
5154
5155    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
5156    # pylint: disable=protected-access
5157    if constant_drop_remainder:
5158      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
5159      # or `False` (explicitly retaining the remainder).
5160      # pylint: disable=g-long-lambda
5161      constant_batch_size = tensor_util.constant_value(self._batch_size)
5162      self._structure = nest.map_structure(
5163          lambda component_spec: component_spec._batch(constant_batch_size),
5164          input_dataset.element_spec)
5165    else:
5166      self._structure = nest.map_structure(
5167          lambda component_spec: component_spec._batch(None),
5168          input_dataset.element_spec)
5169
5170    self._name = name
5171    variant_tensor = gen_dataset_ops.parallel_batch_dataset(
5172        input_dataset._variant_tensor,
5173        batch_size=self._batch_size,
5174        num_parallel_calls=self._num_parallel_calls,
5175        drop_remainder=self._drop_remainder,
5176        deterministic=self._deterministic,
5177        **self._common_args)
5178
5179    super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor)
5180
5181  @property
5182  def element_spec(self):
5183    return self._structure
5184
5185
5186def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
5187  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
5188
5189  Args:
5190    padded_shape: A `tf.TensorShape`.
5191    input_component_shape: A `tf.TensorShape`.
5192
5193  Returns:
5194    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
5195    `False`.
5196  """
5197
5198  if padded_shape.dims is None or input_component_shape.dims is None:
5199    return True
5200  if len(padded_shape.dims) != len(input_component_shape.dims):
5201    return False
5202  for padded_dim, input_dim in zip(
5203      padded_shape.dims, input_component_shape.dims):
5204    if (padded_dim.value is not None and input_dim.value is not None
5205        and padded_dim.value < input_dim.value):
5206      return False
5207  return True
5208
5209
5210def _padded_shape_to_tensor(padded_shape, input_component_shape):
5211  """Converts `padded_shape` to a `tf.Tensor` representing that shape.
5212
5213  Args:
5214    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
5215      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
5216    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
5217      be compatible.
5218
5219  Returns:
5220    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
5221
5222  Raises:
5223    ValueError: If `padded_shape` is not a shape or not compatible with
5224      `input_component_shape`.
5225    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
5226  """
5227  try:
5228    # Try to convert the `padded_shape` to a `tf.TensorShape`
5229    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
5230    # We will return the "canonical" tensor representation, which uses
5231    # `-1` in place of `None`.
5232    ret = ops.convert_to_tensor(
5233        [dim if dim is not None else -1
5234         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
5235  except (TypeError, ValueError) as e:
5236    # The argument was not trivially convertible to a
5237    # `tf.TensorShape`, so fall back on the conversion to tensor
5238    # machinery.
5239    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
5240    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
5241      raise ValueError(
5242          f"Padded shape {padded_shape} must be a `tf.int64` vector tensor, "
5243          f"but its shape was {ret.shape}.") from e
5244    if ret.dtype != dtypes.int64:
5245      raise TypeError(
5246          f"Padded shape {padded_shape} must be a `tf.int64` vector "
5247          f"tensor, but its element type was {ret.dtype.name}.") from e
5248    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
5249
5250  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
5251                                          input_component_shape):
5252    raise ValueError(f"The padded shape {padded_shape_as_shape} is not "
5253                     f"compatible with the shape {input_component_shape} of "
5254                     f"the corresponding input component.")
5255
5256  return ret
5257
5258
5259def _padding_value_to_tensor(value, output_type):
5260  """Converts the padding value to a tensor.
5261
5262  Args:
5263    value: The padding value.
5264    output_type: Its expected dtype.
5265
5266  Returns:
5267    A scalar `Tensor`.
5268
5269  Raises:
5270    ValueError: if the padding value is not a scalar.
5271    TypeError: if the padding value's type does not match `output_type`.
5272  """
5273  value = ops.convert_to_tensor(value, name="padding_value")
5274  if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
5275    raise ValueError(f"Invalid `padding_values`. `padding_values` values "
5276                     f"should be scalars, but got {value.shape}.")
5277  if value.dtype != output_type:
5278    raise TypeError(f"Invalid `padding_values`. `padding_values` values "
5279                    f"type {value.dtype} does not match type {output_type} "
5280                    f"of the corresponding input component.")
5281  return value
5282
5283
5284def _padding_values_or_default(padding_values, input_dataset):
5285  """Returns padding values with None elements replaced with default values."""
5286
5287  def make_zero(t):
5288    if t.base_dtype == dtypes.string:
5289      return ""
5290    elif t.base_dtype == dtypes.variant:
5291      raise TypeError("Unable to create default padding value for a component "
5292                      "of type 'variant'.")
5293    elif t.base_dtype == dtypes.bfloat16:
5294      # Special case `bfloat16` because it is not supported by NumPy.
5295      return constant_op.constant(0, dtype=dtypes.bfloat16)
5296    else:
5297      return np.zeros_like(t.as_numpy_dtype())
5298
5299  def value_or_default(value, default):
5300    return default if value is None else value
5301
5302  default_padding = nest.map_structure(
5303      make_zero,
5304      get_legacy_output_types(input_dataset))
5305  return nest.map_structure_up_to(padding_values, value_or_default,
5306                                  padding_values, default_padding)
5307
5308
5309class PaddedBatchDataset(UnaryDataset):
5310  """A `Dataset` that batches and pads contiguous elements from its input."""
5311
5312  def __init__(self,
5313               input_dataset,
5314               batch_size,
5315               padded_shapes,
5316               padding_values,
5317               drop_remainder,
5318               name=None):
5319    """See `Dataset.batch()` for details."""
5320    self._input_dataset = input_dataset
5321
5322    def check_types(component_spec):
5323      if not isinstance(component_spec, tensor_spec.TensorSpec):
5324        raise TypeError(f"`padded_batch` is only supported for datasets that "
5325                        f"produce tensor elements but the input dataset "
5326                        f"produces elements of unsupported type "
5327                        f"{component_spec.value_type()}.")
5328
5329    nest.map_structure(check_types, input_dataset.element_spec)
5330    self._input_dataset = input_dataset
5331    self._batch_size = ops.convert_to_tensor(
5332        batch_size, dtype=dtypes.int64, name="batch_size")
5333    padding_values = _padding_values_or_default(padding_values, input_dataset)
5334
5335    input_shapes = get_legacy_output_shapes(input_dataset)
5336    flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
5337
5338    flat_padded_shapes_as_tensors = []
5339
5340    for input_component_shape, padded_shape in zip(
5341        nest.flatten(input_shapes), flat_padded_shapes):
5342      flat_padded_shapes_as_tensors.append(
5343          _padded_shape_to_tensor(padded_shape, input_component_shape))
5344
5345    self._padded_shapes = nest.pack_sequence_as(input_shapes,
5346                                                flat_padded_shapes_as_tensors)
5347
5348    # If padding_values is a single element and input_shapes is a structure,
5349    # "broadcast" padding_values to the same structure as input_shapes.
5350    if nest.is_nested(input_shapes) and not nest.is_nested(padding_values):
5351      padding_values = nest.map_structure(lambda _: padding_values,
5352                                          input_shapes)
5353
5354    self._padding_values = nest.map_structure_up_to(
5355        input_shapes, _padding_value_to_tensor, padding_values,
5356        get_legacy_output_types(input_dataset))
5357    self._drop_remainder = ops.convert_to_tensor(
5358        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
5359
5360    def _padded_shape_to_batch_shape(s):
5361      return tensor_shape.TensorShape([
5362          tensor_util.constant_value(self._batch_size)
5363          if smart_cond.smart_constant_value(self._drop_remainder) else None
5364      ]).concatenate(tensor_util.constant_value_as_shape(s))
5365
5366    output_shapes = nest.map_structure(
5367        _padded_shape_to_batch_shape, self._padded_shapes)
5368    self._structure = structure.convert_legacy_structure(
5369        get_legacy_output_types(self._input_dataset), output_shapes,
5370        get_legacy_output_classes(self._input_dataset))
5371
5372    self._name = name
5373    # pylint: disable=protected-access
5374    variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
5375        input_dataset._variant_tensor,  # pylint: disable=protected-access
5376        batch_size=self._batch_size,
5377        padded_shapes=[
5378            ops.convert_to_tensor(s, dtype=dtypes.int64)
5379            for s in nest.flatten(self._padded_shapes)
5380        ],
5381        padding_values=nest.flatten(self._padding_values),
5382        drop_remainder=self._drop_remainder,
5383        output_shapes=structure.get_flat_tensor_shapes(self._structure),
5384        metadata=self._metadata.SerializeToString())
5385    super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
5386
5387  @property
5388  def element_spec(self):
5389    return self._structure
5390
5391
5392class MapDataset(UnaryDataset):
5393  """A `Dataset` that maps a function over elements in its input."""
5394
5395  def __init__(self,
5396               input_dataset,
5397               map_func,
5398               use_inter_op_parallelism=True,
5399               preserve_cardinality=False,
5400               use_legacy_function=False,
5401               name=None):
5402    """See `Dataset.map()` for details."""
5403    self._input_dataset = input_dataset
5404    self._use_inter_op_parallelism = use_inter_op_parallelism
5405    self._preserve_cardinality = preserve_cardinality
5406    self._map_func = structured_function.StructuredFunctionWrapper(
5407        map_func,
5408        self._transformation_name(),
5409        dataset=input_dataset,
5410        use_legacy_function=use_legacy_function)
5411    self._name = name
5412    variant_tensor = gen_dataset_ops.map_dataset(
5413        input_dataset._variant_tensor,  # pylint: disable=protected-access
5414        self._map_func.function.captured_inputs,
5415        f=self._map_func.function,
5416        use_inter_op_parallelism=self._use_inter_op_parallelism,
5417        preserve_cardinality=self._preserve_cardinality,
5418        **self._common_args)
5419    super(MapDataset, self).__init__(input_dataset, variant_tensor)
5420
5421  def _functions(self):
5422    return [self._map_func]
5423
5424  @property
5425  def element_spec(self):
5426    return self._map_func.output_structure
5427
5428  def _transformation_name(self):
5429    return "Dataset.map()"
5430
5431
5432class ParallelMapDataset(UnaryDataset):
5433  """A `Dataset` that maps a function over elements in its input in parallel."""
5434
5435  def __init__(self,
5436               input_dataset,
5437               map_func,
5438               num_parallel_calls,
5439               deterministic,
5440               use_inter_op_parallelism=True,
5441               preserve_cardinality=False,
5442               use_legacy_function=False,
5443               name=None):
5444    """See `Dataset.map()` for details."""
5445    self._input_dataset = input_dataset
5446    self._use_inter_op_parallelism = use_inter_op_parallelism
5447    self._map_func = structured_function.StructuredFunctionWrapper(
5448        map_func,
5449        self._transformation_name(),
5450        dataset=input_dataset,
5451        use_legacy_function=use_legacy_function)
5452    if deterministic is None:
5453      self._deterministic = "default"
5454    elif deterministic:
5455      self._deterministic = "true"
5456    else:
5457      self._deterministic = "false"
5458    self._preserve_cardinality = preserve_cardinality
5459    self._num_parallel_calls = ops.convert_to_tensor(
5460        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
5461    self._name = name
5462    variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
5463        input_dataset._variant_tensor,  # pylint: disable=protected-access
5464        self._map_func.function.captured_inputs,
5465        f=self._map_func.function,
5466        num_parallel_calls=self._num_parallel_calls,
5467        deterministic=self._deterministic,
5468        use_inter_op_parallelism=self._use_inter_op_parallelism,
5469        preserve_cardinality=self._preserve_cardinality,
5470        **self._common_args)
5471    super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
5472
5473  def _functions(self):
5474    return [self._map_func]
5475
5476  @property
5477  def element_spec(self):
5478    return self._map_func.output_structure
5479
5480  def _transformation_name(self):
5481    return "Dataset.map()"
5482
5483
5484class FlatMapDataset(UnaryDataset):
5485  """A `Dataset` that maps a function over its input and flattens the result."""
5486
5487  def __init__(self, input_dataset, map_func, name=None):
5488    """See `Dataset.flat_map()` for details."""
5489    self._input_dataset = input_dataset
5490    self._map_func = structured_function.StructuredFunctionWrapper(
5491        map_func, self._transformation_name(), dataset=input_dataset)
5492    if not isinstance(self._map_func.output_structure, DatasetSpec):
5493      raise TypeError(
5494          "The `map_func` argument must return a `Dataset` object. Got "
5495          f"{_get_type(self._map_func.output_structure)!r}.")
5496    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5497    self._name = name
5498    variant_tensor = gen_dataset_ops.flat_map_dataset(
5499        input_dataset._variant_tensor,  # pylint: disable=protected-access
5500        self._map_func.function.captured_inputs,
5501        f=self._map_func.function,
5502        **self._common_args)
5503    super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
5504
5505  def _functions(self):
5506    return [self._map_func]
5507
5508  @property
5509  def element_spec(self):
5510    return self._structure
5511
5512  def _transformation_name(self):
5513    return "Dataset.flat_map()"
5514
5515
5516class InterleaveDataset(UnaryDataset):
5517  """A `Dataset` that interleaves the result of transformed inputs."""
5518
5519  def __init__(self,
5520               input_dataset,
5521               map_func,
5522               cycle_length,
5523               block_length,
5524               name=None):
5525    """See `Dataset.interleave()` for details."""
5526
5527    self._input_dataset = input_dataset
5528    self._map_func = structured_function.StructuredFunctionWrapper(
5529        map_func, self._transformation_name(), dataset=input_dataset)
5530    if not isinstance(self._map_func.output_structure, DatasetSpec):
5531      raise TypeError(
5532          "The `map_func` argument must return a `Dataset` object. Got "
5533          f"{_get_type(self._map_func.output_structure)!r}.")
5534    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5535    self._cycle_length = ops.convert_to_tensor(
5536        cycle_length, dtype=dtypes.int64, name="cycle_length")
5537    self._block_length = ops.convert_to_tensor(
5538        block_length, dtype=dtypes.int64, name="block_length")
5539    self._name = name
5540    variant_tensor = gen_dataset_ops.interleave_dataset(
5541        input_dataset._variant_tensor,  # pylint: disable=protected-access
5542        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
5543        self._cycle_length,
5544        self._block_length,
5545        f=self._map_func.function,
5546        **self._common_args)
5547    super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
5548
5549  def _functions(self):
5550    return [self._map_func]
5551
5552  @property
5553  def element_spec(self):
5554    return self._structure
5555
5556  def _transformation_name(self):
5557    return "Dataset.interleave()"
5558
5559
5560class ParallelInterleaveDataset(UnaryDataset):
5561  """A `Dataset` that maps a function over its input and interleaves the result."""
5562
5563  def __init__(self,
5564               input_dataset,
5565               map_func,
5566               cycle_length,
5567               block_length,
5568               num_parallel_calls,
5569               buffer_output_elements=AUTOTUNE,
5570               prefetch_input_elements=AUTOTUNE,
5571               deterministic=None,
5572               name=None):
5573    """See `Dataset.interleave()` for details."""
5574    self._input_dataset = input_dataset
5575    self._map_func = structured_function.StructuredFunctionWrapper(
5576        map_func, self._transformation_name(), dataset=input_dataset)
5577    if not isinstance(self._map_func.output_structure, DatasetSpec):
5578      raise TypeError(
5579          "The `map_func` argument must return a `Dataset` object. Got "
5580          f"{_get_type(self._map_func.output_structure)!r}.")
5581    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5582    self._cycle_length = ops.convert_to_tensor(
5583        cycle_length, dtype=dtypes.int64, name="cycle_length")
5584    self._block_length = ops.convert_to_tensor(
5585        block_length, dtype=dtypes.int64, name="block_length")
5586    self._buffer_output_elements = ops.convert_to_tensor(
5587        buffer_output_elements,
5588        dtype=dtypes.int64,
5589        name="buffer_output_elements")
5590    self._prefetch_input_elements = ops.convert_to_tensor(
5591        prefetch_input_elements,
5592        dtype=dtypes.int64,
5593        name="prefetch_input_elements")
5594
5595    self._num_parallel_calls = ops.convert_to_tensor(
5596        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
5597    if deterministic is None:
5598      deterministic_string = "default"
5599    elif deterministic:
5600      deterministic_string = "true"
5601    else:
5602      deterministic_string = "false"
5603
5604    self._name = name
5605    variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
5606        input_dataset._variant_tensor,  # pylint: disable=protected-access
5607        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
5608        self._cycle_length,
5609        self._block_length,
5610        self._buffer_output_elements,
5611        self._prefetch_input_elements,
5612        self._num_parallel_calls,
5613        f=self._map_func.function,
5614        deterministic=deterministic_string,
5615        **self._common_args)
5616    super(ParallelInterleaveDataset, self).__init__(input_dataset,
5617                                                    variant_tensor)
5618
5619  def _functions(self):
5620    return [self._map_func]
5621
5622  @property
5623  def element_spec(self):
5624    return self._structure
5625
5626  def _transformation_name(self):
5627    return "Dataset.interleave()"
5628
5629
5630class FilterDataset(UnaryUnchangedStructureDataset):
5631  """A `Dataset` that filters its input according to a predicate function."""
5632
5633  def __init__(self,
5634               input_dataset,
5635               predicate,
5636               use_legacy_function=False,
5637               name=None):
5638    """See `Dataset.filter()` for details."""
5639    self._input_dataset = input_dataset
5640    wrapped_func = structured_function.StructuredFunctionWrapper(
5641        predicate,
5642        self._transformation_name(),
5643        dataset=input_dataset,
5644        use_legacy_function=use_legacy_function)
5645    if not wrapped_func.output_structure.is_compatible_with(
5646        tensor_spec.TensorSpec([], dtypes.bool)):
5647      raise ValueError(f"Invalid `predicate`. `predicate` must return a "
5648                       f"`tf.bool` scalar tensor, but its return type is "
5649                       f"{wrapped_func.output_structure}.")
5650    self._predicate = wrapped_func
5651    self._name = name
5652    variant_tensor = gen_dataset_ops.filter_dataset(
5653        input_dataset._variant_tensor,  # pylint: disable=protected-access
5654        other_arguments=self._predicate.function.captured_inputs,
5655        predicate=self._predicate.function,
5656        **self._common_args)
5657    super(FilterDataset, self).__init__(input_dataset, variant_tensor)
5658
5659  def _functions(self):
5660    return [self._predicate]
5661
5662  def _transformation_name(self):
5663    return "Dataset.filter()"
5664
5665
5666class PrefetchDataset(UnaryUnchangedStructureDataset):
5667  """A `Dataset` that asynchronously prefetches its input."""
5668
5669  def __init__(self, input_dataset, buffer_size, slack_period=None, name=None):
5670    """See `Dataset.prefetch()` for details."""
5671    self._input_dataset = input_dataset
5672    if buffer_size is None:
5673      buffer_size = AUTOTUNE
5674    self._buffer_size = ops.convert_to_tensor(
5675        buffer_size, dtype=dtypes.int64, name="buffer_size")
5676    self._name = name
5677    # pylint: disable=protected-access
5678    # We colocate the prefetch dataset with its input as this collocation only
5679    # happens automatically in graph mode.
5680    with ops.colocate_with(input_dataset._variant_tensor):
5681      variant_tensor = gen_dataset_ops.prefetch_dataset(
5682          input_dataset._variant_tensor,
5683          buffer_size=self._buffer_size,
5684          slack_period=slack_period,
5685          **self._common_args)
5686    super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
5687
5688
5689class WindowDataset(UnaryDataset):
5690  """A dataset that creates window datasets from the input elements."""
5691
5692  def __init__(self,
5693               input_dataset,
5694               size,
5695               shift,
5696               stride,
5697               drop_remainder,
5698               name=None):
5699    """See `window()` for more details."""
5700    self._input_dataset = input_dataset
5701    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
5702    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
5703    self._stride = ops.convert_to_tensor(
5704        stride, dtype=dtypes.int64, name="stride")
5705    self._drop_remainder = ops.convert_to_tensor(
5706        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
5707    self._structure = nest.pack_sequence_as(
5708        get_legacy_output_classes(input_dataset), [
5709            DatasetSpec(  # pylint: disable=g-complex-comprehension
5710                structure.convert_legacy_structure(
5711                    output_type, output_shape, output_class))
5712            for output_class, output_shape, output_type in zip(
5713                nest.flatten(get_legacy_output_classes(input_dataset)),
5714                nest.flatten(get_legacy_output_shapes(input_dataset)),
5715                nest.flatten(get_legacy_output_types(input_dataset)))
5716        ])
5717    self._name = name
5718    variant_tensor = gen_dataset_ops.window_dataset(
5719        input_dataset._variant_tensor,  # pylint: disable=protected-access
5720        size=self._size,
5721        shift=self._shift,
5722        stride=self._stride,
5723        drop_remainder=self._drop_remainder,
5724        **self._common_args)
5725    super(WindowDataset, self).__init__(input_dataset, variant_tensor)
5726
5727  @property
5728  def element_spec(self):
5729    return self._structure
5730
5731
5732class _OptionsDataset(UnaryUnchangedStructureDataset):
5733  """An identity `Dataset` that stores options."""
5734
5735  def __init__(self, input_dataset, options, name=None):
5736    # pylint: disable=protected-access
5737    self._input_dataset = input_dataset
5738    options_pb = dataset_options_pb2.Options()
5739    options_pb.CopyFrom(options._to_proto())
5740    self._name = name
5741    with ops.colocate_with(input_dataset._variant_tensor):
5742      variant_tensor = gen_dataset_ops.options_dataset(
5743          input_dataset._variant_tensor, options_pb.SerializeToString(),
5744          **self._common_args)
5745    super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
5746
5747    if self._options_attr:
5748      self._options_attr._set_mutable(True)
5749      self._options_attr = self._options_attr.merge(options)
5750    else:
5751      self._options_attr = options
5752    self._options_attr._set_mutable(False)
5753
5754
5755def normalize_to_dense(dataset):
5756  """Normalizes non-tensor components in a dataset to dense representations.
5757
5758  This is necessary for dataset transformations that slice along the batch
5759  dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`.
5760
5761  Args:
5762    dataset: Dataset to normalize.
5763
5764  Returns:
5765    A dataset whose sparse and ragged tensors have been normalized to their
5766    dense representations.
5767  """
5768
5769  # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all
5770  # non-tensor components.
5771  #
5772  # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck.
5773  if structured_function._should_unpack(dataset.element_spec):  # pylint: disable=protected-access
5774
5775    def normalize(*args):
5776      return structure.to_batched_tensor_list(dataset.element_spec, tuple(args))
5777  else:
5778    def normalize(arg):
5779      return structure.to_batched_tensor_list(dataset.element_spec, arg)
5780
5781  normalized_dataset = dataset.map(normalize)
5782
5783  # NOTE(mrry): Our `map()` has lost information about the structure of
5784  # non-tensor components, so re-apply the structure of the original dataset.
5785  return _RestructuredDataset(normalized_dataset, dataset.element_spec)
5786
5787
5788class _RestructuredDataset(UnaryDataset):
5789  """An internal helper for changing the element spec of a dataset."""
5790
5791  def __init__(self, dataset, element_spec):
5792    self._input_dataset = dataset
5793    self._element_spec = element_spec
5794
5795    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
5796    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
5797
5798  @property
5799  def element_spec(self):
5800    return self._element_spec
5801
5802
5803class _UnbatchDataset(UnaryDataset):
5804  """A dataset that splits the elements of its input into multiple elements."""
5805
5806  def __init__(self, input_dataset, name=None):
5807    """See `unbatch()` for more details."""
5808    flat_shapes = input_dataset._flat_shapes  # pylint: disable=protected-access
5809    if any(s.ndims == 0 for s in flat_shapes):
5810      raise ValueError("Cannot unbatch an input with scalar components.")
5811    known_batch_dim = tensor_shape.Dimension(None)
5812    for s in flat_shapes:
5813      try:
5814        known_batch_dim = known_batch_dim.merge_with(s[0])
5815      except ValueError:
5816        raise ValueError(f"`unbatch()` is only supported for datasets of "
5817                         f"elements whose components have a matching leading "
5818                         f"dimension. Encountered both {known_batch_dim} and "
5819                         f"{s[0]}.")
5820    self._input_dataset = input_dataset
5821    self._structure = nest.map_structure(
5822        lambda component_spec: component_spec._unbatch(),  # pylint: disable=protected-access
5823        get_structure(input_dataset))
5824    self._name = name
5825    variant_tensor = ged_ops.unbatch_dataset(
5826        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5827        **self._common_args)
5828    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
5829
5830  @property
5831  def element_spec(self):
5832    return self._structure
5833
5834
5835class _GroupByWindowDataset(UnaryDataset):
5836  """A `Dataset` that groups its input and performs a windowed reduction."""
5837
5838  def __init__(self,
5839               input_dataset,
5840               key_func,
5841               reduce_func,
5842               window_size_func,
5843               name=None):
5844    """See `group_by_window()` for details."""
5845    self._input_dataset = input_dataset
5846    self._make_key_func(key_func, input_dataset)
5847    self._make_reduce_func(reduce_func, input_dataset)
5848    self._make_window_size_func(window_size_func)
5849    self._name = name
5850    variant_tensor = ged_ops.group_by_window_dataset(
5851        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5852        self._key_func.function.captured_inputs,
5853        self._reduce_func.function.captured_inputs,
5854        self._window_size_func.function.captured_inputs,
5855        key_func=self._key_func.function,
5856        reduce_func=self._reduce_func.function,
5857        window_size_func=self._window_size_func.function,
5858        **self._common_args)
5859    super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
5860
5861  def _make_window_size_func(self, window_size_func):
5862    """Make wrapping defun for window_size_func."""
5863
5864    def window_size_func_wrapper(key):
5865      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
5866
5867    self._window_size_func = structured_function.StructuredFunctionWrapper(
5868        window_size_func_wrapper,
5869        self._transformation_name(),
5870        input_structure=tensor_spec.TensorSpec([], dtypes.int64))
5871    if not self._window_size_func.output_structure.is_compatible_with(
5872        tensor_spec.TensorSpec([], dtypes.int64)):
5873      raise ValueError(f"Invalid `window_size_func`. `window_size_func` must "
5874                       f"return a single `tf.int64` scalar tensor but its "
5875                       f"return type is "
5876                       f"{self._window_size_func.output_structure}.")
5877
5878  def _make_key_func(self, key_func, input_dataset):
5879    """Make wrapping defun for key_func."""
5880
5881    def key_func_wrapper(*args):
5882      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
5883
5884    self._key_func = structured_function.StructuredFunctionWrapper(
5885        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
5886    if not self._key_func.output_structure.is_compatible_with(
5887        tensor_spec.TensorSpec([], dtypes.int64)):
5888      raise ValueError(f"Invalid `key_func`. `key_func` must return a single "
5889                       f"`tf.int64` scalar tensor but its return type is "
5890                       f"{self._key_func.output_structure}.")
5891
5892  def _make_reduce_func(self, reduce_func, input_dataset):
5893    """Make wrapping defun for reduce_func."""
5894    nested_dataset = DatasetSpec(input_dataset.element_spec)
5895    input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset)
5896    self._reduce_func = structured_function.StructuredFunctionWrapper(
5897        reduce_func,
5898        self._transformation_name(),
5899        input_structure=input_structure)
5900    if not isinstance(self._reduce_func.output_structure, DatasetSpec):
5901      raise TypeError(f"Invalid `reduce_func`. `reduce_func` must return a "
5902                      f"single `tf.data.Dataset` object but its return type "
5903                      f"is {self._reduce_func.output_structure}.")
5904    # pylint: disable=protected-access
5905    self._element_spec = (self._reduce_func.output_structure._element_spec)
5906
5907  @property
5908  def element_spec(self):
5909    return self._element_spec
5910
5911  def _functions(self):
5912    return [self._key_func, self._reduce_func, self._window_size_func]
5913
5914  def _transformation_name(self):
5915    return "Dataset.group_by_window()"
5916
5917
5918class RandomDataset(DatasetSource):
5919  """A `Dataset` of pseudorandom values."""
5920
5921  def __init__(self, seed=None, name=None):
5922    """A `Dataset` of pseudorandom values."""
5923    self._seed, self._seed2 = random_seed.get_seed(seed)
5924    self._name = name
5925    variant_tensor = ged_ops.random_dataset(
5926        seed=self._seed, seed2=self._seed2, **self._common_args)
5927    super(RandomDataset, self).__init__(variant_tensor)
5928
5929  @property
5930  def element_spec(self):
5931    return tensor_spec.TensorSpec([], dtypes.int64)
5932
5933
5934def _get_prob_original_static(initial_dist_t, target_dist_t):
5935  """Returns the static probability of sampling from the original.
5936
5937  `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
5938  an Op that it isn't defined for. We have some custom logic to avoid this.
5939
5940  Args:
5941    initial_dist_t: A tensor of the initial distribution.
5942    target_dist_t: A tensor of the target distribution.
5943
5944  Returns:
5945    The probability of sampling from the original distribution as a constant,
5946    if it is a constant, or `None`.
5947  """
5948  init_static = tensor_util.constant_value(initial_dist_t)
5949  target_static = tensor_util.constant_value(target_dist_t)
5950
5951  if init_static is None or target_static is None:
5952    return None
5953  else:
5954    return np.min(target_static / init_static)
5955
5956
5957def _filter_ds(dataset,
5958               acceptance_dist_ds,
5959               initial_dist_ds,
5960               class_func,
5961               seed,
5962               name=None):
5963  """Filters a dataset based on per-class acceptance probabilities.
5964
5965  Args:
5966    dataset: The dataset to be filtered.
5967    acceptance_dist_ds: A dataset of acceptance probabilities.
5968    initial_dist_ds: A dataset of the initial probability distribution, given or
5969      estimated.
5970    class_func: A function mapping an element of the input dataset to a scalar
5971      `tf.int32` tensor. Values should be in `[0, num_classes)`.
5972    seed: (Optional.) Python integer seed for the resampler.
5973    name: (Optional.) A name for the tf.data operation.
5974
5975  Returns:
5976    A dataset of (class value, data) after filtering.
5977  """
5978
5979  def maybe_warn_on_large_rejection(accept_dist, initial_dist):
5980    proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
5981    return control_flow_ops.cond(
5982        math_ops.less(proportion_rejected, .5),
5983        lambda: accept_dist,
5984        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
5985            accept_dist, [proportion_rejected, initial_dist, accept_dist],
5986            message="Proportion of examples rejected by sampler is high: ",
5987            summarize=100,
5988            first_n=10))
5989
5990  acceptance_dist_ds = (
5991      DatasetV2.zip((acceptance_dist_ds, initial_dist_ds),
5992                    name=name).map(maybe_warn_on_large_rejection, name=name))
5993
5994  def _gather_and_copy(acceptance_prob, data):
5995    if isinstance(data, tuple):
5996      class_val = class_func(*data)
5997    else:
5998      class_val = class_func(data)
5999    return class_val, array_ops.gather(acceptance_prob, class_val), data
6000
6001  current_probabilities_and_class_and_data_ds = DatasetV2.zip(
6002      (acceptance_dist_ds, dataset), name=name).map(
6003          _gather_and_copy, name=name)
6004
6005  def _reject(unused_class_val, p, unused_data):
6006    return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p
6007
6008  filtered_ds = current_probabilities_and_class_and_data_ds.filter(
6009      _reject, name=name)
6010  return filtered_ds.map(
6011      lambda class_value, _, data: (class_value, data), name=name)
6012
6013
6014# pylint: disable=missing-function-docstring
6015def _estimate_initial_dist_ds(target_dist_t,
6016                              class_values_ds,
6017                              dist_estimation_batch_size=32,
6018                              smoothing_constant=10,
6019                              name=None):
6020  num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
6021  initial_examples_per_class_seen = array_ops.fill([num_classes],
6022                                                   np.int64(smoothing_constant))
6023
6024  def update_estimate_and_tile(num_examples_per_class_seen, c):
6025    updated_examples_per_class_seen, dist = _estimate_data_distribution(
6026        c, num_examples_per_class_seen)
6027    tiled_dist = array_ops.tile(
6028        array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
6029    return updated_examples_per_class_seen, tiled_dist
6030
6031  initial_dist_ds = (
6032      class_values_ds.batch(dist_estimation_batch_size, name=name).scan(
6033          initial_examples_per_class_seen, update_estimate_and_tile,
6034          name=name).unbatch(name=name))
6035
6036  return initial_dist_ds
6037
6038
6039def _get_target_to_initial_ratio(initial_probs, target_probs):
6040  # Add tiny to initial_probs to avoid divide by zero.
6041  denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
6042  return target_probs / denom
6043
6044
6045def _estimate_data_distribution(c, num_examples_per_class_seen):
6046  """Estimate data distribution as labels are seen.
6047
6048  Args:
6049    c: The class labels.  Type `int32`, shape `[batch_size]`.
6050    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing
6051      counts.
6052
6053  Returns:
6054    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
6055      `[num_classes]`.
6056    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
6057  """
6058  num_classes = num_examples_per_class_seen.get_shape()[0]
6059  # Update the class-count based on what labels are seen in batch.
6060  num_examples_per_class_seen = math_ops.add(
6061      num_examples_per_class_seen,
6062      math_ops.reduce_sum(
6063          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
6064  init_prob_estimate = math_ops.truediv(
6065      num_examples_per_class_seen,
6066      math_ops.reduce_sum(num_examples_per_class_seen))
6067  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
6068  return num_examples_per_class_seen, dist
6069
6070
6071def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
6072  """Calculates the acceptance probabilities and mixing ratio.
6073
6074  In this case, we assume that we can *either* sample from the original data
6075  distribution with probability `m`, or sample from a reshaped distribution
6076  that comes from rejection sampling on the original distribution. This
6077  rejection sampling is done on a per-class basis, with `a_i` representing the
6078  probability of accepting data from class `i`.
6079
6080  This method is based on solving the following analysis for the reshaped
6081  distribution:
6082
6083  Let F be the probability of a rejection (on any example).
6084  Let p_i be the proportion of examples in the data in class i (init_probs)
6085  Let a_i is the rate the rejection sampler should *accept* class i
6086  Let t_i is the target proportion in the minibatches for class i (target_probs)
6087
6088  ```
6089  F = sum_i(p_i * (1-a_i))
6090    = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
6091  ```
6092
6093  An example with class `i` will be accepted if `k` rejections occur, then an
6094  example with class `i` is seen by the rejector, and it is accepted. This can
6095  be written as follows:
6096
6097  ```
6098  t_i = sum_k=0^inf(F^k * p_i * a_i)
6099      = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
6100      = p_i * a_i / sum_j(p_j * a_j)        using F from above
6101  ```
6102
6103  Note that the following constraints hold:
6104  ```
6105  0 <= p_i <= 1, sum_i(p_i) = 1
6106  0 <= a_i <= 1
6107  0 <= t_i <= 1, sum_i(t_i) = 1
6108  ```
6109
6110  A solution for a_i in terms of the other variables is the following:
6111    ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
6112
6113  If we try to minimize the amount of data rejected, we get the following:
6114
6115  M_max = max_i [ t_i / p_i ]
6116  M_min = min_i [ t_i / p_i ]
6117
6118  The desired probability of accepting data if it comes from class `i`:
6119
6120  a_i = (t_i/p_i - m) / (M_max - m)
6121
6122  The desired probability of pulling a data element from the original dataset,
6123  rather than the filtered one:
6124
6125  m = M_min
6126
6127  Args:
6128    initial_probs: A Tensor of the initial probability distribution, given or
6129      estimated.
6130    target_probs: A Tensor of the corresponding classes.
6131
6132  Returns:
6133    (A 1D Tensor with the per-class acceptance probabilities, the desired
6134    probability of pull from the original distribution.)
6135  """
6136  ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
6137  max_ratio = math_ops.reduce_max(ratio_l)
6138  min_ratio = math_ops.reduce_min(ratio_l)
6139
6140  # Target prob to sample from original distribution.
6141  m = min_ratio
6142
6143  # TODO(joelshor): Simplify fraction, if possible.
6144  a_i = (ratio_l - m) / (max_ratio - m)
6145  return a_i, m
6146
6147
6148class _TakeWhileDataset(UnaryUnchangedStructureDataset):
6149  """A dataset that stops iteration when `predicate` returns false."""
6150
6151  def __init__(self, input_dataset, predicate, name=None):
6152    """See `take_while()` for details."""
6153
6154    self._input_dataset = input_dataset
6155    wrapped_func = structured_function.StructuredFunctionWrapper(
6156        predicate, self._transformation_name(), dataset=self._input_dataset)
6157
6158    if not wrapped_func.output_structure.is_compatible_with(
6159        tensor_spec.TensorSpec([], dtypes.bool)):
6160      raise ValueError(f"Invalid `predicate`. `predicate` must return a "
6161                       f"`tf.bool` scalar tensor but its return type is"
6162                       f"{wrapped_func.output_structure}.")
6163
6164    self._predicate = wrapped_func
6165    self._name = name
6166    variant_tensor = ged_ops.take_while_dataset(
6167        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
6168        other_arguments=self._predicate.function.captured_inputs,
6169        predicate=self._predicate.function,
6170        **self._common_args)
6171    super(_TakeWhileDataset, self).__init__(input_dataset, variant_tensor)
6172
6173  def _functions(self):
6174    return [self._predicate]
6175
6176  def _transformation_name(self):
6177    return "Dataset.take_while()"
6178
6179
6180class _UniqueDataset(UnaryUnchangedStructureDataset):
6181  """A `Dataset` contains the unique elements from its input."""
6182
6183  def __init__(self, input_dataset, name=None):
6184    """See `unique()` for details."""
6185    self._input_dataset = input_dataset
6186    for ty in nest.flatten(get_legacy_output_types(input_dataset)):
6187      if ty not in (dtypes.int32, dtypes.int64, dtypes.string):
6188        raise TypeError(
6189            f"`unique()` does not support type {ty}, only `tf.int32`, "
6190            f"`tf.int64`, and `tf.string` are supported.")
6191    self._name = name
6192    variant_tensor = ged_ops.unique_dataset(
6193        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
6194        **self._common_args)
6195    super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
6196
6197
6198class _SnapshotDataset(UnaryUnchangedStructureDataset):
6199  """A dataset that allows saving and re-use of already processed data."""
6200
6201  def __init__(self,
6202               input_dataset,
6203               path,
6204               shard_func,
6205               compression=None,
6206               reader_func=None,
6207               pending_snapshot_expiry_seconds=None,
6208               use_legacy_function=False,
6209               name=None):
6210
6211    if reader_func is None:
6212      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
6213          lambda x: x,
6214          cycle_length=multiprocessing.cpu_count(),
6215          num_parallel_calls=AUTOTUNE)
6216
6217    self._input_dataset = input_dataset
6218    self._path = path
6219    self._compression = compression
6220
6221    self._reader_func = structured_function.StructuredFunctionWrapper(
6222        reader_func,
6223        self._transformation_name() + ".reader_func",
6224        # Dataset of datasets of input elements
6225        input_structure=DatasetSpec(DatasetSpec(input_dataset.element_spec)),
6226        use_legacy_function=use_legacy_function)
6227    self._shard_func = structured_function.StructuredFunctionWrapper(
6228        shard_func,
6229        self._transformation_name() + ".shard_func",
6230        dataset=input_dataset,
6231        use_legacy_function=use_legacy_function)
6232
6233    if ((not self._shard_func.output_structure.is_compatible_with(
6234        tensor_spec.TensorSpec([], dtypes.int32))) and
6235        (not self._shard_func.output_structure.is_compatible_with(
6236            tensor_spec.TensorSpec([], dtypes.int64)))):
6237      raise TypeError(f"Invalid `shard_func`. `shard_func` must return "
6238                      f"`tf.int64` scalar tensor but its return type is "
6239                      f"{self._shard_func.output_structure}.")
6240
6241    self._name = name
6242    variant_tensor = ged_ops.snapshot_dataset_v2(
6243        input_dataset._variant_tensor,  # pylint: disable=protected-access
6244        path,
6245        self._reader_func.function.captured_inputs,
6246        self._shard_func.function.captured_inputs,
6247        compression=compression,
6248        reader_func=self._reader_func.function,
6249        shard_func=self._shard_func.function,
6250        **self._common_args)
6251    super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
6252
6253  def _functions(self):
6254    return [self._reader_func, self._shard_func]
6255
6256  def _transformation_name(self):
6257    return "Dataset.snapshot()"
6258
6259
6260class _ScanDataset(UnaryDataset):
6261  """A dataset that scans a function across its input."""
6262
6263  def __init__(self,
6264               input_dataset,
6265               initial_state,
6266               scan_func,
6267               use_default_device=None,
6268               name=None):
6269    """See `scan()` for details."""
6270    self._input_dataset = input_dataset
6271    self._initial_state = structure.normalize_element(initial_state)
6272
6273    # Compute initial values for the state classes, shapes and types based on
6274    # the initial state. The shapes may be refined by running `tf_scan_func` one
6275    # or more times below.
6276    self._state_structure = structure.type_spec_from_value(self._initial_state)
6277
6278    # Iteratively rerun the scan function until reaching a fixed point on
6279    # `self._state_shapes`.
6280    need_to_rerun = True
6281    while need_to_rerun:
6282
6283      wrapped_func = structured_function.StructuredFunctionWrapper(
6284          scan_func,
6285          self._transformation_name(),
6286          input_structure=(self._state_structure, input_dataset.element_spec),
6287          add_to_graph=False)
6288      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
6289              and len(wrapped_func.output_types) == 2):
6290        raise TypeError(f"Invalid `scan_func`. `scan_func` should return a "
6291                        f"pair consisting of new state and the output value "
6292                        f"but its return type is "
6293                        f"{wrapped_func.output_structure}.")
6294
6295      new_state_classes, self._output_classes = wrapped_func.output_classes
6296
6297      # Extract and validate class information from the returned values.
6298      new_state_classes, output_classes = wrapped_func.output_classes
6299      old_state_classes = nest.map_structure(
6300          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
6301          self._state_structure)
6302      for new_state_class, old_state_class in zip(
6303          nest.flatten(new_state_classes), nest.flatten(old_state_classes)):
6304        if not issubclass(new_state_class, old_state_class):
6305          raise TypeError(f"Invalid `scan_func`. The element classes for the "
6306                          f"new state must match the initial state. Expected "
6307                          f"{old_state_classes}, got {new_state_classes}.")
6308
6309      # Extract and validate type information from the returned values.
6310      new_state_types, output_types = wrapped_func.output_types
6311      old_state_types = nest.map_structure(
6312          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
6313          self._state_structure)
6314      for new_state_type, old_state_type in zip(
6315          nest.flatten(new_state_types), nest.flatten(old_state_types)):
6316        if new_state_type != old_state_type:
6317          raise TypeError(f"Invalid `scan_func`. The element types for the "
6318                          f"new state must match the initial state. Expected "
6319                          f"{old_state_types}, got {new_state_types}.")
6320
6321      # Extract shape information from the returned values.
6322      new_state_shapes, output_shapes = wrapped_func.output_shapes
6323      old_state_shapes = nest.map_structure(
6324          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
6325          self._state_structure)
6326      self._element_spec = structure.convert_legacy_structure(
6327          output_types, output_shapes, output_classes)
6328
6329      flat_state_shapes = nest.flatten(old_state_shapes)
6330      flat_new_state_shapes = nest.flatten(new_state_shapes)
6331      weakened_state_shapes = [
6332          original.most_specific_compatible_shape(new)
6333          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
6334      ]
6335
6336      need_to_rerun = False
6337      for original_shape, weakened_shape in zip(flat_state_shapes,
6338                                                weakened_state_shapes):
6339        if original_shape.ndims is not None and (
6340            weakened_shape.ndims is None or
6341            original_shape.as_list() != weakened_shape.as_list()):
6342          need_to_rerun = True
6343          break
6344
6345      if need_to_rerun:
6346        # TODO(b/110122868): Support a "most specific compatible structure"
6347        # method for combining structures, to avoid using legacy structures
6348        # in this method.
6349        self._state_structure = structure.convert_legacy_structure(
6350            old_state_types,
6351            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
6352            old_state_classes)
6353
6354    self._scan_func = wrapped_func
6355    self._scan_func.function.add_to_graph(ops.get_default_graph())
6356
6357    self._name = name
6358    # pylint: disable=protected-access
6359    if use_default_device is not None:
6360      variant_tensor = ged_ops.scan_dataset(
6361          self._input_dataset._variant_tensor,
6362          structure.to_tensor_list(self._state_structure, self._initial_state),
6363          self._scan_func.function.captured_inputs,
6364          f=self._scan_func.function,
6365          preserve_cardinality=True,
6366          use_default_device=use_default_device,
6367          **self._common_args)
6368    else:
6369      variant_tensor = ged_ops.scan_dataset(
6370          self._input_dataset._variant_tensor,
6371          structure.to_tensor_list(self._state_structure, self._initial_state),
6372          self._scan_func.function.captured_inputs,
6373          f=self._scan_func.function,
6374          preserve_cardinality=True,
6375          **self._common_args)
6376    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
6377
6378  def _functions(self):
6379    return [self._scan_func]
6380
6381  @property
6382  def element_spec(self):
6383    return self._element_spec
6384
6385  def _transformation_name(self):
6386    return "Dataset.scan()"
6387
6388
6389class _DirectedInterleaveDataset(DatasetV2):
6390  """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
6391
6392  def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False):
6393    self._selector_input = selector_input
6394    self._data_inputs = list(data_inputs)
6395    self._stop_on_empty_dataset = stop_on_empty_dataset
6396
6397    spec = self._data_inputs[0].element_spec
6398    for i, data_input in enumerate(self._data_inputs[1:]):
6399      def common_supertype(a, b):
6400        result = a.most_specific_common_supertype([b])
6401        if result is None:
6402          raise TypeError(f"No common supertype of {a} and {b}.")
6403        return result
6404
6405      try:
6406        spec = nest.map_structure(common_supertype, spec,
6407                                  data_input.element_spec)
6408      except (TypeError, ValueError) as e:
6409        raise TypeError(f"Invalid `datasets`. `datasets` must have compatible "
6410                        f"element specs.\n Dataset 0 "
6411                        f"element_spec={data_inputs[0].element_spec}.\n"
6412                        f"Dataset {i+1} "
6413                        f"element_spec={data_input.element_spec}.") from e
6414    self._element_spec = spec
6415
6416    # pylint: disable=protected-access
6417    variant_tensor = (
6418        ged_ops.directed_interleave_dataset(
6419            self._selector_input._variant_tensor,
6420            [data_input._variant_tensor for data_input in self._data_inputs],
6421            stop_on_empty_dataset=self._stop_on_empty_dataset,
6422            **self._flat_structure))
6423
6424    super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
6425
6426  def _inputs(self):
6427    return [self._selector_input] + self._data_inputs
6428
6429  @property
6430  def element_spec(self):
6431    return self._element_spec
6432
6433
6434def _apply_rewrite(dataset, rewrite):
6435  # pylint: disable=protected-access
6436  return _VariantDataset(
6437      gen_dataset_ops.rewrite_dataset(dataset._variant_tensor, rewrite,
6438                                      **dataset._flat_structure),
6439      dataset.element_spec)
6440
6441
6442def _collect_resource_inputs(op):
6443  """Collects resource inputs for the given ops (and its variant inputs)."""
6444
6445  def _process(op_queue, seen_ops):
6446    """Processes the next element of the op queue.
6447
6448    Args:
6449      op_queue: Queue of Dataset operations to process.
6450      seen_ops: Already processed set of Operations.
6451
6452    Returns:
6453      A 2-tuple containing sets of resource handles. The first tuple entry
6454      contains read-only handles and the second entry contains read-write
6455      handles.
6456    """
6457
6458    reads = []
6459    writes = []
6460    op = op_queue.pop()
6461    if op in seen_ops:
6462      return reads, writes
6463    seen_ops.add(op)
6464    # TODO(b/150139257): All resource inputs are in writes right now since we
6465    # have not updated the functional ops to set the special attribute that ACD
6466    # uses to figure out which of the op's inputs are read-only.
6467    reads, writes = acd_utils.get_read_write_resource_inputs(op)
6468    # Conservatively assume that any variant inputs are datasets.
6469    op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant)
6470    return reads, writes
6471
6472  op_queue = [op]
6473  seen_ops = set()
6474  all_reads = []
6475  all_writes = []
6476  while op_queue:
6477    reads, writes = _process(op_queue, seen_ops)
6478    all_reads.extend(reads)
6479    all_writes.extend(writes)
6480
6481  return all_reads, all_writes
6482
6483
6484@auto_control_deps.register_acd_resource_resolver
6485def _resource_resolver(op, resource_reads, resource_writes):
6486  """Updates resource inputs for tf.data ops with indirect dependencies."""
6487
6488  updated = False
6489  if op.type in [
6490      "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
6491  ]:
6492    reads, writes = _collect_resource_inputs(op)
6493    for inp in reads:
6494      if inp not in resource_reads:
6495        updated = True
6496        resource_reads.add(inp)
6497    for inp in writes:
6498      if inp not in resource_writes:
6499        updated = True
6500        resource_writes.add(inp)
6501
6502  if op.type in [
6503      "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
6504  ]:
6505    iterator_resource = op.inputs[0]
6506    make_iterator_ops = [
6507        op for op in iterator_resource.consumers() if op.type == "MakeIterator"
6508    ]
6509
6510    if len(make_iterator_ops) == 1:
6511      reads, writes = _collect_resource_inputs(make_iterator_ops[0])
6512      for inp in reads:
6513        if inp not in resource_reads:
6514          updated = True
6515          resource_reads.add(inp)
6516      for inp in writes:
6517        if inp not in resource_writes:
6518          updated = True
6519          resource_writes.add(inp)
6520
6521  return updated
6522
6523
6524DEBUG_MODE = False
6525
6526
6527@tf_export("data.experimental.enable_debug_mode")
6528def enable_debug_mode():
6529  """Enables debug mode for tf.data.
6530
6531  Example usage with pdb module:
6532  ```
6533  import tensorflow as tf
6534  import pdb
6535
6536  tf.data.experimental.enable_debug_mode()
6537
6538  def func(x):
6539    # Python 3.7 and older requires `pdb.Pdb(nosigint=True).set_trace()`
6540    pdb.set_trace()
6541    x = x + 1
6542    return x
6543
6544  dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
6545  dataset = dataset.map(func)
6546
6547  for item in dataset:
6548    print(item)
6549  ```
6550
6551  The effect of debug mode is two-fold:
6552
6553  1) Any transformations that would introduce asynchrony, parallelism, or
6554  non-determinism to the input pipeline execution will be forced to execute
6555  synchronously, sequentially, and deterministically.
6556
6557  2) Any user-defined functions passed into tf.data transformations such as
6558  `map` will be wrapped in `tf.py_function` so that their body is executed
6559  "eagerly" as a Python function as opposed to a traced TensorFlow graph, which
6560  is the default behavior. Note that even when debug mode is enabled, the
6561  user-defined function is still traced  to infer the shape and type of its
6562  outputs; as a consequence, any `print` statements or breakpoints will be
6563  triggered once during the tracing before the actual execution of the input
6564  pipeline.
6565
6566  NOTE: As the debug mode setting affects the construction of the tf.data input
6567  pipeline, it should be enabled before any tf.data definitions.
6568
6569  Raises:
6570    ValueError: When invoked from graph mode.
6571  """
6572  if context.executing_eagerly():
6573    toggle_debug_mode(True)
6574  else:
6575    raise ValueError("`enable_debug_mode() is only supported in eager mode.")
6576
6577
6578def toggle_debug_mode(debug_mode):
6579  global DEBUG_MODE
6580  DEBUG_MODE = debug_mode
6581