1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets.""" 16import abc 17import functools 18import multiprocessing 19import queue 20import threading 21import warnings 22 23import numpy as np 24 25from tensorflow.core.framework import dataset_metadata_pb2 26from tensorflow.core.framework import dataset_options_pb2 27from tensorflow.core.framework import graph_pb2 28from tensorflow.python import tf2 29from tensorflow.python.data.ops import iterator_ops 30from tensorflow.python.data.ops import options as options_lib 31from tensorflow.python.data.ops import structured_function 32from tensorflow.python.data.util import nest 33from tensorflow.python.data.util import random_seed 34from tensorflow.python.data.util import structure 35from tensorflow.python.data.util import traverse 36from tensorflow.python.eager import context 37from tensorflow.python.framework import auto_control_deps 38from tensorflow.python.framework import auto_control_deps_utils as acd_utils 39from tensorflow.python.framework import composite_tensor 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import function 43from tensorflow.python.framework import ops 44from tensorflow.python.framework import random_seed as core_random_seed 45from tensorflow.python.framework import smart_cond 46from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 47from tensorflow.python.framework import tensor_shape 48from tensorflow.python.framework import tensor_spec 49from tensorflow.python.framework import tensor_util 50from tensorflow.python.framework import type_spec 51from tensorflow.python.ops import array_ops 52from tensorflow.python.ops import check_ops 53from tensorflow.python.ops import control_flow_ops 54from tensorflow.python.ops import gen_dataset_ops 55from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 56from tensorflow.python.ops import gen_io_ops 57from tensorflow.python.ops import gen_stateless_random_ops 58from tensorflow.python.ops import logging_ops 59from tensorflow.python.ops import math_ops 60from tensorflow.python.ops import random_ops 61from tensorflow.python.ops import script_ops 62from tensorflow.python.ops import string_ops 63from tensorflow.python.ops.ragged import ragged_tensor 64from tensorflow.python.trackable import asset 65from tensorflow.python.trackable import base as tracking_base 66from tensorflow.python.trackable import resource as resource_lib 67from tensorflow.python.types import trace 68from tensorflow.python.util import deprecation 69from tensorflow.python.util import lazy_loader 70from tensorflow.python.util import nest as tf_nest 71from tensorflow.python.util.compat import collections_abc 72from tensorflow.python.util.tf_export import tf_export 73 74# Symbols forwarded for legacy access through dataset_ops.py. These forwarded 75# symbols can be removed once all internal uses are updated. 76StructuredFunctionWrapper = structured_function.StructuredFunctionWrapper 77 78# Loaded lazily due to a circular dependency (roughly 79# tf.function->wrap_function->dataset->autograph->tf.function). 80# TODO(b/133251390): Use a regular import. 81wrap_function = lazy_loader.LazyLoader( 82 "wrap_function", globals(), 83 "tensorflow.python.eager.wrap_function") 84# Loaded lazily due to a circular dependency 85# dataset_ops->def_function->func_graph->autograph->dataset_ops 86# TODO(kathywu): Use a regular import. 87def_function = lazy_loader.LazyLoader( 88 "def_function", globals(), 89 "tensorflow.python.eager.def_function") 90# Loaded lazily due to a circular dependency 91# dataset_ops->parsing_ops->dataset_ops 92# TODO(varshaan): Use a regular import. 93parsing_ops = lazy_loader.LazyLoader( 94 "parsing_ops", globals(), 95 "tensorflow.python.ops.parsing_ops") 96 97 98ops.NotDifferentiable("ReduceDataset") 99 100# A constant that can be used to enable auto-tuning. 101AUTOTUNE = -1 102tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 103# TODO(b/168128531): Deprecate and remove this symbol. 104tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 105 106# Constants representing infinite and unknown cardinalities. 107INFINITE = -1 108UNKNOWN = -2 109COMPRESSION_GZIP = "GZIP" 110COMPRESSION_SNAPPY = "NONE" 111DATASET_SPEC_FILENAME = "dataset_spec.pb" 112tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE") 113tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN") 114 115 116def _validate_and_encode(name): 117 if not name.isidentifier(): 118 raise ValueError("Invalid `name`. The argument `name` needs to be a valid " 119 "identifier. Value is considered a valid identifier if it " 120 "only contains alphanumeric characters (a-z), (A-Z), and " 121 "(0-9), or underscores (_). A valid identifier cannot " 122 "start with a number, or contain any spaces.") 123 return name.encode("utf-8") 124 125 126def _get_type(value): 127 """Returns the type of `value` if it is a TypeSpec.""" 128 129 if isinstance(value, type_spec.TypeSpec): 130 return value.value_type() 131 else: 132 return type(value) 133 134 135@tf_export("data.Dataset", v1=[]) 136class DatasetV2( 137 collections_abc.Iterable, 138 tracking_base.Trackable, 139 composite_tensor.CompositeTensor, 140 metaclass=abc.ABCMeta): 141 """Represents a potentially large set of elements. 142 143 The `tf.data.Dataset` API supports writing descriptive and efficient input 144 pipelines. `Dataset` usage follows a common pattern: 145 146 1. Create a source dataset from your input data. 147 2. Apply dataset transformations to preprocess the data. 148 3. Iterate over the dataset and process the elements. 149 150 Iteration happens in a streaming fashion, so the full dataset does not need to 151 fit into memory. 152 153 Source Datasets: 154 155 The simplest way to create a dataset is to create it from a python `list`: 156 157 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 158 >>> for element in dataset: 159 ... print(element) 160 tf.Tensor(1, shape=(), dtype=int32) 161 tf.Tensor(2, shape=(), dtype=int32) 162 tf.Tensor(3, shape=(), dtype=int32) 163 164 To process lines from files, use `tf.data.TextLineDataset`: 165 166 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"]) 167 168 To process records written in the `TFRecord` format, use `TFRecordDataset`: 169 170 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"]) 171 172 To create a dataset of all files matching a pattern, use 173 `tf.data.Dataset.list_files`: 174 175 ```python 176 dataset = tf.data.Dataset.list_files("/path/*.txt") 177 ``` 178 179 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator` 180 for more ways to create datasets. 181 182 Transformations: 183 184 Once you have a dataset, you can apply transformations to prepare the data for 185 your model: 186 187 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 188 >>> dataset = dataset.map(lambda x: x*2) 189 >>> list(dataset.as_numpy_iterator()) 190 [2, 4, 6] 191 192 Common Terms: 193 194 **Element**: A single output from calling `next()` on a dataset iterator. 195 Elements may be nested structures containing multiple components. For 196 example, the element `(1, (3, "apple"))` has one tuple nested in another 197 tuple. The components are `1`, `3`, and `"apple"`. 198 199 **Component**: The leaf in the nested structure of an element. 200 201 Supported types: 202 203 Elements can be nested structures of tuples, named tuples, and dictionaries. 204 Note that Python lists are *not* treated as nested structures of components. 205 Instead, lists are converted to tensors and treated as components. For 206 example, the element `(1, [1, 2, 3])` has only two components; the tensor `1` 207 and the tensor `[1, 2, 3]`. Element components can be of any type 208 representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, 209 `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. 210 211 ```python 212 a = 1 # Integer element 213 b = 2.0 # Float element 214 c = (1, 2) # Tuple element with 2 components 215 d = {"a": (2, 2), "b": 3} # Dict element with 3 components 216 Point = collections.namedtuple("Point", ["x", "y"]) 217 e = Point(1, 2) # Named tuple 218 f = tf.data.Dataset.range(10) # Dataset element 219 ``` 220 221 For more information, 222 read [this guide](https://www.tensorflow.org/guide/data). 223 """ 224 225 def __init__(self, variant_tensor): 226 """Creates a DatasetV2 object. 227 228 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 229 take anything in its constructor whereas in the DatasetV2, we expect 230 subclasses to create a variant_tensor and pass it in to the super() call. 231 232 Args: 233 variant_tensor: A DT_VARIANT tensor that represents the dataset. 234 """ 235 self._variant_tensor_attr = variant_tensor 236 self._graph_attr = ops.get_default_graph() 237 238 # Initialize the options for this dataset and its inputs. 239 self._options_attr = options_lib.Options() 240 for input_dataset in self._inputs(): 241 input_options = None 242 if isinstance(input_dataset, DatasetV1): 243 # If the V1 dataset does not have the `_dataset` attribute, we assume it 244 # is a dataset source and hence does not have options. Otherwise, we 245 # grab the options of `_dataset` object 246 if hasattr(input_dataset, "_dataset"): 247 if not isinstance(input_dataset._dataset, DatasetV2): 248 raise TypeError( 249 f"Each input of dataset {type(self)} should be a subclass of " 250 f"`tf.data.Dataset` but encountered " 251 f"{type(input_dataset._dataset)}.") 252 input_options = input_dataset._dataset._options_attr 253 elif isinstance(input_dataset, DatasetV2): 254 input_options = input_dataset._options_attr 255 else: 256 raise TypeError( 257 f"Each input of dataset {type(self)} should be a subclass of " 258 f"`tf.data.Dataset` but encountered {type(input_dataset)}.") 259 if input_options is not None: 260 self._options_attr = self._options_attr.merge(input_options) 261 self._options_attr._set_mutable(False) # pylint: disable=protected-access 262 263 @property 264 def _variant_tensor(self): 265 return self._variant_tensor_attr 266 267 @_variant_tensor.setter 268 def _variant_tensor(self, _): 269 raise ValueError("The `_variant_tensor` property cannot be modified.") 270 271 @deprecation.deprecated_args(None, "Use external_state_policy instead", 272 "allow_stateful") 273 def _as_serialized_graph( 274 self, 275 allow_stateful=None, 276 strip_device_assignment=None, 277 external_state_policy=options_lib.ExternalStatePolicy.WARN): 278 """Produces serialized graph representation of the dataset. 279 280 Args: 281 allow_stateful: If true, we allow stateful ops to be present in the graph 282 def. In that case, the state in these ops would be thrown away. 283 strip_device_assignment: If true, non-local (i.e. job and task) device 284 assignment is stripped from ops in the serialized graph. 285 external_state_policy: The ExternalStatePolicy enum that determines how we 286 handle input pipelines that depend on external state. By default, its 287 set to WARN. 288 289 Returns: 290 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 291 serialized graph. 292 """ 293 if external_state_policy: 294 policy = external_state_policy.value 295 return gen_dataset_ops.dataset_to_graph_v2( 296 self._variant_tensor, 297 external_state_policy=policy, 298 strip_device_assignment=strip_device_assignment) 299 if strip_device_assignment: 300 return gen_dataset_ops.dataset_to_graph( 301 self._variant_tensor, 302 allow_stateful=allow_stateful, 303 strip_device_assignment=strip_device_assignment) 304 return gen_dataset_ops.dataset_to_graph( 305 self._variant_tensor, allow_stateful=allow_stateful) 306 307 def _maybe_track_assets(self, graph_def): 308 """Finds and tracks nodes in `graph_def` that refer to asset files. 309 310 Args: 311 graph_def: Serialized graph representation of this dataset. 312 313 Returns: 314 A dictionary mapping the node name of an asset constant to a tracked 315 `asset.Asset` object. 316 """ 317 asset_tracker = {} 318 for node in graph_def.node: 319 if node.name.startswith("FileIdentity"): 320 asset_tracker[node.input[0]] = None 321 322 if not asset_tracker: 323 return {} 324 325 for node in graph_def.node: 326 if node.name in asset_tracker: 327 tensor_proto = node.attr["value"].tensor 328 with context.eager_mode(), ops.device("CPU"): 329 node_value = parsing_ops.parse_tensor( 330 tensor_proto.SerializeToString(), dtypes.string).numpy() 331 asset_tracker[node.name] = ([ 332 self._track_trackable(asset.Asset(n), 333 name=node.name + "_" + str(i), overwrite=True) 334 for i, n in enumerate(node_value) 335 ]) 336 return asset_tracker 337 338 def _trackable_children(self, 339 save_type=tracking_base.SaveType.CHECKPOINT, 340 **kwargs): 341 if save_type != tracking_base.SaveType.SAVEDMODEL: 342 return {} 343 344 # _trace_variant_creation only works when executing eagerly, so we don't 345 # want to run it in the object initialization. 346 @def_function.function(input_signature=[], autograph=False) 347 def _creator(): 348 resource = self._trace_variant_creation()() # pylint: disable=protected-access 349 return resource 350 _creator.get_concrete_function() # Trigger asset tracking 351 352 children = super(DatasetV2, self)._trackable_children(save_type, **kwargs) 353 children["_variant_tracker"] = _VariantTracker(self._variant_tensor, 354 _creator) 355 return children 356 357 def _trace_variant_creation(self): 358 """Traces a function which outputs a variant `tf.Tensor` for this dataset. 359 360 Note that creating this function involves evaluating an op, and is currently 361 only supported when executing eagerly. 362 363 Returns: 364 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`. 365 """ 366 variant = self._variant_tensor 367 if not isinstance(variant, ops.EagerTensor): 368 raise NotImplementedError( 369 "Constructing a tf.function that reproduces a given dataset is only " 370 "supported for datasets created eagerly. Please file a feature " 371 "request if this is important to you.") 372 with context.eager_mode(), ops.device("CPU"): 373 # pylint: disable=protected-access 374 graph_def = graph_pb2.GraphDef().FromString( 375 self._as_serialized_graph(external_state_policy=options_lib 376 .ExternalStatePolicy.FAIL).numpy()) 377 output_node_names = [] 378 for node in graph_def.node: 379 if node.op == "_Retval": 380 output_node_names = node.input 381 382 if len(output_node_names) != 1: 383 raise AssertionError( 384 f"Dataset graph is expected to only have one return value but found " 385 f"{len(output_node_names)} return values: {output_node_names}.") 386 387 output_node_name = output_node_names[0] 388 389 file_path_nodes = {} 390 # When building a tf.function, track files as `saved_model.Asset`s. 391 if ops.get_default_graph().building_function: 392 asset_tracker = self._maybe_track_assets(graph_def) 393 for key in asset_tracker: 394 assets_list = [ 395 array_ops.expand_dims(asset.asset_path, axis=0) 396 for asset in asset_tracker[key] 397 ] 398 file_path_nodes[key] = array_ops.concat(assets_list, axis=0) 399 400 # Add functions used in this Dataset to the function's graph, since they 401 # need to follow it around (and for example be added to a SavedModel which 402 # references the dataset). 403 variant_function = wrap_function.function_from_graph_def( 404 graph_def, 405 inputs=[], 406 outputs=output_node_name + ":0", 407 captures=file_path_nodes) 408 for used_function in self._functions(): 409 used_function.function.add_to_graph(variant_function.graph) 410 return variant_function 411 412 @abc.abstractmethod 413 def _inputs(self): 414 """Returns a list of the input datasets of the dataset.""" 415 416 raise NotImplementedError(f"{type(self)}._inputs()") 417 418 @property 419 def _graph(self): 420 return self._graph_attr 421 422 @_graph.setter 423 def _graph(self, _): 424 raise ValueError("The `_graph` property cannot be modified.") 425 426 # TODO(jsimsa): Change this to be the transitive closure of functions used 427 # by this dataset and its inputs. 428 def _functions(self): 429 """Returns a list of functions associated with this dataset. 430 431 Returns: 432 A list of `StructuredFunctionWrapper` objects. 433 """ 434 return [] 435 436 def _options(self): 437 """Returns the options tensor for this dataset.""" 438 # pylint: disable=protected-access 439 return gen_dataset_ops.get_options(self._variant_tensor) 440 441 @classmethod 442 def _options_tensor_to_options(cls, serialized_options): 443 """Converts options tensor to tf.data.Options object.""" 444 options = options_lib.Options() 445 if tensor_util.constant_value(serialized_options) is not None: 446 pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value( 447 serialized_options)) 448 options._from_proto(pb) # pylint: disable=protected-access 449 return options 450 451 def options(self): 452 """Returns the options for this dataset and its inputs. 453 454 Returns: 455 A `tf.data.Options` object representing the dataset options. 456 """ 457 if context.executing_eagerly(): 458 options = self._options_tensor_to_options(self._options()) 459 options._set_mutable(False) # pylint: disable=protected-access 460 return options 461 warnings.warn("To make it possible to preserve tf.data options across " 462 "serialization boundaries, their implementation has moved to " 463 "be part of the TensorFlow graph. As a consequence, the " 464 "options value is in general no longer known at graph " 465 "construction time. Invoking this method in graph mode " 466 "retains the legacy behavior of the original implementation, " 467 "but note that the returned value might not reflect the " 468 "actual value of the options.") 469 return self._options_attr 470 471 def _apply_debug_options(self): 472 if DEBUG_MODE: 473 # Disable autotuning and static optimizations that could introduce 474 # parallelism or asynchrony. 475 options = options_lib.Options() 476 options.autotune.enabled = False 477 options.experimental_optimization.filter_parallelization = False 478 options.experimental_optimization.map_and_batch_fusion = False 479 options.experimental_optimization.map_parallelization = False 480 dataset = _OptionsDataset(self, options) 481 else: 482 dataset = self 483 484 return dataset 485 486 def __iter__(self): 487 """Creates an iterator for elements of this dataset. 488 489 The returned iterator implements the Python Iterator protocol. 490 491 Returns: 492 An `tf.data.Iterator` for the elements of this dataset. 493 494 Raises: 495 RuntimeError: If not inside of tf.function and not executing eagerly. 496 """ 497 if context.executing_eagerly() or ops.inside_function(): 498 with ops.colocate_with(self._variant_tensor): 499 return iterator_ops.OwnedIterator(self) 500 else: 501 raise RuntimeError("`tf.data.Dataset` only supports Python-style " 502 "iteration in eager mode or within tf.function.") 503 504 def __bool__(self): 505 return True # Required as __len__ is defined 506 507 __nonzero__ = __bool__ # Python 2 backward compatibility 508 509 def __len__(self): 510 """Returns the length of the dataset if it is known and finite. 511 512 This method requires that you are running in eager mode, and that the 513 length of the dataset is known and non-infinite. When the length may be 514 unknown or infinite, or if you are running in graph mode, use 515 `tf.data.Dataset.cardinality` instead. 516 517 Returns: 518 An integer representing the length of the dataset. 519 520 Raises: 521 RuntimeError: If the dataset length is unknown or infinite, or if eager 522 execution is not enabled. 523 """ 524 if not context.executing_eagerly(): 525 raise TypeError("`tf.data.Dataset` only supports `len` in eager mode. " 526 "Use `tf.data.Dataset.cardinality()` instead.") 527 length = self.cardinality() 528 if length.numpy() == INFINITE: 529 raise TypeError("The dataset is infinite.") 530 if length.numpy() == UNKNOWN: 531 raise TypeError("The dataset length is unknown.") 532 return length 533 534 @abc.abstractproperty 535 def element_spec(self): 536 """The type specification of an element of this dataset. 537 538 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 539 >>> dataset.element_spec 540 TensorSpec(shape=(), dtype=tf.int32, name=None) 541 542 For more information, 543 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 544 545 Returns: 546 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 547 element of this dataset and specifying the type of individual components. 548 """ 549 raise NotImplementedError(f"{type(self)}.element_spec()") 550 551 def __repr__(self): 552 type_ = type(self._dataset if isinstance(self, DatasetV1Adapter) else self) 553 return f"<{type_.__name__} element_spec={self.element_spec}>" 554 555 def __debug_string__(self): 556 """Returns a string showing the type of the dataset and its inputs. 557 558 This string is intended only for debugging purposes, and may change without 559 warning. 560 """ 561 lines = [] 562 to_process = [(self, 0)] # Stack of (dataset, depth) pairs. 563 while to_process: 564 dataset, depth = to_process.pop() 565 lines.append("-"*2*depth + repr(dataset)) 566 to_process.extend([(ds, depth+1) for ds in dataset._inputs()]) # pylint: disable=protected-access 567 return "\n".join(lines) 568 569 def as_numpy_iterator(self): 570 """Returns an iterator which converts all elements of the dataset to numpy. 571 572 Use `as_numpy_iterator` to inspect the content of your dataset. To see 573 element shapes and types, print dataset elements directly instead of using 574 `as_numpy_iterator`. 575 576 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 577 >>> for element in dataset: 578 ... print(element) 579 tf.Tensor(1, shape=(), dtype=int32) 580 tf.Tensor(2, shape=(), dtype=int32) 581 tf.Tensor(3, shape=(), dtype=int32) 582 583 This method requires that you are running in eager mode and the dataset's 584 element_spec contains only `TensorSpec` components. 585 586 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 587 >>> for element in dataset.as_numpy_iterator(): 588 ... print(element) 589 1 590 2 591 3 592 593 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 594 >>> print(list(dataset.as_numpy_iterator())) 595 [1, 2, 3] 596 597 `as_numpy_iterator()` will preserve the nested structure of dataset 598 elements. 599 600 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 601 ... 'b': [5, 6]}) 602 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 603 ... {'a': (2, 4), 'b': 6}] 604 True 605 606 Returns: 607 An iterable over the elements of the dataset, with their tensors converted 608 to numpy arrays. 609 610 Raises: 611 TypeError: if an element contains a non-`Tensor` value. 612 RuntimeError: if eager execution is not enabled. 613 """ 614 if not context.executing_eagerly(): 615 raise RuntimeError("`tf.data.Dataset.as_numpy_iterator()` is only " 616 "supported in eager mode.") 617 for component_spec in nest.flatten(self.element_spec): 618 if not isinstance(component_spec, 619 (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec, 620 structure.NoneTensorSpec)): 621 raise TypeError( 622 f"`tf.data.Dataset.as_numpy_iterator()` is not supported for " 623 f"datasets that produce values of type {component_spec.value_type}") 624 625 return _NumpyIterator(self) 626 627 @property 628 def _flat_shapes(self): 629 """Returns a list `tf.TensorShapes`s for the element tensor representation. 630 631 Returns: 632 A list `tf.TensorShapes`s for the element tensor representation. 633 """ 634 return structure.get_flat_tensor_shapes(self.element_spec) 635 636 @property 637 def _flat_types(self): 638 """Returns a list `tf.DType`s for the element tensor representation. 639 640 Returns: 641 A list `tf.DType`s for the element tensor representation. 642 """ 643 return structure.get_flat_tensor_types(self.element_spec) 644 645 @property 646 def _flat_structure(self): 647 """Helper for setting `output_shapes` and `output_types` attrs of an op. 648 649 Most dataset op constructors expect `output_shapes` and `output_types` 650 arguments that represent the flattened structure of an element. This helper 651 function generates these attrs as a keyword argument dictionary, allowing 652 `Dataset._variant_tensor` implementations to pass `**self._flat_structure` 653 to the op constructor. 654 655 Returns: 656 A dictionary of keyword arguments that can be passed to a dataset op 657 constructor. 658 """ 659 return { 660 "output_shapes": self._flat_shapes, 661 "output_types": self._flat_types, 662 } 663 664 @property 665 def _metadata(self): 666 """Helper for generating dataset metadata.""" 667 metadata = dataset_metadata_pb2.Metadata() 668 if self._name: 669 metadata.name = _validate_and_encode(self._name) 670 return metadata 671 672 @property 673 def _common_args(self): 674 """Helper for generating arguments that are common across most dataset ops. 675 676 Most dataset op constructors expect `output_shapes` and `output_types` 677 arguments that represent the flattened structure of an element, as well as a 678 `metadata` argument for additional metadata such as user-defined dataset 679 name. This helper function generates common attributes as a keyword argument 680 dictionary, allowing `Dataset._variant_tensor` implementations to pass 681 `**self._common_args` to the op constructor. 682 683 Returns: 684 A dictionary of keyword arguments that can be passed to a dataset op 685 constructor. 686 """ 687 return { 688 "metadata": self._metadata.SerializeToString(), 689 "output_shapes": self._flat_shapes, 690 "output_types": self._flat_types, 691 } 692 693 @property 694 def _type_spec(self): 695 return DatasetSpec(self.element_spec) 696 697 @staticmethod 698 def from_tensors(tensors, name=None): 699 """Creates a `Dataset` with a single element, comprising the given tensors. 700 701 `from_tensors` produces a dataset containing only a single element. To slice 702 the input tensor into multiple elements, use `from_tensor_slices` instead. 703 704 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3]) 705 >>> list(dataset.as_numpy_iterator()) 706 [array([1, 2, 3], dtype=int32)] 707 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A')) 708 >>> list(dataset.as_numpy_iterator()) 709 [(array([1, 2, 3], dtype=int32), b'A')] 710 711 >>> # You can use `from_tensors` to produce a dataset which repeats 712 >>> # the same example many times. 713 >>> example = tf.constant([1,2,3]) 714 >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2) 715 >>> list(dataset.as_numpy_iterator()) 716 [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)] 717 718 Note that if `tensors` contains a NumPy array, and eager execution is not 719 enabled, the values will be embedded in the graph as one or more 720 `tf.constant` operations. For large datasets (> 1 GB), this can waste 721 memory and run into byte limits of graph serialization. If `tensors` 722 contains one or more large NumPy arrays, consider the alternative described 723 in [this 724 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays). 725 726 Args: 727 tensors: A dataset "element". Supported values are documented 728 [here](https://www.tensorflow.org/guide/data#dataset_structure). 729 name: (Optional.) A name for the tf.data operation. 730 731 Returns: 732 Dataset: A `Dataset`. 733 """ 734 return TensorDataset(tensors, name=name) 735 736 @staticmethod 737 def from_tensor_slices(tensors, name=None): 738 """Creates a `Dataset` whose elements are slices of the given tensors. 739 740 The given tensors are sliced along their first dimension. This operation 741 preserves the structure of the input tensors, removing the first dimension 742 of each tensor and using it as the dataset dimension. All input tensors 743 must have the same size in their first dimensions. 744 745 >>> # Slicing a 1D tensor produces scalar tensor elements. 746 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 747 >>> list(dataset.as_numpy_iterator()) 748 [1, 2, 3] 749 750 >>> # Slicing a 2D tensor produces 1D tensor elements. 751 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]]) 752 >>> list(dataset.as_numpy_iterator()) 753 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] 754 755 >>> # Slicing a tuple of 1D tensors produces tuple elements containing 756 >>> # scalar tensors. 757 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6])) 758 >>> list(dataset.as_numpy_iterator()) 759 [(1, 3, 5), (2, 4, 6)] 760 761 >>> # Dictionary structure is also preserved. 762 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) 763 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3}, 764 ... {'a': 2, 'b': 4}] 765 True 766 767 >>> # Two tensors can be combined into one Dataset object. 768 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor 769 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor 770 >>> dataset = Dataset.from_tensor_slices((features, labels)) 771 >>> # Both the features and the labels tensors can be converted 772 >>> # to a Dataset object separately and combined after. 773 >>> features_dataset = Dataset.from_tensor_slices(features) 774 >>> labels_dataset = Dataset.from_tensor_slices(labels) 775 >>> dataset = Dataset.zip((features_dataset, labels_dataset)) 776 >>> # A batched feature and label set can be converted to a Dataset 777 >>> # in similar fashion. 778 >>> batched_features = tf.constant([[[1, 3], [2, 3]], 779 ... [[2, 1], [1, 2]], 780 ... [[3, 3], [3, 2]]], shape=(3, 2, 2)) 781 >>> batched_labels = tf.constant([['A', 'A'], 782 ... ['B', 'B'], 783 ... ['A', 'B']], shape=(3, 2, 1)) 784 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels)) 785 >>> for element in dataset.as_numpy_iterator(): 786 ... print(element) 787 (array([[1, 3], 788 [2, 3]], dtype=int32), array([[b'A'], 789 [b'A']], dtype=object)) 790 (array([[2, 1], 791 [1, 2]], dtype=int32), array([[b'B'], 792 [b'B']], dtype=object)) 793 (array([[3, 3], 794 [3, 2]], dtype=int32), array([[b'A'], 795 [b'B']], dtype=object)) 796 797 Note that if `tensors` contains a NumPy array, and eager execution is not 798 enabled, the values will be embedded in the graph as one or more 799 `tf.constant` operations. For large datasets (> 1 GB), this can waste 800 memory and run into byte limits of graph serialization. If `tensors` 801 contains one or more large NumPy arrays, consider the alternative described 802 in [this guide]( 803 https://tensorflow.org/guide/data#consuming_numpy_arrays). 804 805 Args: 806 tensors: A dataset element, whose components have the same first 807 dimension. Supported values are documented 808 [here](https://www.tensorflow.org/guide/data#dataset_structure). 809 name: (Optional.) A name for the tf.data operation. 810 811 Returns: 812 Dataset: A `Dataset`. 813 """ 814 return TensorSliceDataset(tensors, name=name) 815 816 class _GeneratorState: 817 """Stores outstanding iterators created from a Python generator. 818 819 This class keeps track of potentially multiple iterators that may have 820 been created from a generator, e.g. in the case that the dataset is 821 repeated, or nested within a parallel computation. 822 """ 823 824 def __init__(self, generator): 825 self._generator = generator 826 self._lock = threading.Lock() 827 self._next_id = 0 # GUARDED_BY(self._lock) 828 self._args = {} 829 self._iterators = {} 830 831 def _normalize_id(self, iterator_id): 832 # In debug mode, iterator ids may be eagerly-generated np.arrays instead 833 # of Tensors. We convert them to scalars to make them hashable. 834 if isinstance(iterator_id, np.ndarray): 835 return iterator_id.item() 836 return iterator_id 837 838 def get_next_id(self, *args): 839 with self._lock: 840 ret = self._next_id 841 self._next_id += 1 842 self._args[ret] = args 843 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 844 # casting in `py_func()` will create an array of `np.int32` on Windows, 845 # leading to a runtime error. 846 return np.array(ret, dtype=np.int64) 847 848 def get_iterator(self, iterator_id): 849 iterator_id = self._normalize_id(iterator_id) 850 try: 851 return self._iterators[iterator_id] 852 except KeyError: 853 iterator = iter(self._generator(*self._args.pop(iterator_id))) 854 self._iterators[iterator_id] = iterator 855 return iterator 856 857 def iterator_completed(self, iterator_id): 858 del self._iterators[self._normalize_id(iterator_id)] 859 860 @staticmethod 861 @deprecation.deprecated_args(None, "Use output_signature instead", 862 "output_types", "output_shapes") 863 def from_generator(generator, 864 output_types=None, 865 output_shapes=None, 866 args=None, 867 output_signature=None, 868 name=None): 869 """Creates a `Dataset` whose elements are generated by `generator`. 870 871 Note: The current implementation of `Dataset.from_generator()` uses 872 `tf.numpy_function` and inherits the same constraints. In particular, it 873 requires the dataset and iterator related operations to be placed 874 on a device in the same process as the Python program that called 875 `Dataset.from_generator()`. In particular, using `from_generator` will 876 preclude the use of tf.data service for scaling out dataset processing. 877 The body of `generator` will not be serialized in a `GraphDef`, and you 878 should not use this method if you need to serialize your model and restore 879 it in a different environment. 880 881 The `generator` argument must be a callable object that returns 882 an object that supports the `iter()` protocol (e.g. a generator function). 883 884 The elements generated by `generator` must be compatible with either the 885 given `output_signature` argument or with the given `output_types` and 886 (optionally) `output_shapes` arguments, whichever was specified. 887 888 The recommended way to call `from_generator` is to use the 889 `output_signature` argument. In this case the output will be assumed to 890 consist of objects with the classes, shapes and types defined by 891 `tf.TypeSpec` objects from `output_signature` argument: 892 893 >>> def gen(): 894 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]]) 895 ... yield 42, ragged_tensor 896 >>> 897 >>> dataset = tf.data.Dataset.from_generator( 898 ... gen, 899 ... output_signature=( 900 ... tf.TensorSpec(shape=(), dtype=tf.int32), 901 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32))) 902 >>> 903 >>> list(dataset.take(1)) 904 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>, 905 <tf.RaggedTensor [[1, 2], [3]]>)] 906 907 There is also a deprecated way to call `from_generator` by either with 908 `output_types` argument alone or together with `output_shapes` argument. 909 In this case the output of the function will be assumed to consist of 910 `tf.Tensor` objects with the types defined by `output_types` and with the 911 shapes which are either unknown or defined by `output_shapes`. 912 913 Note: If `generator` depends on mutable global variables or other external 914 state, be aware that the runtime may invoke `generator` multiple times 915 (in order to support repeating the `Dataset`) and at any time 916 between the call to `Dataset.from_generator()` and the production of the 917 first element from the generator. Mutating global variables or external 918 state can cause undefined behavior, and we recommend that you explicitly 919 cache any external state in `generator` before calling 920 `Dataset.from_generator()`. 921 922 Note: While the `output_signature` parameter makes it possible to yield 923 `Dataset` elements, the scope of `Dataset.from_generator()` should be 924 limited to logic that cannot be expressed through tf.data operations. Using 925 tf.data operations within the generator function is an anti-pattern and may 926 result in incremental memory growth. 927 928 Args: 929 generator: A callable object that returns an object that supports the 930 `iter()` protocol. If `args` is not specified, `generator` must take no 931 arguments; otherwise it must take as many arguments as there are values 932 in `args`. 933 output_types: (Optional.) A (nested) structure of `tf.DType` objects 934 corresponding to each component of an element yielded by `generator`. 935 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 936 objects corresponding to each component of an element yielded by 937 `generator`. 938 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 939 and passed to `generator` as NumPy-array arguments. 940 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` 941 objects corresponding to each component of an element yielded by 942 `generator`. 943 name: (Optional.) A name for the tf.data operations used by 944 `from_generator`. 945 946 Returns: 947 Dataset: A `Dataset`. 948 """ 949 if not callable(generator): 950 raise TypeError("`generator` must be a Python callable.") 951 952 if output_signature is not None: 953 if output_types is not None: 954 raise TypeError("The `output_types` argument can not be used together " 955 "with the `output_signature` argument.") 956 if output_shapes is not None: 957 raise TypeError("The `output_shapes` argument can not be used together " 958 "with the `output_signature` argument.") 959 for spec in nest.flatten(output_signature): 960 if not isinstance(spec, type_spec.TypeSpec): 961 raise TypeError(f"`output_signature` must contain objects that are " 962 f"subclass of `tf.TypeSpec` but found {type(spec)} " 963 f"which is not.") 964 else: 965 if output_types is None: 966 raise TypeError("To specify the output signature you need to provide " 967 "either the `output_signature` argument or the " 968 "`output_types` argument.") 969 970 if output_signature is None: 971 if output_shapes is None: 972 output_shapes = nest.map_structure( 973 lambda _: tensor_shape.TensorShape(None), output_types) 974 else: 975 output_shapes = nest.map_structure_up_to(output_types, 976 tensor_shape.as_shape, 977 output_shapes) 978 output_signature = nest.map_structure_up_to(output_types, 979 tensor_spec.TensorSpec, 980 output_shapes, output_types) 981 if all( 982 isinstance(x, tensor_spec.TensorSpec) 983 for x in nest.flatten(output_signature)): 984 output_types = nest.pack_sequence_as( 985 output_signature, [x.dtype for x in nest.flatten(output_signature)]) 986 output_shapes = nest.pack_sequence_as( 987 output_signature, [x.shape for x in nest.flatten(output_signature)]) 988 989 if args is None: 990 args = () 991 else: 992 args = tuple(ops.convert_n_to_tensor(args, name="args")) 993 994 generator_state = DatasetV2._GeneratorState(generator) 995 996 def get_iterator_id_fn(unused_dummy): 997 """Creates a unique `iterator_id` for each pass over the dataset. 998 999 The returned `iterator_id` disambiguates between multiple concurrently 1000 existing iterators. 1001 1002 Args: 1003 unused_dummy: Ignored value. 1004 1005 Returns: 1006 A `tf.int64` tensor whose value uniquely identifies an iterator in 1007 `generator_state`. 1008 """ 1009 return script_ops.numpy_function(generator_state.get_next_id, args, 1010 dtypes.int64) 1011 1012 def generator_next_fn(iterator_id_t): 1013 """Generates the next element from iterator with ID `iterator_id_t`. 1014 1015 We map this function across an infinite repetition of the 1016 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 1017 1018 Args: 1019 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the 1020 iterator in `generator_state` from which to generate an element. 1021 1022 Returns: 1023 The next element to generate from the iterator. 1024 """ 1025 if output_types and output_shapes: 1026 flattened_types = [ 1027 dtypes.as_dtype(dt) for dt in nest.flatten(output_types) 1028 ] 1029 flattened_shapes = nest.flatten(output_shapes) 1030 1031 def generator_py_func(iterator_id): 1032 """A `py_func` that will be called to invoke the iterator.""" 1033 # `next()` raises `StopIteration` when there are no more 1034 # elements remaining to be generated. 1035 values = next(generator_state.get_iterator(iterator_id)) 1036 1037 # Use the same _convert function from the py_func() implementation to 1038 # convert the returned values to arrays early, so that we can inspect 1039 # their values. 1040 try: 1041 flattened_values = nest.flatten_up_to(output_types, values) 1042 except (TypeError, ValueError) as e: 1043 raise TypeError( 1044 f"`generator` yielded an element that did not match the " 1045 f"expected structure. The expected structure was " 1046 f"{output_types}, but the yielded element was {values}.") from e 1047 ret_arrays = [] 1048 for ret, dtype in zip(flattened_values, flattened_types): 1049 try: 1050 ret_arrays.append( 1051 script_ops.FuncRegistry._convert( # pylint: disable=protected-access 1052 ret, 1053 dtype=dtype.as_numpy_dtype)) 1054 except (TypeError, ValueError) as e: 1055 raise TypeError( 1056 f"`generator` yielded an element that could not be " 1057 f"converted to the expected type. The expected type was " 1058 f"{dtype.name}, but the yielded element was {ret}.") from e 1059 1060 # Additional type and shape checking to ensure that the components of 1061 # the generated element match the `output_types` and `output_shapes` 1062 # arguments. 1063 for (ret_array, expected_dtype, 1064 expected_shape) in zip(ret_arrays, flattened_types, 1065 flattened_shapes): 1066 if ret_array.dtype != expected_dtype.as_numpy_dtype: 1067 raise TypeError( 1068 f"`generator` yielded an element of type {ret_array.dtype} " 1069 f"where an element of type {expected_dtype.as_numpy_dtype} " 1070 f"was expected.") 1071 if not expected_shape.is_compatible_with(ret_array.shape): 1072 raise TypeError( 1073 f"`generator` yielded an element of shape {ret_array.shape} " 1074 f"where an element of shape {expected_shape} was expected.") 1075 1076 return ret_arrays 1077 1078 flat_values = script_ops.numpy_function(generator_py_func, 1079 [iterator_id_t], 1080 flattened_types) 1081 1082 # In debug mode the numpy_function will return a scalar if 1083 # generator_py_func produces only a single value. 1084 if not isinstance(flat_values, (list, tuple)): 1085 flat_values = [flat_values] 1086 1087 # The `py_func()` op drops the inferred shapes, so we add them back in 1088 # here. 1089 if output_shapes is not None: 1090 for ret_t, shape in zip(flat_values, flattened_shapes): 1091 ret_t.set_shape(shape) 1092 1093 return nest.pack_sequence_as(output_types, flat_values) 1094 else: 1095 flat_output_types = structure.get_flat_tensor_types(output_signature) 1096 1097 def generator_py_func(iterator_id): 1098 """A `py_func` that will be called to invoke the iterator.""" 1099 # `next()` raises `StopIteration` when there are no more 1100 # elements remaining to be generated. 1101 values = next(generator_state.get_iterator(iterator_id.numpy())) 1102 1103 try: 1104 values = structure.normalize_element(values, output_signature) 1105 except (TypeError, ValueError) as e: 1106 raise TypeError( 1107 f"`generator` yielded an element that did not match the " 1108 f"expected structure. The expected structure was " 1109 f"{output_signature}, but the yielded element was " 1110 f"{values}.") from e 1111 1112 values_spec = structure.type_spec_from_value(values) 1113 1114 if not structure.are_compatible(values_spec, output_signature): 1115 raise TypeError( 1116 f"`generator` yielded an element of {values_spec} where an " 1117 f"element of {output_signature} was expected.") 1118 1119 return structure.to_tensor_list(output_signature, values) 1120 1121 return script_ops.eager_py_func( 1122 generator_py_func, inp=[iterator_id_t], Tout=flat_output_types) 1123 1124 def finalize_fn(iterator_id_t): 1125 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 1126 1127 def finalize_py_func(iterator_id): 1128 generator_state.iterator_completed(iterator_id) 1129 # We return a dummy value so that the `finalize_fn` has a valid 1130 # signature. 1131 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 1132 # casting in `py_func()` will create an array of `np.int32` on Windows, 1133 # leading to a runtime error. 1134 return np.array(0, dtype=np.int64) 1135 1136 return script_ops.numpy_function(finalize_py_func, [iterator_id_t], 1137 dtypes.int64) 1138 1139 # This function associates each traversal of `generator` with a unique 1140 # iterator ID. 1141 def flat_map_fn(dummy_arg): 1142 # The `get_iterator_id_fn` gets a unique ID for the current instance of 1143 # of the generator. 1144 # The `generator_next_fn` gets the next element from the iterator with the 1145 # given ID, and raises StopIteration when that iterator contains no 1146 # more elements. 1147 return _GeneratorDataset( 1148 dummy_arg, 1149 get_iterator_id_fn, 1150 generator_next_fn, 1151 finalize_fn, 1152 output_signature, 1153 name=name) 1154 1155 # A single-element dataset that, each time it is evaluated, contains a 1156 # freshly-generated and unique (for the returned dataset) int64 1157 # ID that will be used to identify the appropriate Python state, which 1158 # is encapsulated in `generator_state`, and captured in 1159 # `get_iterator_id_map_fn`. 1160 dummy = 0 1161 id_dataset = Dataset.from_tensors(dummy, name=name) 1162 1163 # A dataset that contains all of the elements generated by a 1164 # single iterator created from `generator`, identified by the 1165 # iterator ID contained in `id_dataset`. Lifting the iteration 1166 # into a flat_map here enables multiple repetitions and/or nested 1167 # versions of the returned dataset to be created, because it forces 1168 # the generation of a new ID for each version. 1169 return id_dataset.flat_map(flat_map_fn, name=name) 1170 1171 @staticmethod 1172 def range(*args, **kwargs): 1173 """Creates a `Dataset` of a step-separated range of values. 1174 1175 >>> list(Dataset.range(5).as_numpy_iterator()) 1176 [0, 1, 2, 3, 4] 1177 >>> list(Dataset.range(2, 5).as_numpy_iterator()) 1178 [2, 3, 4] 1179 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator()) 1180 [1, 3] 1181 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator()) 1182 [] 1183 >>> list(Dataset.range(5, 1).as_numpy_iterator()) 1184 [] 1185 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator()) 1186 [5, 3] 1187 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator()) 1188 [2, 3, 4] 1189 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator()) 1190 [1.0, 3.0] 1191 1192 Args: 1193 *args: follows the same semantics as python's range. 1194 len(args) == 1 -> start = 0, stop = args[0], step = 1. 1195 len(args) == 2 -> start = args[0], stop = args[1], step = 1. 1196 len(args) == 3 -> start = args[0], stop = args[1], step = args[2]. 1197 **kwargs: 1198 - output_type: Its expected dtype. (Optional, default: `tf.int64`). 1199 - name: (Optional.) A name for the tf.data operation. 1200 1201 Returns: 1202 Dataset: A `RangeDataset`. 1203 1204 Raises: 1205 ValueError: if len(args) == 0. 1206 """ 1207 return RangeDataset(*args, **kwargs) 1208 1209 @staticmethod 1210 def zip(datasets, name=None): 1211 """Creates a `Dataset` by zipping together the given datasets. 1212 1213 This method has similar semantics to the built-in `zip()` function 1214 in Python, with the main difference being that the `datasets` 1215 argument can be a (nested) structure of `Dataset` objects. The supported 1216 nesting mechanisms are documented 1217 [here] (https://www.tensorflow.org/guide/data#dataset_structure). 1218 1219 >>> # The nested structure of the `datasets` argument determines the 1220 >>> # structure of elements in the resulting dataset. 1221 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1222 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] 1223 >>> ds = tf.data.Dataset.zip((a, b)) 1224 >>> list(ds.as_numpy_iterator()) 1225 [(1, 4), (2, 5), (3, 6)] 1226 >>> ds = tf.data.Dataset.zip((b, a)) 1227 >>> list(ds.as_numpy_iterator()) 1228 [(4, 1), (5, 2), (6, 3)] 1229 >>> 1230 >>> # The `datasets` argument may contain an arbitrary number of datasets. 1231 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], 1232 ... # [9, 10], 1233 ... # [11, 12] ] 1234 >>> ds = tf.data.Dataset.zip((a, b, c)) 1235 >>> for element in ds.as_numpy_iterator(): 1236 ... print(element) 1237 (1, 4, array([7, 8])) 1238 (2, 5, array([ 9, 10])) 1239 (3, 6, array([11, 12])) 1240 >>> 1241 >>> # The number of elements in the resulting dataset is the same as 1242 >>> # the size of the smallest dataset in `datasets`. 1243 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] 1244 >>> ds = tf.data.Dataset.zip((a, d)) 1245 >>> list(ds.as_numpy_iterator()) 1246 [(1, 13), (2, 14)] 1247 1248 Args: 1249 datasets: A (nested) structure of datasets. 1250 name: (Optional.) A name for the tf.data operation. 1251 1252 Returns: 1253 Dataset: A `Dataset`. 1254 """ 1255 return ZipDataset(datasets, name=name) 1256 1257 def concatenate(self, dataset, name=None): 1258 """Creates a `Dataset` by concatenating the given dataset with this dataset. 1259 1260 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 1261 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] 1262 >>> ds = a.concatenate(b) 1263 >>> list(ds.as_numpy_iterator()) 1264 [1, 2, 3, 4, 5, 6, 7] 1265 >>> # The input dataset and dataset to be concatenated should have 1266 >>> # compatible element specs. 1267 >>> c = tf.data.Dataset.zip((a, b)) 1268 >>> a.concatenate(c) 1269 Traceback (most recent call last): 1270 TypeError: Two datasets to concatenate have different types 1271 <dtype: 'int64'> and (tf.int64, tf.int64) 1272 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) 1273 >>> a.concatenate(d) 1274 Traceback (most recent call last): 1275 TypeError: Two datasets to concatenate have different types 1276 <dtype: 'int64'> and <dtype: 'string'> 1277 1278 Args: 1279 dataset: `Dataset` to be concatenated. 1280 name: (Optional.) A name for the tf.data operation. 1281 1282 Returns: 1283 Dataset: A `Dataset`. 1284 """ 1285 return ConcatenateDataset(self, dataset, name=name) 1286 1287 def prefetch(self, buffer_size, name=None): 1288 """Creates a `Dataset` that prefetches elements from this dataset. 1289 1290 Most dataset input pipelines should end with a call to `prefetch`. This 1291 allows later elements to be prepared while the current element is being 1292 processed. This often improves latency and throughput, at the cost of 1293 using additional memory to store prefetched elements. 1294 1295 Note: Like other `Dataset` methods, prefetch operates on the 1296 elements of the input dataset. It has no concept of examples vs. batches. 1297 `examples.prefetch(2)` will prefetch two elements (2 examples), 1298 while `examples.batch(20).prefetch(2)` will prefetch 2 elements 1299 (2 batches, of 20 examples each). 1300 1301 >>> dataset = tf.data.Dataset.range(3) 1302 >>> dataset = dataset.prefetch(2) 1303 >>> list(dataset.as_numpy_iterator()) 1304 [0, 1, 2] 1305 1306 Args: 1307 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 1308 number of elements that will be buffered when prefetching. If the value 1309 `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned. 1310 name: Optional. A name for the tf.data transformation. 1311 1312 Returns: 1313 Dataset: A `Dataset`. 1314 """ 1315 if DEBUG_MODE: 1316 return self 1317 return PrefetchDataset(self, buffer_size, name=name) 1318 1319 @staticmethod 1320 def list_files(file_pattern, shuffle=None, seed=None, name=None): 1321 """A dataset of all files matching one or more glob patterns. 1322 1323 The `file_pattern` argument should be a small number of glob patterns. 1324 If your filenames have already been globbed, use 1325 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every 1326 filename with `list_files` may result in poor performance with remote 1327 storage systems. 1328 1329 Note: The default behavior of this method is to return filenames in 1330 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 1331 to get results in a deterministic order. 1332 1333 Example: 1334 If we had the following files on our filesystem: 1335 1336 - /path/to/dir/a.txt 1337 - /path/to/dir/b.py 1338 - /path/to/dir/c.py 1339 1340 If we pass "/path/to/dir/*.py" as the directory, the dataset 1341 would produce: 1342 1343 - /path/to/dir/b.py 1344 - /path/to/dir/c.py 1345 1346 Args: 1347 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 1348 (scalar or vector), representing the filename glob (i.e. shell wildcard) 1349 pattern(s) that will be matched. 1350 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 1351 Defaults to `True`. 1352 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1353 seed that will be used to create the distribution. See 1354 `tf.random.set_seed` for behavior. 1355 name: Optional. A name for the tf.data operations used by `list_files`. 1356 1357 Returns: 1358 Dataset: A `Dataset` of strings corresponding to file names. 1359 """ 1360 with ops.name_scope("list_files"): 1361 if shuffle is None: 1362 shuffle = True 1363 file_pattern = ops.convert_to_tensor( 1364 file_pattern, dtype=dtypes.string, name="file_pattern") 1365 matching_files = gen_io_ops.matching_files(file_pattern) 1366 1367 # Raise an exception if `file_pattern` does not match any files. 1368 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 1369 name="match_not_empty") 1370 1371 message = math_ops.add( 1372 "No files matched pattern: ", 1373 string_ops.reduce_join(file_pattern, separator=", "), name="message") 1374 1375 assert_not_empty = control_flow_ops.Assert( 1376 condition, [message], summarize=1, name="assert_not_empty") 1377 with ops.control_dependencies([assert_not_empty]): 1378 matching_files = array_ops.identity(matching_files) 1379 1380 dataset = TensorSliceDataset(matching_files, is_files=True, name=name) 1381 if issubclass(Dataset, DatasetV1): 1382 dataset = DatasetV1Adapter(dataset) 1383 if shuffle: 1384 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 1385 # list of files might be empty. 1386 buffer_size = math_ops.maximum( 1387 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 1388 dataset = dataset.shuffle(buffer_size, seed=seed, name=name) 1389 return dataset 1390 1391 def repeat(self, count=None, name=None): 1392 """Repeats this dataset so each original value is seen `count` times. 1393 1394 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1395 >>> dataset = dataset.repeat(3) 1396 >>> list(dataset.as_numpy_iterator()) 1397 [1, 2, 3, 1, 2, 3, 1, 2, 3] 1398 1399 Note: If the input dataset depends on global state (e.g. a random number 1400 generator) or its output is non-deterministic (e.g. because of upstream 1401 `shuffle`), then different repetitions may produce different elements. 1402 1403 Args: 1404 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1405 number of times the dataset should be repeated. The default behavior (if 1406 `count` is `None` or `-1`) is for the dataset be repeated indefinitely. 1407 name: (Optional.) A name for the tf.data operation. 1408 1409 Returns: 1410 Dataset: A `Dataset`. 1411 """ 1412 return RepeatDataset(self, count, name=name) 1413 1414 def enumerate(self, start=0, name=None): 1415 """Enumerates the elements of this dataset. 1416 1417 It is similar to python's `enumerate`. 1418 1419 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 1420 >>> dataset = dataset.enumerate(start=5) 1421 >>> for element in dataset.as_numpy_iterator(): 1422 ... print(element) 1423 (5, 1) 1424 (6, 2) 1425 (7, 3) 1426 1427 >>> # The (nested) structure of the input dataset determines the 1428 >>> # structure of elements in the resulting dataset. 1429 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) 1430 >>> dataset = dataset.enumerate() 1431 >>> for element in dataset.as_numpy_iterator(): 1432 ... print(element) 1433 (0, array([7, 8], dtype=int32)) 1434 (1, array([ 9, 10], dtype=int32)) 1435 1436 Args: 1437 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for 1438 enumeration. 1439 name: Optional. A name for the tf.data operations used by `enumerate`. 1440 1441 Returns: 1442 Dataset: A `Dataset`. 1443 """ 1444 1445 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 1446 range_dataset = Dataset.range(start, max_value, name=name) 1447 # Replicate the range component so that each split is enumerated 1448 # independently. This avoids the need for prohibitively expensive 1449 # cross-split coordination. 1450 range_dataset = _apply_rewrite(range_dataset, "replicate_on_split") 1451 return Dataset.zip((range_dataset, self), name=name) 1452 1453 def shuffle(self, 1454 buffer_size, 1455 seed=None, 1456 reshuffle_each_iteration=None, 1457 name=None): 1458 """Randomly shuffles the elements of this dataset. 1459 1460 This dataset fills a buffer with `buffer_size` elements, then randomly 1461 samples elements from this buffer, replacing the selected elements with new 1462 elements. For perfect shuffling, a buffer size greater than or equal to the 1463 full size of the dataset is required. 1464 1465 For instance, if your dataset contains 10,000 elements but `buffer_size` is 1466 set to 1,000, then `shuffle` will initially select a random element from 1467 only the first 1,000 elements in the buffer. Once an element is selected, 1468 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 1469 maintaining the 1,000 element buffer. 1470 1471 `reshuffle_each_iteration` controls whether the shuffle order should be 1472 different for each epoch. In TF 1.X, the idiomatic way to create epochs 1473 was through the `repeat` transformation: 1474 1475 ```python 1476 dataset = tf.data.Dataset.range(3) 1477 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1478 dataset = dataset.repeat(2) 1479 # [1, 0, 2, 1, 2, 0] 1480 1481 dataset = tf.data.Dataset.range(3) 1482 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1483 dataset = dataset.repeat(2) 1484 # [1, 0, 2, 1, 0, 2] 1485 ``` 1486 1487 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it 1488 possible to also create epochs through Python iteration: 1489 1490 ```python 1491 dataset = tf.data.Dataset.range(3) 1492 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 1493 list(dataset.as_numpy_iterator()) 1494 # [1, 0, 2] 1495 list(dataset.as_numpy_iterator()) 1496 # [1, 2, 0] 1497 ``` 1498 1499 ```python 1500 dataset = tf.data.Dataset.range(3) 1501 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 1502 list(dataset.as_numpy_iterator()) 1503 # [1, 0, 2] 1504 list(dataset.as_numpy_iterator()) 1505 # [1, 0, 2] 1506 ``` 1507 1508 Args: 1509 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1510 elements from this dataset from which the new dataset will sample. 1511 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 1512 seed that will be used to create the distribution. See 1513 `tf.random.set_seed` for behavior. 1514 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 1515 that the dataset should be pseudorandomly reshuffled each time it is 1516 iterated over. (Defaults to `True`.) 1517 name: (Optional.) A name for the tf.data operation. 1518 1519 Returns: 1520 Dataset: A `Dataset`. 1521 """ 1522 return ShuffleDataset( 1523 self, buffer_size, seed, reshuffle_each_iteration, name=name) 1524 1525 def cache(self, filename="", name=None): 1526 """Caches the elements in this dataset. 1527 1528 The first time the dataset is iterated over, its elements will be cached 1529 either in the specified file or in memory. Subsequent iterations will 1530 use the cached data. 1531 1532 Note: To guarantee that the cache gets finalized, the input dataset must be 1533 iterated through in its entirety, until it raises StopIteration. Otherwise, 1534 subsequent iterations may not use cached data. 1535 1536 >>> dataset = tf.data.Dataset.range(5) 1537 >>> dataset = dataset.map(lambda x: x**2) 1538 >>> dataset = dataset.cache() 1539 >>> # The first time reading through the data will generate the data using 1540 >>> # `range` and `map`. 1541 >>> list(dataset.as_numpy_iterator()) 1542 [0, 1, 4, 9, 16] 1543 >>> # Subsequent iterations read from the cache. 1544 >>> list(dataset.as_numpy_iterator()) 1545 [0, 1, 4, 9, 16] 1546 1547 When caching to a file, the cached data will persist across runs. Even the 1548 first iteration through the data will read from the cache file. Changing 1549 the input pipeline before the call to `.cache()` will have no effect until 1550 the cache file is removed or the filename is changed. 1551 1552 ```python 1553 dataset = tf.data.Dataset.range(5) 1554 dataset = dataset.cache("/path/to/file") 1555 list(dataset.as_numpy_iterator()) 1556 # [0, 1, 2, 3, 4] 1557 dataset = tf.data.Dataset.range(10) 1558 dataset = dataset.cache("/path/to/file") # Same file! 1559 list(dataset.as_numpy_iterator()) 1560 # [0, 1, 2, 3, 4] 1561 ``` 1562 1563 Note: `cache` will produce exactly the same elements during each iteration 1564 through the dataset. If you wish to randomize the iteration order, make sure 1565 to call `shuffle` *after* calling `cache`. 1566 1567 Args: 1568 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 1569 directory on the filesystem to use for caching elements in this Dataset. 1570 If a filename is not provided, the dataset will be cached in memory. 1571 name: (Optional.) A name for the tf.data operation. 1572 1573 Returns: 1574 Dataset: A `Dataset`. 1575 """ 1576 return CacheDataset(self, filename, name=name) 1577 1578 def take(self, count, name=None): 1579 """Creates a `Dataset` with at most `count` elements from this dataset. 1580 1581 >>> dataset = tf.data.Dataset.range(10) 1582 >>> dataset = dataset.take(3) 1583 >>> list(dataset.as_numpy_iterator()) 1584 [0, 1, 2] 1585 1586 Args: 1587 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1588 elements of this dataset that should be taken to form the new dataset. 1589 If `count` is -1, or if `count` is greater than the size of this 1590 dataset, the new dataset will contain all elements of this dataset. 1591 name: (Optional.) A name for the tf.data operation. 1592 1593 Returns: 1594 Dataset: A `Dataset`. 1595 """ 1596 return TakeDataset(self, count, name=name) 1597 1598 def skip(self, count, name=None): 1599 """Creates a `Dataset` that skips `count` elements from this dataset. 1600 1601 >>> dataset = tf.data.Dataset.range(10) 1602 >>> dataset = dataset.skip(7) 1603 >>> list(dataset.as_numpy_iterator()) 1604 [7, 8, 9] 1605 1606 Args: 1607 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 1608 elements of this dataset that should be skipped to form the new dataset. 1609 If `count` is greater than the size of this dataset, the new dataset 1610 will contain no elements. If `count` is -1, skips the entire dataset. 1611 name: (Optional.) A name for the tf.data operation. 1612 1613 Returns: 1614 Dataset: A `Dataset`. 1615 """ 1616 return SkipDataset(self, count, name=name) 1617 1618 def shard(self, num_shards, index, name=None): 1619 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 1620 1621 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will 1622 contain all elements of A whose index mod n = i. 1623 1624 >>> A = tf.data.Dataset.range(10) 1625 >>> B = A.shard(num_shards=3, index=0) 1626 >>> list(B.as_numpy_iterator()) 1627 [0, 3, 6, 9] 1628 >>> C = A.shard(num_shards=3, index=1) 1629 >>> list(C.as_numpy_iterator()) 1630 [1, 4, 7] 1631 >>> D = A.shard(num_shards=3, index=2) 1632 >>> list(D.as_numpy_iterator()) 1633 [2, 5, 8] 1634 1635 This dataset operator is very useful when running distributed training, as 1636 it allows each worker to read a unique subset. 1637 1638 When reading a single input file, you can shard elements as follows: 1639 1640 ```python 1641 d = tf.data.TFRecordDataset(input_file) 1642 d = d.shard(num_workers, worker_index) 1643 d = d.repeat(num_epochs) 1644 d = d.shuffle(shuffle_buffer_size) 1645 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1646 ``` 1647 1648 Important caveats: 1649 1650 - Be sure to shard before you use any randomizing operator (such as 1651 shuffle). 1652 - Generally it is best if the shard operator is used early in the dataset 1653 pipeline. For example, when reading from a set of TFRecord files, shard 1654 before converting the dataset to input samples. This avoids reading every 1655 file on every worker. The following is an example of an efficient 1656 sharding strategy within a complete pipeline: 1657 1658 ```python 1659 d = Dataset.list_files(pattern) 1660 d = d.shard(num_workers, worker_index) 1661 d = d.repeat(num_epochs) 1662 d = d.shuffle(shuffle_buffer_size) 1663 d = d.interleave(tf.data.TFRecordDataset, 1664 cycle_length=num_readers, block_length=1) 1665 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 1666 ``` 1667 1668 Args: 1669 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 1670 shards operating in parallel. 1671 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 1672 name: (Optional.) A name for the tf.data operation. 1673 1674 Returns: 1675 Dataset: A `Dataset`. 1676 1677 Raises: 1678 InvalidArgumentError: if `num_shards` or `index` are illegal values. 1679 1680 Note: error checking is done on a best-effort basis, and errors aren't 1681 guaranteed to be caught upon dataset creation. (e.g. providing in a 1682 placeholder tensor bypasses the early checking, and will instead result 1683 in an error during a session.run call.) 1684 """ 1685 return ShardDataset(self, num_shards, index, name=name) 1686 1687 def save(self, 1688 path, 1689 compression=None, 1690 shard_func=None, 1691 checkpoint_args=None): 1692 """Saves the content of the given dataset. 1693 1694 Example usage: 1695 1696 >>> import tempfile 1697 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 1698 >>> # Save a dataset 1699 >>> dataset = tf.data.Dataset.range(2) 1700 >>> dataset.save(path) 1701 >>> new_dataset = tf.data.Dataset.load(path) 1702 >>> for elem in new_dataset: 1703 ... print(elem) 1704 tf.Tensor(0, shape=(), dtype=int64) 1705 tf.Tensor(1, shape=(), dtype=int64) 1706 1707 The saved dataset is saved in multiple file "shards". By default, the 1708 dataset output is divided to shards in a round-robin fashion but custom 1709 sharding can be specified via the `shard_func` function. For example, you 1710 can save the dataset to using a single shard as follows: 1711 1712 ```python 1713 dataset = make_dataset() 1714 def custom_shard_func(element): 1715 return np.int64(0) 1716 dataset.save( 1717 path="/path/to/data", ..., shard_func=custom_shard_func) 1718 ``` 1719 1720 To enable checkpointing, pass in `checkpoint_args` to the `save` method 1721 as follows: 1722 1723 ```python 1724 dataset = tf.data.Dataset.range(100) 1725 save_dir = "..." 1726 checkpoint_prefix = "..." 1727 step_counter = tf.Variable(0, trainable=False) 1728 checkpoint_args = { 1729 "checkpoint_interval": 50, 1730 "step_counter": step_counter, 1731 "directory": checkpoint_prefix, 1732 "max_to_keep": 20, 1733 } 1734 dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args) 1735 ``` 1736 1737 NOTE: The directory layout and file format used for saving the dataset is 1738 considered an implementation detail and may change. For this reason, 1739 datasets saved through `tf.data.Dataset.save` should only be consumed 1740 through `tf.data.Dataset.load`, which is guaranteed to be 1741 backwards compatible. 1742 1743 Args: 1744 path: Required. A directory to use for saving the dataset. 1745 compression: Optional. The algorithm to use to compress data when writing 1746 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 1747 shard_func: Optional. A function to control the mapping of dataset 1748 elements to file shards. The function is expected to map elements of 1749 the input dataset to int64 shard IDs. If present, the function will be 1750 traced and executed as graph computation. 1751 checkpoint_args: Optional args for checkpointing which will be passed into 1752 the `tf.train.CheckpointManager`. If `checkpoint_args` are not 1753 specified, then checkpointing will not be performed. The `save()` 1754 implementation creates a `tf.train.Checkpoint` object internally, so 1755 users should not set the `checkpoint` argument in `checkpoint_args`. 1756 1757 Raises: 1758 ValueError if `checkpoint` is passed into `checkpoint_args`. 1759 """ 1760 # Loaded lazily due to a circular dependency 1761 # dataset_ops->save_ops->dataset_ops 1762 from tensorflow.python.data.ops import save_op # pylint: disable=g-import-not-at-top 1763 save_op.save(self, path, compression, shard_func, checkpoint_args) 1764 1765 @staticmethod 1766 def load(path, element_spec=None, compression=None, reader_func=None): 1767 """Loads a previously saved dataset. 1768 1769 Example usage: 1770 1771 >>> import tempfile 1772 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 1773 >>> # Save a dataset 1774 >>> dataset = tf.data.Dataset.range(2) 1775 >>> tf.data.Dataset.save(dataset, path) 1776 >>> new_dataset = tf.data.Dataset.load(path) 1777 >>> for elem in new_dataset: 1778 ... print(elem) 1779 tf.Tensor(0, shape=(), dtype=int64) 1780 tf.Tensor(1, shape=(), dtype=int64) 1781 1782 1783 If the default option of sharding the saved dataset was used, the element 1784 order of the saved dataset will be preserved when loading it. 1785 1786 The `reader_func` argument can be used to specify a custom order in which 1787 elements should be loaded from the individual shards. The `reader_func` is 1788 expected to take a single argument -- a dataset of datasets, each containing 1789 elements of one of the shards -- and return a dataset of elements. For 1790 example, the order of shards can be shuffled when loading them as follows: 1791 1792 ```python 1793 def custom_reader_func(datasets): 1794 datasets = datasets.shuffle(NUM_SHARDS) 1795 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 1796 1797 dataset = tf.data.Dataset.load( 1798 path="/path/to/data", ..., reader_func=custom_reader_func) 1799 ``` 1800 1801 Args: 1802 path: Required. A path pointing to a previously saved dataset. 1803 element_spec: Optional. A nested structure of `tf.TypeSpec` objects 1804 matching the structure of an element of the saved dataset and specifying 1805 the type of individual element components. If not provided, the nested 1806 structure of `tf.TypeSpec` saved with the saved dataset is used. Note 1807 that this argument is required in graph mode. 1808 compression: Optional. The algorithm to use to decompress the data when 1809 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 1810 reader_func: Optional. A function to control how to read data from shards. 1811 If present, the function will be traced and executed as graph 1812 computation. 1813 1814 Returns: 1815 A `tf.data.Dataset` instance. 1816 1817 Raises: 1818 FileNotFoundError: If `element_spec` is not specified and the saved nested 1819 structure of `tf.TypeSpec` can not be located with the saved dataset. 1820 ValueError: If `element_spec` is not specified and the method is executed 1821 in graph mode. 1822 """ 1823 # Loaded lazily due to a circular dependency 1824 # dataset_ops->load_ops->dataset_ops 1825 from tensorflow.python.data.ops import load_op # pylint: disable=g-import-not-at-top 1826 return load_op.load( 1827 path=path, 1828 element_spec=element_spec, 1829 compression=compression, 1830 reader_func=reader_func) 1831 1832 def batch(self, 1833 batch_size, 1834 drop_remainder=False, 1835 num_parallel_calls=None, 1836 deterministic=None, 1837 name=None): 1838 """Combines consecutive elements of this dataset into batches. 1839 1840 >>> dataset = tf.data.Dataset.range(8) 1841 >>> dataset = dataset.batch(3) 1842 >>> list(dataset.as_numpy_iterator()) 1843 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] 1844 1845 >>> dataset = tf.data.Dataset.range(8) 1846 >>> dataset = dataset.batch(3, drop_remainder=True) 1847 >>> list(dataset.as_numpy_iterator()) 1848 [array([0, 1, 2]), array([3, 4, 5])] 1849 1850 The components of the resulting element will have an additional outer 1851 dimension, which will be `batch_size` (or `N % batch_size` for the last 1852 element if `batch_size` does not divide the number of input elements `N` 1853 evenly and `drop_remainder` is `False`). If your program depends on the 1854 batches having the same outer dimension, you should set the `drop_remainder` 1855 argument to `True` to prevent the smaller batch from being produced. 1856 1857 Note: If your program requires data to have a statically known shape (e.g., 1858 when using XLA), you should use `drop_remainder=True`. Without 1859 `drop_remainder=True` the shape of the output dataset will have an unknown 1860 leading dimension due to the possibility of a smaller final batch. 1861 1862 Args: 1863 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1864 consecutive elements of this dataset to combine in a single batch. 1865 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 1866 whether the last batch should be dropped in the case it has fewer than 1867 `batch_size` elements; the default behavior is not to drop the smaller 1868 batch. 1869 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 1870 representing the number of batches to compute asynchronously in 1871 parallel. 1872 If not specified, batches will be computed sequentially. If the value 1873 `tf.data.AUTOTUNE` is used, then the number of parallel 1874 calls is set dynamically based on available resources. 1875 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 1876 boolean is specified (`True` or `False`), it controls the order in which 1877 the transformation produces elements. If set to `False`, the 1878 transformation is allowed to yield elements out of order to trade 1879 determinism for performance. If not specified, the 1880 `tf.data.Options.deterministic` option (`True` by default) controls the 1881 behavior. 1882 name: (Optional.) A name for the tf.data operation. 1883 1884 Returns: 1885 Dataset: A `Dataset`. 1886 """ 1887 if num_parallel_calls is None or DEBUG_MODE: 1888 if deterministic is not None and not DEBUG_MODE: 1889 warnings.warn("The `deterministic` argument has no effect unless the " 1890 "`num_parallel_calls` argument is specified.") 1891 return BatchDataset(self, batch_size, drop_remainder, name=name) 1892 else: 1893 return ParallelBatchDataset( 1894 self, 1895 batch_size, 1896 drop_remainder, 1897 num_parallel_calls, 1898 deterministic, 1899 name=name) 1900 1901 def padded_batch(self, 1902 batch_size, 1903 padded_shapes=None, 1904 padding_values=None, 1905 drop_remainder=False, 1906 name=None): 1907 """Combines consecutive elements of this dataset into padded batches. 1908 1909 This transformation combines multiple consecutive elements of the input 1910 dataset into a single element. 1911 1912 Like `tf.data.Dataset.batch`, the components of the resulting element will 1913 have an additional outer dimension, which will be `batch_size` (or 1914 `N % batch_size` for the last element if `batch_size` does not divide the 1915 number of input elements `N` evenly and `drop_remainder` is `False`). If 1916 your program depends on the batches having the same outer dimension, you 1917 should set the `drop_remainder` argument to `True` to prevent the smaller 1918 batch from being produced. 1919 1920 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 1921 different shapes, and this transformation will pad each component to the 1922 respective shape in `padded_shapes`. The `padded_shapes` argument 1923 determines the resulting shape for each dimension of each component in an 1924 output element: 1925 1926 * If the dimension is a constant, the component will be padded out to that 1927 length in that dimension. 1928 * If the dimension is unknown, the component will be padded out to the 1929 maximum length of all elements in that dimension. 1930 1931 >>> A = (tf.data.Dataset 1932 ... .range(1, 5, output_type=tf.int32) 1933 ... .map(lambda x: tf.fill([x], x))) 1934 >>> # Pad to the smallest per-batch size that fits all elements. 1935 >>> B = A.padded_batch(2) 1936 >>> for element in B.as_numpy_iterator(): 1937 ... print(element) 1938 [[1 0] 1939 [2 2]] 1940 [[3 3 3 0] 1941 [4 4 4 4]] 1942 >>> # Pad to a fixed size. 1943 >>> C = A.padded_batch(2, padded_shapes=5) 1944 >>> for element in C.as_numpy_iterator(): 1945 ... print(element) 1946 [[1 0 0 0 0] 1947 [2 2 0 0 0]] 1948 [[3 3 3 0 0] 1949 [4 4 4 4 0]] 1950 >>> # Pad with a custom value. 1951 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1) 1952 >>> for element in D.as_numpy_iterator(): 1953 ... print(element) 1954 [[ 1 -1 -1 -1 -1] 1955 [ 2 2 -1 -1 -1]] 1956 [[ 3 3 3 -1 -1] 1957 [ 4 4 4 4 -1]] 1958 >>> # Components of nested elements can be padded independently. 1959 >>> elements = [([1, 2, 3], [10]), 1960 ... ([4, 5], [11, 12])] 1961 >>> dataset = tf.data.Dataset.from_generator( 1962 ... lambda: iter(elements), (tf.int32, tf.int32)) 1963 >>> # Pad the first component of the tuple to length 4, and the second 1964 >>> # component to the smallest size that fits. 1965 >>> dataset = dataset.padded_batch(2, 1966 ... padded_shapes=([4], [None]), 1967 ... padding_values=(-1, 100)) 1968 >>> list(dataset.as_numpy_iterator()) 1969 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32), 1970 array([[ 10, 100], [ 11, 12]], dtype=int32))] 1971 >>> # Pad with a single value and multiple components. 1972 >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1) 1973 >>> for element in E.as_numpy_iterator(): 1974 ... print(element) 1975 (array([[ 1, -1], 1976 [ 2, 2]], dtype=int32), array([[ 1, -1], 1977 [ 2, 2]], dtype=int32)) 1978 (array([[ 3, 3, 3, -1], 1979 [ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1], 1980 [ 4, 4, 4, 4]], dtype=int32)) 1981 1982 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 1983 elements that may have different shapes into a `tf.sparse.SparseTensor`. 1984 1985 Args: 1986 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 1987 consecutive elements of this dataset to combine in a single batch. 1988 padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or 1989 `tf.int64` vector tensor-like objects representing the shape to which 1990 the respective component of each input element should be padded prior 1991 to batching. Any unknown dimensions will be padded to the maximum size 1992 of that dimension in each batch. If unset, all dimensions of all 1993 components are padded to the maximum size in the batch. `padded_shapes` 1994 must be set if any component has an unknown rank. 1995 padding_values: (Optional.) A (nested) structure of scalar-shaped 1996 `tf.Tensor`, representing the padding values to use for the respective 1997 components. None represents that the (nested) structure should be padded 1998 with default values. Defaults are `0` for numeric types and the empty 1999 string for string types. The `padding_values` should have the same 2000 (nested) structure as the input dataset. If `padding_values` is a single 2001 element and the input dataset has multiple components, then the same 2002 `padding_values` will be used to pad every component of the dataset. 2003 If `padding_values` is a scalar, then its value will be broadcasted 2004 to match the shape of each component. 2005 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 2006 whether the last batch should be dropped in the case it has fewer than 2007 `batch_size` elements; the default behavior is not to drop the smaller 2008 batch. 2009 name: (Optional.) A name for the tf.data operation. 2010 2011 Returns: 2012 Dataset: A `Dataset`. 2013 2014 Raises: 2015 ValueError: If a component has an unknown rank, and the `padded_shapes` 2016 argument is not set. 2017 TypeError: If a component is of an unsupported type. The list of supported 2018 types is documented in 2019 https://www.tensorflow.org/guide/data#dataset_structure. 2020 """ 2021 if padded_shapes is None: 2022 padded_shapes = get_legacy_output_shapes(self) 2023 for i, shape in enumerate(nest.flatten(padded_shapes)): 2024 # A `tf.TensorShape` is only false if its *rank* is unknown. 2025 if not shape: 2026 raise ValueError(f"You must provide `padded_shapes` argument because " 2027 f"component {i} has unknown rank.") 2028 return PaddedBatchDataset( 2029 self, 2030 batch_size, 2031 padded_shapes, 2032 padding_values, 2033 drop_remainder, 2034 name=name) 2035 2036 def map(self, 2037 map_func, 2038 num_parallel_calls=None, 2039 deterministic=None, 2040 name=None): 2041 """Maps `map_func` across the elements of this dataset. 2042 2043 This transformation applies `map_func` to each element of this dataset, and 2044 returns a new dataset containing the transformed elements, in the same 2045 order as they appeared in the input. `map_func` can be used to change both 2046 the values and the structure of a dataset's elements. Supported structure 2047 constructs are documented 2048 [here](https://www.tensorflow.org/guide/data#dataset_structure). 2049 2050 For example, `map` can be used for adding 1 to each element, or projecting a 2051 subset of element components. 2052 2053 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 2054 >>> dataset = dataset.map(lambda x: x + 1) 2055 >>> list(dataset.as_numpy_iterator()) 2056 [2, 3, 4, 5, 6] 2057 2058 The input signature of `map_func` is determined by the structure of each 2059 element in this dataset. 2060 2061 >>> dataset = Dataset.range(5) 2062 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same 2063 >>> # shape and dtype. 2064 >>> result = dataset.map(lambda x: x + 1) 2065 2066 >>> # Each element is a tuple containing two `tf.Tensor` objects. 2067 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")] 2068 >>> dataset = tf.data.Dataset.from_generator( 2069 ... lambda: elements, (tf.int32, tf.string)) 2070 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function 2071 >>> # projects out just the first component. 2072 >>> result = dataset.map(lambda x_int, y_str: x_int) 2073 >>> list(result.as_numpy_iterator()) 2074 [1, 2, 3] 2075 2076 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects. 2077 >>> elements = ([{"a": 1, "b": "foo"}, 2078 ... {"a": 2, "b": "bar"}, 2079 ... {"a": 3, "b": "baz"}]) 2080 >>> dataset = tf.data.Dataset.from_generator( 2081 ... lambda: elements, {"a": tf.int32, "b": tf.string}) 2082 >>> # `map_func` takes a single argument of type `dict` with the same keys 2083 >>> # as the elements. 2084 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"]) 2085 2086 The value or values returned by `map_func` determine the structure of each 2087 element in the returned dataset. 2088 2089 >>> dataset = tf.data.Dataset.range(3) 2090 >>> # `map_func` returns two `tf.Tensor` objects. 2091 >>> def g(x): 2092 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 2093 >>> result = dataset.map(g) 2094 >>> result.element_spec 2095 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \ 2096dtype=tf.string, name=None)) 2097 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to 2098 >>> # `tf.Tensor`. 2099 >>> def h(x): 2100 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64) 2101 >>> result = dataset.map(h) 2102 >>> result.element_spec 2103 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \ 2104dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \ 2105name=None)) 2106 >>> # `map_func` can return nested structures. 2107 >>> def i(x): 2108 ... return (37.0, [42, 16]), "foo" 2109 >>> result = dataset.map(i) 2110 >>> result.element_spec 2111 ((TensorSpec(shape=(), dtype=tf.float32, name=None), 2112 TensorSpec(shape=(2,), dtype=tf.int32, name=None)), 2113 TensorSpec(shape=(), dtype=tf.string, name=None)) 2114 2115 `map_func` can accept as arguments and return any type of dataset element. 2116 2117 Note that irrespective of the context in which `map_func` is defined (eager 2118 vs. graph), tf.data traces the function and executes it as a graph. To use 2119 Python code inside of the function you have a few options: 2120 2121 1) Rely on AutoGraph to convert Python code into an equivalent graph 2122 computation. The downside of this approach is that AutoGraph can convert 2123 some but not all Python code. 2124 2125 2) Use `tf.py_function`, which allows you to write arbitrary Python code but 2126 will generally result in worse performance than 1). For example: 2127 2128 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 2129 >>> # transform a string tensor to upper case string using a Python function 2130 >>> def upper_case_fn(t: tf.Tensor): 2131 ... return t.numpy().decode('utf-8').upper() 2132 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn, 2133 ... inp=[x], Tout=tf.string)) 2134 >>> list(d.as_numpy_iterator()) 2135 [b'HELLO', b'WORLD'] 2136 2137 3) Use `tf.numpy_function`, which also allows you to write arbitrary 2138 Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas 2139 `tf.numpy_function` accepts numpy arrays and returns only numpy arrays. 2140 For example: 2141 2142 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 2143 >>> def upper_case_fn(t: np.ndarray): 2144 ... return t.decode('utf-8').upper() 2145 >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn, 2146 ... inp=[x], Tout=tf.string)) 2147 >>> list(d.as_numpy_iterator()) 2148 [b'HELLO', b'WORLD'] 2149 2150 Note that the use of `tf.numpy_function` and `tf.py_function` 2151 in general precludes the possibility of executing user-defined 2152 transformations in parallel (because of Python GIL). 2153 2154 Performance can often be improved by setting `num_parallel_calls` so that 2155 `map` will use multiple threads to process elements. If deterministic order 2156 isn't required, it can also improve performance to set 2157 `deterministic=False`. 2158 2159 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 2160 >>> dataset = dataset.map(lambda x: x + 1, 2161 ... num_parallel_calls=tf.data.AUTOTUNE, 2162 ... deterministic=False) 2163 2164 The order of elements yielded by this transformation is deterministic if 2165 `deterministic=True`. If `map_func` contains stateful operations and 2166 `num_parallel_calls > 1`, the order in which that state is accessed is 2167 undefined, so the values of output elements may not be deterministic 2168 regardless of the `deterministic` flag value. 2169 2170 Args: 2171 map_func: A function mapping a dataset element to another dataset element. 2172 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 2173 representing the number elements to process asynchronously in parallel. 2174 If not specified, elements will be processed sequentially. If the value 2175 `tf.data.AUTOTUNE` is used, then the number of parallel 2176 calls is set dynamically based on available CPU. 2177 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 2178 boolean is specified (`True` or `False`), it controls the order in which 2179 the transformation produces elements. If set to `False`, the 2180 transformation is allowed to yield elements out of order to trade 2181 determinism for performance. If not specified, the 2182 `tf.data.Options.deterministic` option (`True` by default) controls the 2183 behavior. 2184 name: (Optional.) A name for the tf.data operation. 2185 2186 Returns: 2187 Dataset: A `Dataset`. 2188 """ 2189 if num_parallel_calls is None or DEBUG_MODE: 2190 if deterministic is not None and not DEBUG_MODE: 2191 warnings.warn("The `deterministic` argument has no effect unless the " 2192 "`num_parallel_calls` argument is specified.") 2193 return MapDataset(self, map_func, preserve_cardinality=True, name=name) 2194 else: 2195 return ParallelMapDataset( 2196 self, 2197 map_func, 2198 num_parallel_calls, 2199 deterministic, 2200 preserve_cardinality=True, 2201 name=name) 2202 2203 def flat_map(self, map_func, name=None): 2204 """Maps `map_func` across this dataset and flattens the result. 2205 2206 The type signature is: 2207 2208 ``` 2209 def flat_map( 2210 self: Dataset[T], 2211 map_func: Callable[[T], Dataset[S]] 2212 ) -> Dataset[S] 2213 ``` 2214 2215 Use `flat_map` if you want to make sure that the order of your dataset 2216 stays the same. For example, to flatten a dataset of batches into a 2217 dataset of their elements: 2218 2219 >>> dataset = tf.data.Dataset.from_tensor_slices( 2220 ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 2221 >>> dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices) 2222 >>> list(dataset.as_numpy_iterator()) 2223 [1, 2, 3, 4, 5, 6, 7, 8, 9] 2224 2225 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 2226 `flat_map` produces the same output as 2227 `tf.data.Dataset.interleave(cycle_length=1)` 2228 2229 Args: 2230 map_func: A function mapping a dataset element to a dataset. 2231 name: (Optional.) A name for the tf.data operation. 2232 2233 Returns: 2234 Dataset: A `Dataset`. 2235 """ 2236 return FlatMapDataset(self, map_func, name=name) 2237 2238 def interleave(self, 2239 map_func, 2240 cycle_length=None, 2241 block_length=None, 2242 num_parallel_calls=None, 2243 deterministic=None, 2244 name=None): 2245 """Maps `map_func` across this dataset, and interleaves the results. 2246 2247 The type signature is: 2248 2249 ``` 2250 def interleave( 2251 self: Dataset[T], 2252 map_func: Callable[[T], Dataset[S]] 2253 ) -> Dataset[S] 2254 ``` 2255 2256 For example, you can use `Dataset.interleave()` to process many input files 2257 concurrently: 2258 2259 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records 2260 >>> # from each file. 2261 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 2262 ... "/var/data/file3.txt", "/var/data/file4.txt"] 2263 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 2264 >>> def parse_fn(filename): 2265 ... return tf.data.Dataset.range(10) 2266 >>> dataset = dataset.interleave(lambda x: 2267 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 2268 ... cycle_length=4, block_length=16) 2269 2270 The `cycle_length` and `block_length` arguments control the order in which 2271 elements are produced. `cycle_length` controls the number of input elements 2272 that are processed concurrently. If you set `cycle_length` to 1, this 2273 transformation will handle one input element at a time, and will produce 2274 identical results to `tf.data.Dataset.flat_map`. In general, 2275 this transformation will apply `map_func` to `cycle_length` input elements, 2276 open iterators on the returned `Dataset` objects, and cycle through them 2277 producing `block_length` consecutive elements from each iterator, and 2278 consuming the next input element each time it reaches the end of an 2279 iterator. 2280 2281 For example: 2282 2283 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 2284 >>> # NOTE: New lines indicate "block" boundaries. 2285 >>> dataset = dataset.interleave( 2286 ... lambda x: Dataset.from_tensors(x).repeat(6), 2287 ... cycle_length=2, block_length=4) 2288 >>> list(dataset.as_numpy_iterator()) 2289 [1, 1, 1, 1, 2290 2, 2, 2, 2, 2291 1, 1, 2292 2, 2, 2293 3, 3, 3, 3, 2294 4, 4, 4, 4, 2295 3, 3, 2296 4, 4, 2297 5, 5, 5, 5, 2298 5, 5] 2299 2300 Note: The order of elements yielded by this transformation is 2301 deterministic, as long as `map_func` is a pure function and 2302 `deterministic=True`. If `map_func` contains any stateful operations, the 2303 order in which that state is accessed is undefined. 2304 2305 Performance can often be improved by setting `num_parallel_calls` so that 2306 `interleave` will use multiple threads to fetch elements. If determinism 2307 isn't required, it can also improve performance to set 2308 `deterministic=False`. 2309 2310 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 2311 ... "/var/data/file3.txt", "/var/data/file4.txt"] 2312 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 2313 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x), 2314 ... cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE, 2315 ... deterministic=False) 2316 2317 Args: 2318 map_func: A function that takes a dataset element and returns a 2319 `tf.data.Dataset`. 2320 cycle_length: (Optional.) The number of input elements that will be 2321 processed concurrently. If not set, the tf.data runtime decides what it 2322 should be based on available CPU. If `num_parallel_calls` is set to 2323 `tf.data.AUTOTUNE`, the `cycle_length` argument identifies 2324 the maximum degree of parallelism. 2325 block_length: (Optional.) The number of consecutive elements to produce 2326 from each input element before cycling to another input element. If not 2327 set, defaults to 1. 2328 num_parallel_calls: (Optional.) If specified, the implementation creates a 2329 threadpool, which is used to fetch inputs from cycle elements 2330 asynchronously and in parallel. The default behavior is to fetch inputs 2331 from cycle elements synchronously with no parallelism. If the value 2332 `tf.data.AUTOTUNE` is used, then the number of parallel 2333 calls is set dynamically based on available CPU. 2334 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 2335 boolean is specified (`True` or `False`), it controls the order in which 2336 the transformation produces elements. If set to `False`, the 2337 transformation is allowed to yield elements out of order to trade 2338 determinism for performance. If not specified, the 2339 `tf.data.Options.deterministic` option (`True` by default) controls the 2340 behavior. 2341 name: (Optional.) A name for the tf.data operation. 2342 2343 Returns: 2344 Dataset: A `Dataset`. 2345 """ 2346 if block_length is None: 2347 block_length = 1 2348 2349 if cycle_length is None: 2350 cycle_length = AUTOTUNE 2351 2352 if num_parallel_calls is None or DEBUG_MODE: 2353 if deterministic is not None and not DEBUG_MODE: 2354 warnings.warn("The `deterministic` argument has no effect unless the " 2355 "`num_parallel_calls` argument is specified.") 2356 return InterleaveDataset( 2357 self, map_func, cycle_length, block_length, name=name) 2358 else: 2359 return ParallelInterleaveDataset( 2360 self, 2361 map_func, 2362 cycle_length, 2363 block_length, 2364 num_parallel_calls, 2365 deterministic=deterministic, 2366 name=name) 2367 2368 def filter(self, predicate, name=None): 2369 """Filters this dataset according to `predicate`. 2370 2371 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 2372 >>> dataset = dataset.filter(lambda x: x < 3) 2373 >>> list(dataset.as_numpy_iterator()) 2374 [1, 2] 2375 >>> # `tf.math.equal(x, y)` is required for equality comparison 2376 >>> def filter_fn(x): 2377 ... return tf.math.equal(x, 1) 2378 >>> dataset = dataset.filter(filter_fn) 2379 >>> list(dataset.as_numpy_iterator()) 2380 [1] 2381 2382 Args: 2383 predicate: A function mapping a dataset element to a boolean. 2384 name: (Optional.) A name for the tf.data operation. 2385 2386 Returns: 2387 Dataset: The `Dataset` containing the elements of this dataset for which 2388 `predicate` is `True`. 2389 """ 2390 return FilterDataset(self, predicate, name=name) 2391 2392 def apply(self, transformation_func): 2393 """Applies a transformation function to this dataset. 2394 2395 `apply` enables chaining of custom `Dataset` transformations, which are 2396 represented as functions that take one `Dataset` argument and return a 2397 transformed `Dataset`. 2398 2399 >>> dataset = tf.data.Dataset.range(100) 2400 >>> def dataset_fn(ds): 2401 ... return ds.filter(lambda x: x < 5) 2402 >>> dataset = dataset.apply(dataset_fn) 2403 >>> list(dataset.as_numpy_iterator()) 2404 [0, 1, 2, 3, 4] 2405 2406 Args: 2407 transformation_func: A function that takes one `Dataset` argument and 2408 returns a `Dataset`. 2409 2410 Returns: 2411 Dataset: The `Dataset` returned by applying `transformation_func` to this 2412 dataset. 2413 """ 2414 dataset = transformation_func(self) 2415 if not isinstance(dataset, DatasetV2): 2416 raise TypeError( 2417 f"`transformation_func` must return a `tf.data.Dataset` object. " 2418 f"Got {type(dataset)}.") 2419 dataset._input_datasets = [self] # pylint: disable=protected-access 2420 return dataset 2421 2422 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None): 2423 """Returns a dataset of "windows". 2424 2425 Each "window" is a dataset that contains a subset of elements of the 2426 input dataset. These are finite datasets of size `size` (or possibly fewer 2427 if there are not enough input elements to fill the window and 2428 `drop_remainder` evaluates to `False`). 2429 2430 For example: 2431 2432 >>> dataset = tf.data.Dataset.range(7).window(3) 2433 >>> for window in dataset: 2434 ... print(window) 2435 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 2436 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 2437 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 2438 2439 Since windows are datasets, they can be iterated over: 2440 2441 >>> for window in dataset: 2442 ... print([item.numpy() for item in window]) 2443 [0, 1, 2] 2444 [3, 4, 5] 2445 [6] 2446 2447 #### Shift 2448 2449 The `shift` argument determines the number of input elements to shift 2450 between the start of each window. If windows and elements are both numbered 2451 starting at 0, the first element in window `k` will be element `k * shift` 2452 of the input dataset. In particular, the first element of the first window 2453 will always be the first element of the input dataset. 2454 2455 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, 2456 ... drop_remainder=True) 2457 >>> for window in dataset: 2458 ... print(list(window.as_numpy_iterator())) 2459 [0, 1, 2] 2460 [1, 2, 3] 2461 [2, 3, 4] 2462 [3, 4, 5] 2463 [4, 5, 6] 2464 2465 #### Stride 2466 2467 The `stride` argument determines the stride between input elements within a 2468 window. 2469 2470 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2, 2471 ... drop_remainder=True) 2472 >>> for window in dataset: 2473 ... print(list(window.as_numpy_iterator())) 2474 [0, 2, 4] 2475 [1, 3, 5] 2476 [2, 4, 6] 2477 2478 #### Nested elements 2479 2480 When the `window` transformation is applied to a dataset whos elements are 2481 nested structures, it produces a dataset where the elements have the same 2482 nested structure but each leaf is replaced by a window. In other words, 2483 the nesting is applied outside of the windows as opposed inside of them. 2484 2485 The type signature is: 2486 2487 ``` 2488 def window( 2489 self: Dataset[Nest[T]], ... 2490 ) -> Dataset[Nest[Dataset[T]]] 2491 ``` 2492 2493 Applying `window` to a `Dataset` of tuples gives a tuple of windows: 2494 2495 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5], 2496 ... [6, 7, 8, 9, 10])) 2497 >>> dataset = dataset.window(2) 2498 >>> windows = next(iter(dataset)) 2499 >>> windows 2500 (<...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>, 2501 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>) 2502 2503 >>> def to_numpy(ds): 2504 ... return list(ds.as_numpy_iterator()) 2505 >>> 2506 >>> for windows in dataset: 2507 ... print(to_numpy(windows[0]), to_numpy(windows[1])) 2508 [1, 2] [6, 7] 2509 [3, 4] [8, 9] 2510 [5] [10] 2511 2512 Applying `window` to a `Dataset` of dictionaries gives a dictionary of 2513 `Datasets`: 2514 2515 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3], 2516 ... 'b': [4, 5, 6], 2517 ... 'c': [7, 8, 9]}) 2518 >>> dataset = dataset.window(2) 2519 >>> def to_numpy(ds): 2520 ... return list(ds.as_numpy_iterator()) 2521 >>> 2522 >>> for windows in dataset: 2523 ... print(tf.nest.map_structure(to_numpy, windows)) 2524 {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]} 2525 {'a': [3], 'b': [6], 'c': [9]} 2526 2527 #### Flatten a dataset of windows 2528 2529 The `Dataset.flat_map` and `Dataset.interleave` methods can be used to 2530 flatten a dataset of windows into a single dataset. 2531 2532 The argument to `flat_map` is a function that takes an element from the 2533 dataset and returns a `Dataset`. `flat_map` chains together the resulting 2534 datasets sequentially. 2535 2536 For example, to turn each window into a dense tensor: 2537 2538 >>> size = 3 2539 >>> dataset = tf.data.Dataset.range(7).window(size, shift=1, 2540 ... drop_remainder=True) 2541 >>> batched = dataset.flat_map(lambda x:x.batch(3)) 2542 >>> for batch in batched: 2543 ... print(batch.numpy()) 2544 [0 1 2] 2545 [1 2 3] 2546 [2 3 4] 2547 [3 4 5] 2548 [4 5 6] 2549 2550 Args: 2551 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 2552 of the input dataset to combine into a window. Must be positive. 2553 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2554 number of input elements by which the window moves in each iteration. 2555 Defaults to `size`. Must be positive. 2556 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 2557 stride of the input elements in the sliding window. Must be positive. 2558 The default value of 1 means "retain every input element". 2559 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 2560 whether the last windows should be dropped if their size is smaller than 2561 `size`. 2562 name: (Optional.) A name for the tf.data operation. 2563 2564 Returns: 2565 Dataset: A `Dataset` of (nests of) windows. Each window is a finite 2566 datasets of flat elements. 2567 """ 2568 if shift is None: 2569 shift = size 2570 return WindowDataset(self, size, shift, stride, drop_remainder, name=name) 2571 2572 def reduce(self, initial_state, reduce_func, name=None): 2573 """Reduces the input dataset to a single element. 2574 2575 The transformation calls `reduce_func` successively on every element of 2576 the input dataset until the dataset is exhausted, aggregating information in 2577 its internal state. The `initial_state` argument is used for the initial 2578 state and the final state is returned as the result. 2579 2580 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy() 2581 5 2582 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy() 2583 10 2584 2585 Args: 2586 initial_state: An element representing the initial state of the 2587 transformation. 2588 reduce_func: A function that maps `(old_state, input_element)` to 2589 `new_state`. It must take two arguments and return a new element 2590 The structure of `new_state` must match the structure of 2591 `initial_state`. 2592 name: (Optional.) A name for the tf.data operation. 2593 2594 Returns: 2595 A dataset element corresponding to the final state of the transformation. 2596 2597 """ 2598 2599 with ops.name_scope("initial_state"): 2600 initial_state = structure.normalize_element(initial_state) 2601 state_structure = structure.type_spec_from_value(initial_state) 2602 2603 # Iteratively rerun the reduce function until reaching a fixed point on 2604 # `state_structure`. 2605 need_to_rerun = True 2606 while need_to_rerun: 2607 2608 wrapped_func = structured_function.StructuredFunctionWrapper( 2609 reduce_func, 2610 "reduce()", 2611 input_structure=(state_structure, self.element_spec), 2612 add_to_graph=False) 2613 2614 # Extract and validate class information from the returned values. 2615 output_classes = wrapped_func.output_classes 2616 state_classes = nest.map_structure( 2617 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2618 state_structure) 2619 for new_state_class, state_class in zip( 2620 nest.flatten(output_classes), nest.flatten(state_classes)): 2621 if not issubclass(new_state_class, state_class): 2622 raise TypeError( 2623 f"The element classes for the new state must match the initial " 2624 f"state. Expected {state_classes} but got " 2625 f"{wrapped_func.output_classes}.") 2626 2627 # Extract and validate type information from the returned values. 2628 output_types = wrapped_func.output_types 2629 state_types = nest.map_structure( 2630 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2631 state_structure) 2632 for new_state_type, state_type in zip( 2633 nest.flatten(output_types), nest.flatten(state_types)): 2634 if new_state_type != state_type: 2635 raise TypeError( 2636 f"The element types for the new state must match the initial " 2637 f"state. Expected {state_types} but got " 2638 f"{wrapped_func.output_types}.") 2639 2640 # Extract shape information from the returned values. 2641 output_shapes = wrapped_func.output_shapes 2642 state_shapes = nest.map_structure( 2643 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2644 state_structure) 2645 flat_state_shapes = nest.flatten(state_shapes) 2646 flat_new_state_shapes = nest.flatten(output_shapes) 2647 weakened_state_shapes = [ 2648 original.most_specific_compatible_shape(new) 2649 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 2650 ] 2651 2652 need_to_rerun = False 2653 for original_shape, weakened_shape in zip(flat_state_shapes, 2654 weakened_state_shapes): 2655 if original_shape.ndims is not None and ( 2656 weakened_shape.ndims is None or 2657 original_shape.as_list() != weakened_shape.as_list()): 2658 need_to_rerun = True 2659 break 2660 2661 if need_to_rerun: 2662 # TODO(b/110122868): Support a "most specific compatible structure" 2663 # method for combining structures, to avoid using legacy structures 2664 # here. 2665 state_structure = structure.convert_legacy_structure( 2666 state_types, 2667 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 2668 state_classes) 2669 2670 reduce_func = wrapped_func.function 2671 reduce_func.add_to_graph(ops.get_default_graph()) 2672 2673 dataset = self._apply_debug_options() 2674 2675 # pylint: disable=protected-access 2676 metadata = dataset_metadata_pb2.Metadata() 2677 if name: 2678 metadata.name = _validate_and_encode(name) 2679 return structure.from_compatible_tensor_list( 2680 state_structure, 2681 gen_dataset_ops.reduce_dataset( 2682 dataset._variant_tensor, 2683 structure.to_tensor_list(state_structure, initial_state), 2684 reduce_func.captured_inputs, 2685 f=reduce_func, 2686 output_shapes=structure.get_flat_tensor_shapes(state_structure), 2687 output_types=structure.get_flat_tensor_types(state_structure), 2688 metadata=metadata.SerializeToString())) 2689 2690 def get_single_element(self, name=None): 2691 """Returns the single element of the `dataset`. 2692 2693 The function enables you to use a `tf.data.Dataset` in a stateless 2694 "tensor-in tensor-out" expression, without creating an iterator. 2695 This facilitates the ease of data transformation on tensors using the 2696 optimized `tf.data.Dataset` abstraction on top of them. 2697 2698 For example, lets consider a `preprocessing_fn` which would take as an 2699 input the raw features and returns the processed feature along with 2700 it's label. 2701 2702 ```python 2703 def preprocessing_fn(raw_feature): 2704 # ... the raw_feature is preprocessed as per the use-case 2705 return feature 2706 2707 raw_features = ... # input batch of BATCH_SIZE elements. 2708 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 2709 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2710 .batch(BATCH_SIZE)) 2711 2712 processed_features = dataset.get_single_element() 2713 ``` 2714 2715 In the above example, the `raw_features` tensor of length=BATCH_SIZE 2716 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was 2717 mapped using the `preprocessing_fn` and the processed features were 2718 grouped into a single batch. The final `dataset` contains only one element 2719 which is a batch of all the processed features. 2720 2721 NOTE: The `dataset` should contain only one element. 2722 2723 Now, instead of creating an iterator for the `dataset` and retrieving the 2724 batch of features, the `tf.data.get_single_element()` function is used 2725 to skip the iterator creation process and directly output the batch of 2726 features. 2727 2728 This can be particularly useful when your tensor transformations are 2729 expressed as `tf.data.Dataset` operations, and you want to use those 2730 transformations while serving your model. 2731 2732 #### Keras 2733 2734 ```python 2735 2736 model = ... # A pre-built or custom model 2737 2738 class PreprocessingModel(tf.keras.Model): 2739 def __init__(self, model): 2740 super().__init__(self) 2741 self.model = model 2742 2743 @tf.function(input_signature=[...]) 2744 def serving_fn(self, data): 2745 ds = tf.data.Dataset.from_tensor_slices(data) 2746 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2747 ds = ds.batch(batch_size=BATCH_SIZE) 2748 return tf.argmax(self.model(ds.get_single_element()), axis=-1) 2749 2750 preprocessing_model = PreprocessingModel(model) 2751 your_exported_model_dir = ... # save the model to this path. 2752 tf.saved_model.save(preprocessing_model, your_exported_model_dir, 2753 signatures={'serving_default': preprocessing_model.serving_fn} 2754 ) 2755 ``` 2756 2757 #### Estimator 2758 2759 In the case of estimators, you need to generally define a `serving_input_fn` 2760 which would require the features to be processed by the model while 2761 inferencing. 2762 2763 ```python 2764 def serving_input_fn(): 2765 2766 raw_feature_spec = ... # Spec for the raw_features 2767 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( 2768 raw_feature_spec, default_batch_size=None) 2769 ) 2770 serving_input_receiver = input_fn() 2771 raw_features = serving_input_receiver.features 2772 2773 def preprocessing_fn(raw_feature): 2774 # ... the raw_feature is preprocessed as per the use-case 2775 return feature 2776 2777 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 2778 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 2779 .batch(BATCH_SIZE)) 2780 2781 processed_features = dataset.get_single_element() 2782 2783 # Please note that the value of `BATCH_SIZE` should be equal to 2784 # the size of the leading dimension of `raw_features`. This ensures 2785 # that `dataset` has only element, which is a pre-requisite for 2786 # using `dataset.get_single_element()`. 2787 2788 return tf.estimator.export.ServingInputReceiver( 2789 processed_features, serving_input_receiver.receiver_tensors) 2790 2791 estimator = ... # A pre-built or custom estimator 2792 estimator.export_saved_model(your_exported_model_dir, serving_input_fn) 2793 ``` 2794 2795 Args: 2796 name: (Optional.) A name for the tf.data operation. 2797 2798 Returns: 2799 A nested structure of `tf.Tensor` objects, corresponding to the single 2800 element of `dataset`. 2801 2802 Raises: 2803 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly 2804 one element. 2805 """ 2806 2807 metadata = dataset_metadata_pb2.Metadata() 2808 if name: 2809 metadata.name = _validate_and_encode(name) 2810 return structure.from_compatible_tensor_list( 2811 self.element_spec, 2812 gen_dataset_ops.dataset_to_single_element( 2813 self._variant_tensor, 2814 metadata=metadata.SerializeToString(), 2815 **self._flat_structure)) # pylint: disable=protected-access 2816 2817 def unbatch(self, name=None): 2818 """Splits elements of a dataset into multiple elements. 2819 2820 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 2821 where `B` may vary for each input element, then for each element in the 2822 dataset, the unbatched dataset will contain `B` consecutive elements 2823 of shape `[a0, a1, ...]`. 2824 2825 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] 2826 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64) 2827 >>> dataset = dataset.unbatch() 2828 >>> list(dataset.as_numpy_iterator()) 2829 [1, 2, 3, 1, 2, 1, 2, 3, 4] 2830 2831 Note: `unbatch` requires a data copy to slice up the batched tensor into 2832 smaller, unbatched tensors. When optimizing performance, try to avoid 2833 unnecessary usage of `unbatch`. 2834 2835 Args: 2836 name: (Optional.) A name for the tf.data operation. 2837 2838 Returns: 2839 A `Dataset`. 2840 """ 2841 normalized_dataset = normalize_to_dense(self) 2842 return _UnbatchDataset(normalized_dataset, name=name) 2843 2844 def with_options(self, options, name=None): 2845 """Returns a new `tf.data.Dataset` with the given options set. 2846 2847 The options are "global" in the sense they apply to the entire dataset. 2848 If options are set multiple times, they are merged as long as different 2849 options do not use different non-default values. 2850 2851 >>> ds = tf.data.Dataset.range(5) 2852 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5), 2853 ... cycle_length=3, 2854 ... num_parallel_calls=3) 2855 >>> options = tf.data.Options() 2856 >>> # This will make the interleave order non-deterministic. 2857 >>> options.deterministic = False 2858 >>> ds = ds.with_options(options) 2859 2860 Args: 2861 options: A `tf.data.Options` that identifies the options the use. 2862 name: (Optional.) A name for the tf.data operation. 2863 2864 Returns: 2865 Dataset: A `Dataset` with the given options. 2866 2867 Raises: 2868 ValueError: when an option is set more than once to a non-default value 2869 """ 2870 return _OptionsDataset(self, options, name=name) 2871 2872 def cardinality(self): 2873 """Returns the cardinality of the dataset, if known. 2874 2875 `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset 2876 contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if 2877 the analysis fails to determine the number of elements in the dataset 2878 (e.g. when the dataset source is a file). 2879 2880 >>> dataset = tf.data.Dataset.range(42) 2881 >>> print(dataset.cardinality().numpy()) 2882 42 2883 >>> dataset = dataset.repeat() 2884 >>> cardinality = dataset.cardinality() 2885 >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy()) 2886 True 2887 >>> dataset = dataset.filter(lambda x: True) 2888 >>> cardinality = dataset.cardinality() 2889 >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()) 2890 True 2891 2892 Returns: 2893 A scalar `tf.int64` `Tensor` representing the cardinality of the dataset. 2894 If the cardinality is infinite or unknown, `cardinality` returns the 2895 named constants `tf.data.INFINITE_CARDINALITY` and 2896 `tf.data.UNKNOWN_CARDINALITY` respectively. 2897 """ 2898 return gen_dataset_ops.dataset_cardinality(self._variant_tensor) 2899 2900 def group_by_window(self, 2901 key_func, 2902 reduce_func, 2903 window_size=None, 2904 window_size_func=None, 2905 name=None): 2906 """Groups windows of elements by key and reduces them. 2907 2908 This transformation maps each consecutive element in a dataset to a key 2909 using `key_func` and groups the elements by key. It then applies 2910 `reduce_func` to at most `window_size_func(key)` elements matching the same 2911 key. All except the final window for each key will contain 2912 `window_size_func(key)` elements; the final window may be smaller. 2913 2914 You may provide either a constant `window_size` or a window size determined 2915 by the key through `window_size_func`. 2916 2917 >>> dataset = tf.data.Dataset.range(10) 2918 >>> window_size = 5 2919 >>> key_func = lambda x: x%2 2920 >>> reduce_func = lambda key, dataset: dataset.batch(window_size) 2921 >>> dataset = dataset.group_by_window( 2922 ... key_func=key_func, 2923 ... reduce_func=reduce_func, 2924 ... window_size=window_size) 2925 >>> for elem in dataset.as_numpy_iterator(): 2926 ... print(elem) 2927 [0 2 4 6 8] 2928 [1 3 5 7 9] 2929 2930 Args: 2931 key_func: A function mapping a nested structure of tensors (having shapes 2932 and types defined by `self.output_shapes` and `self.output_types`) to a 2933 scalar `tf.int64` tensor. 2934 reduce_func: A function mapping a key and a dataset of up to `window_size` 2935 consecutive elements matching that key to another dataset. 2936 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 2937 consecutive elements matching the same key to combine in a single batch, 2938 which will be passed to `reduce_func`. Mutually exclusive with 2939 `window_size_func`. 2940 window_size_func: A function mapping a key to a `tf.int64` scalar 2941 `tf.Tensor`, representing the number of consecutive elements matching 2942 the same key to combine in a single batch, which will be passed to 2943 `reduce_func`. Mutually exclusive with `window_size`. 2944 name: (Optional.) A name for the tf.data operation. 2945 2946 Returns: 2947 A `Dataset`. 2948 2949 Raises: 2950 ValueError: if neither or both of {`window_size`, `window_size_func`} are 2951 passed. 2952 """ 2953 if (window_size is not None and window_size_func or 2954 not (window_size is not None or window_size_func)): 2955 raise ValueError("Either the `window_size` argument or the " 2956 "`window_size_func` argument must be specified.") 2957 2958 if window_size is not None: 2959 2960 def constant_window_func(unused_key): 2961 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 2962 2963 window_size_func = constant_window_func 2964 2965 assert window_size_func is not None 2966 2967 return _GroupByWindowDataset( 2968 self, key_func, reduce_func, window_size_func, name=name) 2969 2970 def bucket_by_sequence_length(self, 2971 element_length_func, 2972 bucket_boundaries, 2973 bucket_batch_sizes, 2974 padded_shapes=None, 2975 padding_values=None, 2976 pad_to_bucket_boundary=False, 2977 no_padding=False, 2978 drop_remainder=False, 2979 name=None): 2980 """A transformation that buckets elements in a `Dataset` by length. 2981 2982 Elements of the `Dataset` are grouped together by length and then are padded 2983 and batched. 2984 2985 This is useful for sequence tasks in which the elements have variable 2986 length. Grouping together elements that have similar lengths reduces the 2987 total fraction of padding in a batch which increases training step 2988 efficiency. 2989 2990 Below is an example to bucketize the input data to the 3 buckets 2991 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2. 2992 2993 >>> elements = [ 2994 ... [0], [1, 2, 3, 4], [5, 6, 7], 2995 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 2996 >>> dataset = tf.data.Dataset.from_generator( 2997 ... lambda: elements, tf.int64, output_shapes=[None]) 2998 >>> dataset = dataset.bucket_by_sequence_length( 2999 ... element_length_func=lambda elem: tf.shape(elem)[0], 3000 ... bucket_boundaries=[3, 5], 3001 ... bucket_batch_sizes=[2, 2, 2]) 3002 >>> for elem in dataset.as_numpy_iterator(): 3003 ... print(elem) 3004 [[1 2 3 4] 3005 [5 6 7 0]] 3006 [[ 7 8 9 10 11 0] 3007 [13 14 15 16 19 20]] 3008 [[ 0 0] 3009 [21 22]] 3010 3011 Args: 3012 element_length_func: function from element in `Dataset` to `tf.int32`, 3013 determines the length of the element, which will determine the bucket it 3014 goes into. 3015 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 3016 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 3017 `len(bucket_boundaries) + 1`. 3018 padded_shapes: Nested structure of `tf.TensorShape` to pass to 3019 `tf.data.Dataset.padded_batch`. If not provided, will use 3020 `dataset.output_shapes`, which will result in variable length dimensions 3021 being padded out to the maximum length in each batch. 3022 padding_values: Values to pad with, passed to 3023 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 3024 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 3025 size to maximum length in batch. If `True`, will pad dimensions with 3026 unknown size to bucket boundary minus 1 (i.e., the maximum length in 3027 each bucket), and caller must ensure that the source `Dataset` does not 3028 contain any elements with length longer than `max(bucket_boundaries)`. 3029 no_padding: `bool`, indicates whether to pad the batch features (features 3030 need to be either of type `tf.sparse.SparseTensor` or of same shape). 3031 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 3032 whether the last batch should be dropped in the case it has fewer than 3033 `batch_size` elements; the default behavior is not to drop the smaller 3034 batch. 3035 name: (Optional.) A name for the tf.data operation. 3036 3037 Returns: 3038 A `Dataset`. 3039 3040 Raises: 3041 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 3042 """ 3043 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): 3044 raise ValueError( 3045 f"`len(bucket_batch_sizes)` must equal `len(bucket_boundaries) + 1` " 3046 f"but `len(bucket_batch_sizes)={len(bucket_batch_sizes)}` and " 3047 f"`len(bucket_boundaries)={len(bucket_boundaries)}`.") 3048 3049 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) 3050 3051 def element_to_bucket_id(*args): 3052 """Return int64 id of the length bucket for this element.""" 3053 seq_length = element_length_func(*args) 3054 3055 boundaries = list(bucket_boundaries) 3056 buckets_min = [np.iinfo(np.int32).min] + boundaries 3057 buckets_max = boundaries + [np.iinfo(np.int32).max] 3058 conditions_c = math_ops.logical_and( 3059 math_ops.less_equal(buckets_min, seq_length), 3060 math_ops.less(seq_length, buckets_max)) 3061 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) 3062 3063 return bucket_id 3064 3065 def window_size_fn(bucket_id): 3066 # The window size is set to the batch size for this bucket 3067 window_size = batch_sizes[bucket_id] 3068 return window_size 3069 3070 def make_padded_shapes(shapes, none_filler=None): 3071 padded = [] 3072 for shape in nest.flatten(shapes): 3073 shape = tensor_shape.TensorShape(shape) 3074 shape = [ 3075 none_filler if tensor_shape.dimension_value(d) is None else d 3076 for d in shape 3077 ] 3078 padded.append(shape) 3079 return nest.pack_sequence_as(shapes, padded) 3080 3081 def batching_fn(bucket_id, grouped_dataset): 3082 """Batch elements in dataset.""" 3083 batch_size = window_size_fn(bucket_id) 3084 if no_padding: 3085 return grouped_dataset.batch( 3086 batch_size, drop_remainder=drop_remainder, name=name) 3087 none_filler = None 3088 if pad_to_bucket_boundary: 3089 err_msg = ("When pad_to_bucket_boundary=True, elements must have " 3090 "length < max(bucket_boundaries).") 3091 check = check_ops.assert_less( 3092 bucket_id, 3093 constant_op.constant( 3094 len(bucket_batch_sizes) - 1, dtype=dtypes.int64), 3095 message=err_msg) 3096 with ops.control_dependencies([check]): 3097 boundaries = constant_op.constant( 3098 bucket_boundaries, dtype=dtypes.int64) 3099 bucket_boundary = boundaries[bucket_id] 3100 none_filler = bucket_boundary - 1 3101 input_shapes = get_legacy_output_shapes(grouped_dataset) 3102 shapes = make_padded_shapes( 3103 padded_shapes or input_shapes, none_filler=none_filler) 3104 return grouped_dataset.padded_batch( 3105 batch_size, 3106 shapes, 3107 padding_values, 3108 drop_remainder=drop_remainder, 3109 name=name) 3110 3111 return self.group_by_window( 3112 key_func=element_to_bucket_id, 3113 reduce_func=batching_fn, 3114 window_size_func=window_size_fn, 3115 name=name) 3116 3117 @staticmethod 3118 def random(seed=None, name=None): 3119 """Creates a `Dataset` of pseudorandom values. 3120 3121 The dataset generates a sequence of uniformly distributed integer values. 3122 3123 >>> ds1 = tf.data.Dataset.random(seed=4).take(10) 3124 >>> ds2 = tf.data.Dataset.random(seed=4).take(10) 3125 >>> print(list(ds2.as_numpy_iterator())==list(ds2.as_numpy_iterator())) 3126 True 3127 3128 Args: 3129 seed: (Optional) If specified, the dataset produces a deterministic 3130 sequence of values. 3131 name: (Optional.) A name for the tf.data operation. 3132 3133 Returns: 3134 Dataset: A `Dataset`. 3135 """ 3136 return RandomDataset(seed=seed, name=name) 3137 3138 def snapshot(self, 3139 path, 3140 compression="AUTO", 3141 reader_func=None, 3142 shard_func=None, 3143 name=None): 3144 """API to persist the output of the input dataset. 3145 3146 The snapshot API allows users to transparently persist the output of their 3147 preprocessing pipeline to disk, and materialize the pre-processed data on a 3148 different training run. 3149 3150 This API enables repeated preprocessing steps to be consolidated, and allows 3151 re-use of already processed data, trading off disk storage and network 3152 bandwidth for freeing up more valuable CPU resources and accelerator compute 3153 time. 3154 3155 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 3156 has detailed design documentation of this feature. 3157 3158 Users can specify various options to control the behavior of snapshot, 3159 including how snapshots are read from and written to by passing in 3160 user-defined functions to the `reader_func` and `shard_func` parameters. 3161 3162 `shard_func` is a user specified function that maps input elements to 3163 snapshot shards. 3164 3165 Users may want to specify this function to control how snapshot files should 3166 be written to disk. Below is an example of how a potential `shard_func` 3167 could be written. 3168 3169 ```python 3170 dataset = ... 3171 dataset = dataset.enumerate() 3172 dataset = dataset.snapshot("/path/to/snapshot/dir", 3173 shard_func=lambda x, y: x % NUM_SHARDS, ...) 3174 dataset = dataset.map(lambda x, y: y) 3175 ``` 3176 3177 `reader_func` is a user specified function that accepts a single argument: 3178 (1) a Dataset of Datasets, each representing a "split" of elements of the 3179 original dataset. The cardinality of the input dataset matches the 3180 number of the shards specified in the `shard_func` (see above). The function 3181 should return a Dataset of elements of the original dataset. 3182 3183 Users may want specify this function to control how snapshot files should be 3184 read from disk, including the amount of shuffling and parallelism. 3185 3186 Here is an example of a standard reader function a user can define. This 3187 function enables both dataset shuffling and parallel reading of datasets: 3188 3189 ```python 3190 def user_reader_func(datasets): 3191 # shuffle the datasets splits 3192 datasets = datasets.shuffle(NUM_CORES) 3193 # read datasets in parallel and interleave their elements 3194 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 3195 3196 dataset = dataset.snapshot("/path/to/snapshot/dir", 3197 reader_func=user_reader_func) 3198 ``` 3199 3200 By default, snapshot parallelizes reads by the number of cores available on 3201 the system, but will not attempt to shuffle the data. 3202 3203 Args: 3204 path: Required. A directory to use for storing / loading the snapshot to / 3205 from. 3206 compression: Optional. The type of compression to apply to the snapshot 3207 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 3208 Defaults to `AUTO`, which attempts to pick an appropriate compression 3209 algorithm for the dataset. 3210 reader_func: Optional. A function to control how to read data from 3211 snapshot shards. 3212 shard_func: Optional. A function to control how to shard data when writing 3213 a snapshot. 3214 name: (Optional.) A name for the tf.data operation. 3215 3216 Returns: 3217 A `Dataset`. 3218 """ 3219 3220 project_func = None 3221 input_dataset = self 3222 if shard_func is None: 3223 input_dataset = input_dataset.enumerate(name=name) 3224 # This sets the amount of parallelism based on the number of CPU cores on 3225 # the machine where this Python code is executed, which may differ from 3226 # the number of CPU cores where the input pipeline graph is actually 3227 # executed (e.g. remote Cloud TPU workers). 3228 local_shard_func = lambda index, _: index % multiprocessing.cpu_count() 3229 project_func = lambda _, elem: elem 3230 else: 3231 local_shard_func = shard_func 3232 dataset = _SnapshotDataset( 3233 input_dataset=input_dataset, 3234 path=path, 3235 compression=compression, 3236 reader_func=reader_func, 3237 # This will not do the right thing where the graph is built on a 3238 # different machine than the executor (e.g. Cloud TPUs). 3239 shard_func=local_shard_func, 3240 name=name) 3241 if project_func is not None: 3242 dataset = dataset.map(project_func, name=name) 3243 return dataset 3244 3245 def scan(self, initial_state, scan_func, name=None): 3246 """A transformation that scans a function across an input dataset. 3247 3248 This transformation is a stateful relative of `tf.data.Dataset.map`. 3249 In addition to mapping `scan_func` across the elements of the input dataset, 3250 `scan()` accumulates one or more state tensors, whose initial values are 3251 `initial_state`. 3252 3253 >>> dataset = tf.data.Dataset.range(10) 3254 >>> initial_state = tf.constant(0, dtype=tf.int64) 3255 >>> scan_func = lambda state, i: (state + i, state + i) 3256 >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func) 3257 >>> list(dataset.as_numpy_iterator()) 3258 [0, 1, 3, 6, 10, 15, 21, 28, 36, 45] 3259 3260 Args: 3261 initial_state: A nested structure of tensors, representing the initial 3262 state of the accumulator. 3263 scan_func: A function that maps `(old_state, input_element)` to 3264 `(new_state, output_element)`. It must take two arguments and return a 3265 pair of nested structures of tensors. The `new_state` must match the 3266 structure of `initial_state`. 3267 name: (Optional.) A name for the tf.data operation. 3268 3269 Returns: 3270 A `Dataset`. 3271 """ 3272 3273 return _ScanDataset( 3274 self, initial_state=initial_state, scan_func=scan_func, name=name) 3275 3276 def take_while(self, predicate, name=None): 3277 """A transformation that stops dataset iteration based on a `predicate`. 3278 3279 >>> dataset = tf.data.Dataset.range(10) 3280 >>> dataset = dataset.take_while(lambda x: x < 5) 3281 >>> list(dataset.as_numpy_iterator()) 3282 [0, 1, 2, 3, 4] 3283 3284 Args: 3285 predicate: A function that maps a nested structure of tensors (having 3286 shapes and types defined by `self.output_shapes` and 3287 `self.output_types`) to a scalar `tf.bool` tensor. 3288 name: (Optional.) A name for the tf.data operation. 3289 3290 Returns: 3291 A `Dataset`. 3292 """ 3293 3294 return _TakeWhileDataset(self, predicate, name=name) 3295 3296 def unique(self, name=None): 3297 """A transformation that discards duplicate elements of a `Dataset`. 3298 3299 Use this transformation to produce a dataset that contains one instance of 3300 each unique element in the input. For example: 3301 3302 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) 3303 >>> dataset = dataset.unique() 3304 >>> sorted(list(dataset.as_numpy_iterator())) 3305 [1, 2, 37] 3306 3307 Note: This transformation only supports datasets which fit into memory 3308 and have elements of either `tf.int32`, `tf.int64` or `tf.string` type. 3309 3310 Args: 3311 name: (Optional.) A name for the tf.data operation. 3312 3313 Returns: 3314 A `Dataset`. 3315 """ 3316 3317 return _UniqueDataset(self, name=name) 3318 3319 def rejection_resample(self, 3320 class_func, 3321 target_dist, 3322 initial_dist=None, 3323 seed=None, 3324 name=None): 3325 """A transformation that resamples a dataset to a target distribution. 3326 3327 Lets consider the following example where a dataset with an initial data 3328 distribution of `init_dist` needs to be resampled into a dataset with 3329 `target_dist` distribution. 3330 3331 >>> initial_dist = [0.6, 0.4] 3332 >>> num_classes = len(initial_dist) 3333 >>> num_samples = 1000 3334 >>> data_np = np.random.choice(num_classes, num_samples, p=initial_dist) 3335 >>> dataset = tf.data.Dataset.from_tensor_slices(data_np) 3336 3337 The value of `x` will be close to `{0: 50000, 1: 50000}` as per the 3338 `initial_dist` distribution. 3339 3340 >>> target_dist = [0.5, 0.5] 3341 >>> resampled_dataset = dataset.rejection_resample( 3342 ... class_func=lambda x: x, 3343 ... target_dist=target_dist, 3344 ... initial_dist=initial_dist) 3345 >>> resampled_dataset = resampled_dataset.map( 3346 ... lambda class_func_result, data: data) 3347 3348 3349 The value distribution of classes in the resampled_distribution will be now 3350 be close to the target distribution. 3351 3352 Args: 3353 class_func: A function mapping an element of the input dataset to a scalar 3354 `tf.int32` tensor. Values should be in `[0, num_classes)`. 3355 target_dist: A floating point type tensor, shaped `[num_classes]`. 3356 initial_dist: (Optional.) A floating point type tensor, shaped 3357 `[num_classes]`. If not provided, the true class distribution is 3358 estimated live in a streaming fashion. 3359 seed: (Optional.) Python integer seed for the resampler. 3360 name: (Optional.) A name for the tf.data operation. 3361 3362 Returns: 3363 A `Dataset` 3364 """ 3365 3366 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") 3367 target_dist_t = math_ops.cast(target_dist_t, dtypes.float32) 3368 3369 # Get initial distribution. 3370 if initial_dist is not None: 3371 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") 3372 initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32) 3373 acceptance_dist, prob_of_original = ( 3374 _calculate_acceptance_probs_with_mixing(initial_dist_t, 3375 target_dist_t)) 3376 initial_dist_ds = DatasetV2.from_tensors( 3377 initial_dist_t, name=name).repeat(name=name) 3378 acceptance_dist_ds = DatasetV2.from_tensors( 3379 acceptance_dist, name=name).repeat(name=name) 3380 prob_of_original_ds = DatasetV2.from_tensors( 3381 prob_of_original, name=name).repeat(name=name) 3382 else: 3383 initial_dist_ds = _estimate_initial_dist_ds( 3384 target_dist_t, self.map(class_func, name=name), name=name) 3385 acceptance_and_original_prob_ds = initial_dist_ds.map( 3386 lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda 3387 initial, target_dist_t), 3388 name=name) 3389 acceptance_dist_ds = acceptance_and_original_prob_ds.map( 3390 lambda accept_prob, _: accept_prob, name=name) 3391 prob_of_original_ds = acceptance_and_original_prob_ds.map( 3392 lambda _, prob_original: prob_original, name=name) 3393 filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds, 3394 class_func, seed) 3395 # Prefetch filtered dataset for speed. 3396 filtered_ds = filtered_ds.prefetch(3, name=name) 3397 3398 prob_original_static = _get_prob_original_static( 3399 initial_dist_t, target_dist_t) if initial_dist is not None else None 3400 3401 def add_class_value(*x): 3402 if len(x) == 1: 3403 return class_func(*x), x[0] 3404 else: 3405 return class_func(*x), x 3406 3407 if prob_original_static == 1: 3408 return self.map(add_class_value, name=name) 3409 elif prob_original_static == 0: 3410 return filtered_ds 3411 else: 3412 return Dataset.sample_from_datasets( 3413 [self.map(add_class_value), filtered_ds], 3414 weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), 3415 seed=seed, 3416 stop_on_empty_dataset=True) 3417 3418 @staticmethod 3419 def sample_from_datasets(datasets, 3420 weights=None, 3421 seed=None, 3422 stop_on_empty_dataset=False): 3423 """Samples elements at random from the datasets in `datasets`. 3424 3425 Creates a dataset by interleaving elements of `datasets` with `weight[i]` 3426 probability of picking an element from dataset `i`. Sampling is done without 3427 replacement. For example, suppose we have 2 datasets: 3428 3429 ```python 3430 dataset1 = tf.data.Dataset.range(0, 3) 3431 dataset2 = tf.data.Dataset.range(100, 103) 3432 ``` 3433 3434 Suppose that we sample from these 2 datasets with the following weights: 3435 3436 ```python 3437 sample_dataset = tf.data.Dataset.sample_from_datasets( 3438 [dataset1, dataset2], weights=[0.5, 0.5]) 3439 ``` 3440 3441 One possible outcome of elements in sample_dataset is: 3442 3443 ``` 3444 print(list(sample_dataset.as_numpy_iterator())) 3445 # [100, 0, 1, 101, 2, 102] 3446 ``` 3447 3448 Args: 3449 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 3450 structure. 3451 weights: (Optional.) A list or Tensor of `len(datasets)` floating-point 3452 values where `weights[i]` represents the probability to sample from 3453 `datasets[i]`, or a `tf.data.Dataset` object where each element is such 3454 a list. Defaults to a uniform distribution across `datasets`. 3455 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 3456 seed that will be used to create the distribution. See 3457 `tf.random.set_seed` for behavior. 3458 stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty 3459 dataset. If `False`, it continues sampling and skips any empty datasets. 3460 It is recommended to set it to `True`. Otherwise, the distribution of 3461 samples starts off as the user intends, but may change as input datasets 3462 become empty. This can be difficult to detect since the dataset starts 3463 off looking correct. Default to `False` for backward compatibility. 3464 3465 Returns: 3466 A dataset that interleaves elements from `datasets` at random, according 3467 to `weights` if provided, otherwise with uniform probability. 3468 3469 Raises: 3470 TypeError: If the `datasets` or `weights` arguments have the wrong type. 3471 ValueError: 3472 - If `datasets` is empty, or 3473 - If `weights` is specified and does not match the length of `datasets`. 3474 """ 3475 3476 def _skip_datasets_with_zero_weight(datasets, weights): 3477 datasets_and_weights = [(dataset, weight) 3478 for (dataset, weight) in zip(datasets, weights) 3479 if weight > 0] 3480 return (zip(*datasets_and_weights) if datasets_and_weights else 3481 ([datasets[0].take(0)], [1.])) 3482 3483 if not datasets: 3484 raise ValueError("Invalid `datasets`. `datasets` should not be empty.") 3485 3486 if not isinstance(weights, DatasetV2): 3487 if weights is None: 3488 # Select inputs with uniform probability. 3489 logits = [[1.0] * len(datasets)] 3490 3491 else: 3492 if isinstance(weights, ops.Tensor): 3493 if not weights.shape.is_compatible_with([len(datasets)]): 3494 raise ValueError(f"Invalid `weights`. The shape of `weights` " 3495 f"should be compatible with `[len(datasets)]` " 3496 f"but is {weights.shape}.") 3497 else: 3498 if len(datasets) != len(weights): 3499 raise ValueError(f"Invalid `weights`. `weights` should have the " 3500 f"same length as `datasets` but got " 3501 f"`len(weights)={len(weights)}` vs. " 3502 f"`len(datasets)={len(datasets)}`.") 3503 3504 # Use the given `weights` as the probability of choosing the respective 3505 # input. 3506 if not isinstance(weights, ops.Tensor): 3507 datasets, weights = _skip_datasets_with_zero_weight(datasets, weights) 3508 weights = ops.convert_to_tensor(weights, name="weights") 3509 if weights.dtype not in (dtypes.float32, dtypes.float64): 3510 raise TypeError(f"Invalid `weights`. `weights` type must be either " 3511 f"`tf.float32` or `tf.float64` but is " 3512 f"{weights.dtype}.") 3513 3514 # The `stateless_multinomial()` op expects log-probabilities, as opposed 3515 # to weights. 3516 logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) 3517 3518 # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When 3519 # it is a `Dataset`, it is possible that evaluating it has a side effect 3520 # the user depends on. 3521 if len(datasets) == 1: 3522 return datasets[0] 3523 3524 def select_dataset_constant_logits(seed): 3525 return array_ops.squeeze( 3526 gen_stateless_random_ops.stateless_multinomial( 3527 logits, 1, seed=seed), 3528 axis=[0, 1]) 3529 3530 selector_input = MapDataset( 3531 RandomDataset(seed).batch(2), 3532 select_dataset_constant_logits, 3533 use_inter_op_parallelism=False) 3534 3535 else: 3536 # Use each element of the given `weights` dataset as the probability of 3537 # choosing the respective input. 3538 # 3539 # The `stateless_multinomial()` op expects log-probabilities, as opposed 3540 # to weights. 3541 logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) 3542 3543 def select_dataset_varying_logits(logits, seed): 3544 return array_ops.squeeze( 3545 gen_stateless_random_ops.stateless_multinomial( 3546 logits, 1, seed=seed), 3547 axis=[0, 1]) 3548 3549 logits_and_seeds = Dataset.zip((logits_ds, RandomDataset(seed).batch(2))) 3550 selector_input = MapDataset( 3551 logits_and_seeds, 3552 select_dataset_varying_logits, 3553 use_inter_op_parallelism=False) 3554 3555 return _DirectedInterleaveDataset(selector_input, datasets, 3556 stop_on_empty_dataset) 3557 3558 @staticmethod 3559 def choose_from_datasets(datasets, 3560 choice_dataset, 3561 stop_on_empty_dataset=True): 3562 """Creates a dataset that deterministically chooses elements from `datasets`. 3563 3564 For example, given the following datasets: 3565 3566 ```python 3567 datasets = [tf.data.Dataset.from_tensors("foo").repeat(), 3568 tf.data.Dataset.from_tensors("bar").repeat(), 3569 tf.data.Dataset.from_tensors("baz").repeat()] 3570 3571 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. 3572 choice_dataset = tf.data.Dataset.range(3).repeat(3) 3573 3574 result = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset) 3575 ``` 3576 3577 The elements of `result` will be: 3578 3579 ``` 3580 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" 3581 ``` 3582 3583 Args: 3584 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 3585 structure. 3586 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between 3587 `0` and `len(datasets) - 1`. 3588 stop_on_empty_dataset: If `True`, selection stops if it encounters an 3589 empty dataset. If `False`, it skips empty datasets. It is recommended to 3590 set it to `True`. Otherwise, the selected elements start off as the user 3591 intends, but may change as input datasets become empty. This can be 3592 difficult to detect since the dataset starts off looking correct. 3593 Defaults to `True`. 3594 3595 Returns: 3596 A dataset that interleaves elements from `datasets` according to the 3597 values of `choice_dataset`. 3598 3599 Raises: 3600 TypeError: If `datasets` or `choice_dataset` has the wrong type. 3601 ValueError: If `datasets` is empty. 3602 """ 3603 if not datasets: 3604 raise ValueError("Invalid `datasets`. `datasets` should not be empty.") 3605 if not isinstance(choice_dataset, DatasetV2): 3606 raise TypeError(f"Invalid `choice_dataset`. `choice_dataset` should be a " 3607 f"`tf.data.Dataset` but is {type(choice_dataset)}.") 3608 if not structure.are_compatible(choice_dataset.element_spec, 3609 tensor_spec.TensorSpec([], dtypes.int64)): 3610 raise TypeError(f"Invalid `choice_dataset`. Elements of `choice_dataset` " 3611 f"must be scalar `tf.int64` tensors but are " 3612 f"{choice_dataset.element_spec}.") 3613 # Replicate the `choice_dataset` component so that each split makes choices 3614 # independently. This avoids the need for prohibitively expensive 3615 # cross-split coordination. 3616 choice_dataset = _apply_rewrite(choice_dataset, "replicate_on_split") 3617 # pylint: disable=protected-access 3618 return _DirectedInterleaveDataset(choice_dataset, datasets, 3619 stop_on_empty_dataset) 3620 3621 3622@tf_export(v1=["data.Dataset"]) 3623class DatasetV1(DatasetV2): 3624 """Represents a potentially large set of elements. 3625 3626 A `Dataset` can be used to represent an input pipeline as a 3627 collection of elements and a "logical plan" of transformations that act on 3628 those elements. 3629 """ 3630 3631 def __init__(self): 3632 try: 3633 variant_tensor = self._as_variant_tensor() 3634 except AttributeError as e: 3635 if "_as_variant_tensor" in str(e): 3636 raise AttributeError("Please use `_variant_tensor` instead of " 3637 "`_as_variant_tensor()` to obtain the variant " 3638 "associated with a dataset.") 3639 raise AttributeError("{}: A likely cause of this error is that the super " 3640 "call for this dataset is not the last line of the " 3641 "`__init__` method. The base class invokes the " 3642 "`_as_variant_tensor()` method in its constructor " 3643 "and if that method uses attributes defined in the " 3644 "`__init__` method, those attributes need to be " 3645 "defined before the super call.".format(e)) 3646 super(DatasetV1, self).__init__(variant_tensor) 3647 3648 @abc.abstractmethod 3649 def _as_variant_tensor(self): 3650 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 3651 3652 Returns: 3653 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 3654 """ 3655 raise NotImplementedError(f"{type(self)}.as_variant_tensor()") 3656 3657 @deprecation.deprecated( 3658 None, "This is a deprecated API that should only be used in TF 1 graph " 3659 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In " 3660 "all other situations -- namely, eager mode and inside `tf.function` -- " 3661 "you can consume dataset elements using `for elem in dataset: ...` or " 3662 "by explicitly creating iterator via `iterator = iter(dataset)` and " 3663 "fetching its elements via `values = next(iterator)`. Furthermore, " 3664 "this API is not available in TF 2. During the transition from TF 1 " 3665 "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` " 3666 "to create a TF 1 graph mode style iterator for a dataset created " 3667 "through TF 2 APIs. Note that this should be a transient state of your " 3668 "code base as there are in general no guarantees about the " 3669 "interoperability of TF 1 and TF 2 code.") 3670 def make_one_shot_iterator(self): 3671 """Creates an iterator for elements of this dataset. 3672 3673 Note: The returned iterator will be initialized automatically. 3674 A "one-shot" iterator does not currently support re-initialization. For 3675 that see `make_initializable_iterator`. 3676 3677 Example: 3678 3679 ```python 3680 # Building graph ... 3681 dataset = ... 3682 next_value = dataset.make_one_shot_iterator().get_next() 3683 3684 # ... from within a session ... 3685 try: 3686 while True: 3687 value = sess.run(next_value) 3688 ... 3689 except tf.errors.OutOfRangeError: 3690 pass 3691 ``` 3692 3693 Returns: 3694 An `tf.data.Iterator` for elements of this dataset. 3695 """ 3696 return self._make_one_shot_iterator() 3697 3698 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 3699 if context.executing_eagerly(): 3700 with ops.colocate_with(self._variant_tensor): 3701 return iterator_ops.OwnedIterator(self) 3702 3703 _ensure_same_dataset_graph(self) 3704 # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to 3705 # to capture by value. We must allowlist these ops so that the capturing 3706 # logic captures the ops instead of raising an exception. 3707 allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self) 3708 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 3709 3710 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 3711 # a 0-argument function. 3712 @function.Defun( 3713 capture_by_value=True, 3714 allowlisted_stateful_ops=allowlisted_stateful_ops) 3715 def _make_dataset(): 3716 """Factory function for a dataset.""" 3717 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 3718 # enclosing graph, so if a graph-level seed is present we set the local 3719 # graph seed based on a combination of the graph- and op-level seeds. 3720 if graph_level_seed is not None: 3721 assert op_level_seed is not None 3722 core_random_seed.set_random_seed( 3723 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 3724 3725 dataset = self._apply_debug_options() 3726 return dataset._variant_tensor # pylint: disable=protected-access 3727 3728 try: 3729 _make_dataset.add_to_graph(ops.get_default_graph()) 3730 except ValueError as err: 3731 if "Cannot capture a stateful node" in str(err): 3732 raise ValueError( 3733 "{}: A likely cause of this error is that the dataset for which " 3734 "you are calling `make_one_shot_iterator()` captures a stateful " 3735 "object, such as a `tf.Variable` or `tf.lookup.StaticHashTable`, " 3736 "which is not supported. Use `make_initializable_iterator()` " 3737 "instead.".format(err)) from None 3738 else: 3739 raise 3740 3741 with ops.colocate_with(self._variant_tensor): 3742 # pylint: disable=protected-access 3743 return iterator_ops.Iterator( 3744 gen_dataset_ops.one_shot_iterator( 3745 dataset_factory=_make_dataset, **self._flat_structure), None, 3746 get_legacy_output_types(self), get_legacy_output_shapes(self), 3747 get_legacy_output_classes(self)) 3748 3749 @deprecation.deprecated( 3750 None, "This is a deprecated API that should only be used in TF 1 graph " 3751 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. " 3752 "In all other situations -- namely, eager mode and inside `tf.function` " 3753 "-- you can consume dataset elements using `for elem in dataset: ...` " 3754 "or by explicitly creating iterator via `iterator = iter(dataset)` " 3755 "and fetching its elements via `values = next(iterator)`. " 3756 "Furthermore, this API is not available in TF 2. During the transition " 3757 "from TF 1 to TF 2 you can use " 3758 "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF " 3759 "1 graph mode style iterator for a dataset created through TF 2 APIs. " 3760 "Note that this should be a transient state of your code base as there " 3761 "are in general no guarantees about the interoperability of TF 1 and TF " 3762 "2 code.") 3763 def make_initializable_iterator(self, shared_name=None): 3764 """Creates an iterator for elements of this dataset. 3765 3766 Note: The returned iterator will be in an uninitialized state, 3767 and you must run the `iterator.initializer` operation before using it: 3768 3769 ```python 3770 # Building graph ... 3771 dataset = ... 3772 iterator = dataset.make_initializable_iterator() 3773 next_value = iterator.get_next() # This is a Tensor. 3774 3775 # ... from within a session ... 3776 sess.run(iterator.initializer) 3777 try: 3778 while True: 3779 value = sess.run(next_value) 3780 ... 3781 except tf.errors.OutOfRangeError: 3782 pass 3783 ``` 3784 3785 Args: 3786 shared_name: (Optional.) If non-empty, the returned iterator will be 3787 shared under the given name across multiple sessions that share the same 3788 devices (e.g. when using a remote server). 3789 3790 Returns: 3791 A `tf.data.Iterator` for elements of this dataset. 3792 3793 Raises: 3794 RuntimeError: If eager execution is enabled. 3795 """ 3796 return self._make_initializable_iterator(shared_name) 3797 3798 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 3799 if context.executing_eagerly(): 3800 raise RuntimeError("`make_initializable_iterator()` is not supported in " 3801 "eager mode. Use Python-style iteration instead.") 3802 _ensure_same_dataset_graph(self) 3803 dataset = self._apply_debug_options() 3804 if shared_name is None: 3805 shared_name = "" 3806 3807 with ops.colocate_with(self._variant_tensor): 3808 iterator_resource = gen_dataset_ops.iterator_v2( 3809 container="", shared_name=shared_name, **self._flat_structure) 3810 3811 initializer = gen_dataset_ops.make_iterator( 3812 dataset._variant_tensor, # pylint: disable=protected-access 3813 iterator_resource) 3814 3815 # pylint: disable=protected-access 3816 return iterator_ops.Iterator(iterator_resource, initializer, 3817 get_legacy_output_types(dataset), 3818 get_legacy_output_shapes(dataset), 3819 get_legacy_output_classes(dataset)) 3820 3821 @property 3822 @deprecation.deprecated( 3823 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.") 3824 def output_classes(self): 3825 """Returns the class of each component of an element of this dataset. 3826 3827 Returns: 3828 A (nested) structure of Python `type` objects corresponding to each 3829 component of an element of this dataset. 3830 """ 3831 return nest.map_structure( 3832 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 3833 self.element_spec) 3834 3835 @property 3836 @deprecation.deprecated( 3837 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.") 3838 def output_shapes(self): 3839 """Returns the shape of each component of an element of this dataset. 3840 3841 Returns: 3842 A (nested) structure of `tf.TensorShape` objects corresponding to each 3843 component of an element of this dataset. 3844 """ 3845 return nest.map_structure( 3846 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 3847 self.element_spec) 3848 3849 @property 3850 @deprecation.deprecated( 3851 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.") 3852 def output_types(self): 3853 """Returns the type of each component of an element of this dataset. 3854 3855 Returns: 3856 A (nested) structure of `tf.DType` objects corresponding to each component 3857 of an element of this dataset. 3858 """ 3859 return nest.map_structure( 3860 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 3861 self.element_spec) 3862 3863 @property 3864 def element_spec(self): 3865 # TODO(b/110122868): Remove this override once all `Dataset` instances 3866 # implement `element_structure`. 3867 return structure.convert_legacy_structure( 3868 self.output_types, self.output_shapes, self.output_classes) 3869 3870 @staticmethod 3871 @functools.wraps(DatasetV2.from_tensors) 3872 def from_tensors(tensors, name=None): 3873 return DatasetV1Adapter(DatasetV2.from_tensors(tensors, name=name)) 3874 3875 @staticmethod 3876 @functools.wraps(DatasetV2.from_tensor_slices) 3877 def from_tensor_slices(tensors, name=None): 3878 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors, name=name)) 3879 3880 @staticmethod 3881 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 3882 def from_sparse_tensor_slices(sparse_tensor): 3883 """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise. 3884 3885 Args: 3886 sparse_tensor: A `tf.sparse.SparseTensor`. 3887 3888 Returns: 3889 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 3890 """ 3891 return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor)) 3892 3893 @staticmethod 3894 @functools.wraps(DatasetV2.from_generator) 3895 @deprecation.deprecated_args(None, "Use output_signature instead", 3896 "output_types", "output_shapes") 3897 def from_generator(generator, 3898 output_types=None, 3899 output_shapes=None, 3900 args=None, 3901 output_signature=None, 3902 name=None): 3903 # Calling DatasetV2.from_generator with output_shapes or output_types is 3904 # deprecated, but this is already checked by the decorator on this function. 3905 with deprecation.silence(): 3906 return DatasetV1Adapter( 3907 DatasetV2.from_generator( 3908 generator, 3909 output_types, 3910 output_shapes, 3911 args, 3912 output_signature, 3913 name=name)) 3914 3915 @staticmethod 3916 @functools.wraps(DatasetV2.range) 3917 def range(*args, **kwargs): 3918 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs)) 3919 3920 @staticmethod 3921 @functools.wraps(DatasetV2.zip) 3922 def zip(datasets, name=None): 3923 return DatasetV1Adapter(DatasetV2.zip(datasets, name=name)) 3924 3925 @functools.wraps(DatasetV2.concatenate) 3926 def concatenate(self, dataset, name=None): 3927 return DatasetV1Adapter( 3928 super(DatasetV1, self).concatenate(dataset, name=name)) 3929 3930 @functools.wraps(DatasetV2.prefetch) 3931 def prefetch(self, buffer_size, name=None): 3932 return DatasetV1Adapter( 3933 super(DatasetV1, self).prefetch(buffer_size, name=name)) 3934 3935 @staticmethod 3936 @functools.wraps(DatasetV2.list_files) 3937 def list_files(file_pattern, shuffle=None, seed=None, name=None): 3938 return DatasetV1Adapter( 3939 DatasetV2.list_files(file_pattern, shuffle, seed, name=name)) 3940 3941 @functools.wraps(DatasetV2.repeat) 3942 def repeat(self, count=None, name=None): 3943 return DatasetV1Adapter(super(DatasetV1, self).repeat(count, name=name)) 3944 3945 @functools.wraps(DatasetV2.shuffle) 3946 def shuffle(self, 3947 buffer_size, 3948 seed=None, 3949 reshuffle_each_iteration=None, 3950 name=None): 3951 return DatasetV1Adapter( 3952 super(DatasetV1, self).shuffle( 3953 buffer_size, seed, reshuffle_each_iteration, name=name)) 3954 3955 @functools.wraps(DatasetV2.cache) 3956 def cache(self, filename="", name=None): 3957 return DatasetV1Adapter(super(DatasetV1, self).cache(filename, name=name)) 3958 3959 @functools.wraps(DatasetV2.take) 3960 def take(self, count, name=None): 3961 return DatasetV1Adapter(super(DatasetV1, self).take(count, name=name)) 3962 3963 @functools.wraps(DatasetV2.skip) 3964 def skip(self, count, name=None): 3965 return DatasetV1Adapter(super(DatasetV1, self).skip(count, name=name)) 3966 3967 @functools.wraps(DatasetV2.shard) 3968 def shard(self, num_shards, index, name=None): 3969 return DatasetV1Adapter( 3970 super(DatasetV1, self).shard(num_shards, index, name=name)) 3971 3972 @functools.wraps(DatasetV2.batch) 3973 def batch(self, 3974 batch_size, 3975 drop_remainder=False, 3976 num_parallel_calls=None, 3977 deterministic=None, 3978 name=None): 3979 return DatasetV1Adapter( 3980 super(DatasetV1, self).batch( 3981 batch_size, 3982 drop_remainder, 3983 num_parallel_calls, 3984 deterministic, 3985 name=name)) 3986 3987 @functools.wraps(DatasetV2.padded_batch) 3988 def padded_batch(self, 3989 batch_size, 3990 padded_shapes=None, 3991 padding_values=None, 3992 drop_remainder=False, 3993 name=None): 3994 return DatasetV1Adapter( 3995 super(DatasetV1, self).padded_batch( 3996 batch_size, 3997 padded_shapes, 3998 padding_values, 3999 drop_remainder, 4000 name=name)) 4001 4002 @functools.wraps(DatasetV2.map) 4003 def map(self, 4004 map_func, 4005 num_parallel_calls=None, 4006 deterministic=None, 4007 name=None): 4008 if num_parallel_calls is None or DEBUG_MODE: 4009 return DatasetV1Adapter( 4010 MapDataset(self, map_func, preserve_cardinality=False)) 4011 else: 4012 return DatasetV1Adapter( 4013 ParallelMapDataset( 4014 self, 4015 map_func, 4016 num_parallel_calls, 4017 deterministic, 4018 preserve_cardinality=False)) 4019 4020 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 4021 def map_with_legacy_function(self, 4022 map_func, 4023 num_parallel_calls=None, 4024 deterministic=None): 4025 """Maps `map_func` across the elements of this dataset. 4026 4027 Note: This is an escape hatch for existing uses of `map` that do not work 4028 with V2 functions. New uses are strongly discouraged and existing uses 4029 should migrate to `map` as this method will be removed in V2. 4030 4031 Args: 4032 map_func: A function mapping a (nested) structure of tensors (having 4033 shapes and types defined by `self.output_shapes` and 4034 `self.output_types`) to another (nested) structure of tensors. 4035 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 4036 representing the number elements to process asynchronously in parallel. 4037 If not specified, elements will be processed sequentially. If the value 4038 `tf.data.AUTOTUNE` is used, then the number of parallel calls is set 4039 dynamically based on available CPU. 4040 deterministic: (Optional.) When `num_parallel_calls` is specified, this 4041 boolean controls the order in which the transformation produces 4042 elements. If set to `False`, the transformation is allowed to yield 4043 elements out of order to trade determinism for performance. If not 4044 specified, the `tf.data.Options.deterministic` option (`True` by 4045 default) controls the behavior. 4046 4047 Returns: 4048 Dataset: A `Dataset`. 4049 """ 4050 if num_parallel_calls is None: 4051 if deterministic is not None: 4052 warnings.warn("The `deterministic` argument has no effect unless the " 4053 "`num_parallel_calls` argument is specified.") 4054 return DatasetV1Adapter( 4055 MapDataset( 4056 self, 4057 map_func, 4058 preserve_cardinality=False, 4059 use_legacy_function=True)) 4060 else: 4061 return DatasetV1Adapter( 4062 ParallelMapDataset( 4063 self, 4064 map_func, 4065 num_parallel_calls, 4066 deterministic, 4067 preserve_cardinality=False, 4068 use_legacy_function=True)) 4069 4070 @functools.wraps(DatasetV2.flat_map) 4071 def flat_map(self, map_func, name=None): 4072 return DatasetV1Adapter( 4073 super(DatasetV1, self).flat_map(map_func, name=name)) 4074 4075 @functools.wraps(DatasetV2.interleave) 4076 def interleave(self, 4077 map_func, 4078 cycle_length=None, 4079 block_length=None, 4080 num_parallel_calls=None, 4081 deterministic=None, 4082 name=None): 4083 return DatasetV1Adapter( 4084 super(DatasetV1, self).interleave( 4085 map_func, 4086 cycle_length, 4087 block_length, 4088 num_parallel_calls, 4089 deterministic, 4090 name=name)) 4091 4092 @functools.wraps(DatasetV2.filter) 4093 def filter(self, predicate, name=None): 4094 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate, name=name)) 4095 4096 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 4097 def filter_with_legacy_function(self, predicate): 4098 """Filters this dataset according to `predicate`. 4099 4100 Note: This is an escape hatch for existing uses of `filter` that do not work 4101 with V2 functions. New uses are strongly discouraged and existing uses 4102 should migrate to `filter` as this method will be removed in V2. 4103 4104 Args: 4105 predicate: A function mapping a (nested) structure of tensors (having 4106 shapes and types defined by `self.output_shapes` and 4107 `self.output_types`) to a scalar `tf.bool` tensor. 4108 4109 Returns: 4110 Dataset: The `Dataset` containing the elements of this dataset for which 4111 `predicate` is `True`. 4112 """ 4113 return FilterDataset(self, predicate, use_legacy_function=True) 4114 4115 @functools.wraps(DatasetV2.apply) 4116 def apply(self, transformation_func): 4117 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 4118 4119 @functools.wraps(DatasetV2.window) 4120 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None): 4121 return DatasetV1Adapter( 4122 super(DatasetV1, 4123 self).window(size, shift, stride, drop_remainder, name=name)) 4124 4125 @functools.wraps(DatasetV2.unbatch) 4126 def unbatch(self, name=None): 4127 return DatasetV1Adapter(super(DatasetV1, self).unbatch(name=name)) 4128 4129 @functools.wraps(DatasetV2.with_options) 4130 def with_options(self, options, name=None): 4131 return DatasetV1Adapter( 4132 super(DatasetV1, self).with_options(options, name=name)) 4133 4134 4135if tf2.enabled(): 4136 Dataset = DatasetV2 4137else: 4138 Dataset = DatasetV1 4139 4140 4141class DatasetV1Adapter(DatasetV1): 4142 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 4143 4144 def __init__(self, dataset): 4145 self._dataset = dataset 4146 super(DatasetV1Adapter, self).__init__() 4147 4148 def _as_variant_tensor(self): 4149 return self._dataset._variant_tensor # pylint: disable=protected-access 4150 4151 def _inputs(self): 4152 return self._dataset._inputs() # pylint: disable=protected-access 4153 4154 def _functions(self): 4155 return self._dataset._functions() # pylint: disable=protected-access 4156 4157 def options(self): 4158 return self._dataset.options() 4159 4160 @property 4161 def element_spec(self): 4162 return self._dataset.element_spec # pylint: disable=protected-access 4163 4164 def __iter__(self): 4165 return iter(self._dataset) 4166 4167 4168def _ensure_same_dataset_graph(dataset): 4169 """Walks the dataset graph to ensure all datasets come from the same graph.""" 4170 # pylint: disable=protected-access 4171 current_graph = ops.get_default_graph() 4172 bfs_q = queue.Queue() 4173 bfs_q.put(dataset) 4174 visited = [] 4175 while not bfs_q.empty(): 4176 ds = bfs_q.get() 4177 visited.append(ds) 4178 ds_graph = ds._graph 4179 if current_graph != ds_graph: 4180 raise ValueError( 4181 f"The graph {current_graph} of the iterator is different from the " 4182 f"graph {ds_graph} the dataset: {ds._variant_tensor} was created in. " 4183 f"If you are using the Estimator API, make sure that no part of the " 4184 f"dataset returned by the `input_fn` function is defined outside the " 4185 f"`input_fn` function. Otherwise, make sure that the dataset is " 4186 f"created in the same graph as the iterator.") 4187 for input_ds in ds._inputs(): 4188 if input_ds not in visited: 4189 bfs_q.put(input_ds) 4190 4191 4192@tf_export(v1=["data.make_one_shot_iterator"]) 4193def make_one_shot_iterator(dataset): 4194 """Creates an iterator for elements of `dataset`. 4195 4196 Note: The returned iterator will be initialized automatically. 4197 A "one-shot" iterator does not support re-initialization. 4198 4199 Args: 4200 dataset: A `tf.data.Dataset`. 4201 4202 Returns: 4203 A `tf.data.Iterator` for elements of `dataset`. 4204 4205 @compatibility(TF2) 4206 This is a legacy API for consuming dataset elements and should only be used 4207 during transition from TF 1 to TF 2. Note that using this API should be 4208 a transient state of your code base as there are in general no guarantees 4209 about the interoperability of TF 1 and TF 2 code. 4210 4211 In TF 2 datasets are Python iterables which means you can consume their 4212 elements using `for elem in dataset: ...` or by explicitly creating iterator 4213 via `iterator = iter(dataset)` and fetching its elements via 4214 `values = next(iterator)`. 4215 @end_compatibility 4216 """ 4217 try: 4218 # Call the defined `_make_one_shot_iterator()` if there is one, because some 4219 # datasets (e.g. for prefetching) override its behavior. 4220 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 4221 except AttributeError: 4222 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 4223 4224 4225@tf_export(v1=["data.make_initializable_iterator"]) 4226def make_initializable_iterator(dataset, shared_name=None): 4227 """Creates an iterator for elements of `dataset`. 4228 4229 Note: The returned iterator will be in an uninitialized state, 4230 and you must run the `iterator.initializer` operation before using it: 4231 4232 ```python 4233 dataset = ... 4234 iterator = tf.compat.v1.data.make_initializable_iterator(dataset) 4235 # ... 4236 sess.run(iterator.initializer) 4237 ``` 4238 4239 Args: 4240 dataset: A `tf.data.Dataset`. 4241 shared_name: (Optional.) If non-empty, the returned iterator will be shared 4242 under the given name across multiple sessions that share the same devices 4243 (e.g. when using a remote server). 4244 4245 Returns: 4246 A `tf.data.Iterator` for elements of `dataset`. 4247 4248 Raises: 4249 RuntimeError: If eager execution is enabled. 4250 4251 @compatibility(TF2) 4252 This is a legacy API for consuming dataset elements and should only be used 4253 during transition from TF 1 to TF 2. Note that using this API should be 4254 a transient state of your code base as there are in general no guarantees 4255 about the interoperability of TF 1 and TF 2 code. 4256 4257 In TF 2 datasets are Python iterables which means you can consume their 4258 elements using `for elem in dataset: ...` or by explicitly creating iterator 4259 via `iterator = iter(dataset)` and fetching its elements via 4260 `values = next(iterator)`. 4261 @end_compatibility 4262 """ 4263 try: 4264 # Call the defined `_make_initializable_iterator()` if there is one, because 4265 # some datasets (e.g. for prefetching) override its behavior. 4266 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 4267 except AttributeError: 4268 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 4269 4270 4271@tf_export("data.experimental.get_structure") 4272def get_structure(dataset_or_iterator): 4273 """Returns the type signature for elements of the input dataset / iterator. 4274 4275 For example, to get the structure of a `tf.data.Dataset`: 4276 4277 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 4278 >>> tf.data.experimental.get_structure(dataset) 4279 TensorSpec(shape=(), dtype=tf.int32, name=None) 4280 4281 >>> dataset = tf.data.experimental.from_list([(1, 'a'), (2, 'b'), (3, 'c')]) 4282 >>> tf.data.experimental.get_structure(dataset) 4283 (TensorSpec(shape=(), dtype=tf.int32, name=None), 4284 TensorSpec(shape=(), dtype=tf.string, name=None)) 4285 4286 To get the structure of an `tf.data.Iterator`: 4287 4288 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 4289 >>> tf.data.experimental.get_structure(iter(dataset)) 4290 TensorSpec(shape=(), dtype=tf.int32, name=None) 4291 4292 Args: 4293 dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`. 4294 4295 Returns: 4296 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 4297 element of `dataset_or_iterator` and specifying the type of individual 4298 components. 4299 4300 Raises: 4301 TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator` 4302 object. 4303 """ 4304 try: 4305 return dataset_or_iterator.element_spec # pylint: disable=protected-access 4306 except AttributeError: 4307 raise TypeError(f"Invalid `dataset_or_iterator`. `dataset_or_iterator` " 4308 f"must be a `tf.data.Dataset` or tf.data.Iterator object, " 4309 f"but got {type(dataset_or_iterator)}.") 4310 4311 4312@tf_export(v1=["data.get_output_classes"]) 4313def get_legacy_output_classes(dataset_or_iterator): 4314 """Returns the output classes for elements of the input dataset / iterator. 4315 4316 Args: 4317 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 4318 4319 Returns: 4320 A (nested) structure of Python `type` objects matching the structure of the 4321 dataset / iterator elements and specifying the class of the individual 4322 components. 4323 4324 @compatibility(TF2) 4325 This is a legacy API for inspecting the type signature of dataset elements. In 4326 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 4327 @end_compatibility 4328 """ 4329 return nest.map_structure( 4330 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 4331 get_structure(dataset_or_iterator)) 4332 4333 4334@tf_export(v1=["data.get_output_shapes"]) 4335def get_legacy_output_shapes(dataset_or_iterator): 4336 """Returns the output shapes for elements of the input dataset / iterator. 4337 4338 Args: 4339 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 4340 4341 Returns: 4342 A (nested) structure of `tf.TensorShape` objects matching the structure of 4343 the dataset / iterator elements and specifying the shape of the individual 4344 components. 4345 4346 @compatibility(TF2) 4347 This is a legacy API for inspecting the type signature of dataset elements. In 4348 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 4349 @end_compatibility 4350 """ 4351 return nest.map_structure( 4352 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 4353 get_structure(dataset_or_iterator)) 4354 4355 4356@tf_export(v1=["data.get_output_types"]) 4357def get_legacy_output_types(dataset_or_iterator): 4358 """Returns the output shapes for elements of the input dataset / iterator. 4359 4360 Args: 4361 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 4362 4363 Returns: 4364 A (nested) structure of `tf.DType` objects matching the structure of 4365 dataset / iterator elements and specifying the shape of the individual 4366 components. 4367 4368 @compatibility(TF2) 4369 This is a legacy API for inspecting the type signature of dataset elements. In 4370 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 4371 @end_compatibility 4372 """ 4373 return nest.map_structure( 4374 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 4375 get_structure(dataset_or_iterator)) 4376 4377 4378class DatasetSource(DatasetV2): 4379 """Abstract class representing a dataset with no inputs.""" 4380 4381 def _inputs(self): 4382 return [] 4383 4384 4385class UnaryDataset(DatasetV2): 4386 """Abstract class representing a dataset with one input.""" 4387 4388 def __init__(self, input_dataset, variant_tensor): 4389 self._input_dataset = input_dataset 4390 super(UnaryDataset, self).__init__(variant_tensor) 4391 4392 def _inputs(self): 4393 return [self._input_dataset] 4394 4395 4396class UnaryUnchangedStructureDataset(UnaryDataset): 4397 """Represents a unary dataset with the same input and output structure.""" 4398 4399 def __init__(self, input_dataset, variant_tensor): 4400 self._input_dataset = input_dataset 4401 super(UnaryUnchangedStructureDataset, self).__init__( 4402 input_dataset, variant_tensor) 4403 4404 @property 4405 def element_spec(self): 4406 return self._input_dataset.element_spec 4407 4408 4409class _VariantDataset(DatasetV2): 4410 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 4411 4412 def __init__(self, dataset_variant, element_spec): 4413 self._element_spec = element_spec 4414 super(_VariantDataset, self).__init__(dataset_variant) 4415 4416 def _inputs(self): 4417 return [] 4418 4419 @property 4420 def element_spec(self): 4421 return self._element_spec 4422 4423 4424class _NestedVariant(composite_tensor.CompositeTensor): 4425 4426 def __init__(self, variant_tensor, element_spec, dataset_shape): 4427 self._variant_tensor = variant_tensor 4428 self._element_spec = element_spec 4429 self._dataset_shape = dataset_shape 4430 4431 @property 4432 def _type_spec(self): 4433 return DatasetSpec(self._element_spec, self._dataset_shape) 4434 4435 4436@tf_export("data.experimental.from_variant") 4437def from_variant(variant, structure): 4438 """Constructs a dataset from the given variant and (nested) structure. 4439 4440 Args: 4441 variant: A scalar `tf.variant` tensor representing a dataset. 4442 structure: A (nested) structure of `tf.TypeSpec` objects representing the 4443 structure of each element in the dataset. 4444 4445 Returns: 4446 A `tf.data.Dataset` instance. 4447 """ 4448 return _VariantDataset(variant, structure) # pylint: disable=protected-access 4449 4450 4451@tf_export("data.experimental.to_variant") 4452def to_variant(dataset): 4453 """Returns a variant representing the given dataset. 4454 4455 Args: 4456 dataset: A `tf.data.Dataset`. 4457 4458 Returns: 4459 A scalar `tf.variant` tensor representing the given dataset. 4460 """ 4461 return dataset._variant_tensor # pylint: disable=protected-access 4462 4463 4464@tf_export( 4465 "data.DatasetSpec", 4466 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"]) 4467class DatasetSpec(type_spec.BatchableTypeSpec): 4468 """Type specification for `tf.data.Dataset`. 4469 4470 See `tf.TypeSpec` for more information about TensorFlow type specifications. 4471 4472 >>> dataset = tf.data.Dataset.range(3) 4473 >>> tf.data.DatasetSpec.from_value(dataset) 4474 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([])) 4475 """ 4476 4477 __slots__ = ["_element_spec", "_dataset_shape"] 4478 4479 def __init__(self, element_spec, dataset_shape=()): 4480 self._element_spec = element_spec 4481 self._dataset_shape = tensor_shape.as_shape(dataset_shape) 4482 4483 @property 4484 def value_type(self): 4485 return Dataset 4486 4487 @property 4488 def element_spec(self): 4489 """The inner element spec.""" 4490 return self._element_spec 4491 4492 def is_subtype_of(self, other): 4493 """See base class.""" 4494 if type(self) is not type(other): 4495 return False 4496 4497 # TODO(b/220385675): _element_spec should always be a TypeSpec. 4498 try: 4499 tf_nest.assert_same_structure(self.element_spec, other.element_spec) 4500 except (TypeError, ValueError): 4501 return False 4502 4503 self_elements = tf_nest.flatten(self.element_spec) 4504 other_elements = tf_nest.flatten(other.element_spec) 4505 4506 def is_subtype_or_equal(a, b): 4507 if isinstance(a, trace.TraceType): 4508 return a.is_subtype_of(b) 4509 else: 4510 return a == b 4511 4512 for self_element, other_element in zip(self_elements, other_elements): 4513 if not is_subtype_or_equal(self_element, other_element): 4514 return False 4515 4516 return self._dataset_shape.is_subtype_of(other._dataset_shape) # pylint: disable=protected-access 4517 4518 def most_specific_common_supertype(self, others): 4519 """See base class.""" 4520 if not all(type(self) is type(other) for other in others): 4521 return None 4522 4523 try: 4524 for other in others: 4525 tf_nest.assert_same_structure(self.element_spec, other.element_spec) 4526 except (TypeError, ValueError): 4527 return None 4528 4529 self_components = tf_nest.flatten(self.element_spec) 4530 others_components = [ 4531 tf_nest.flatten(other.element_spec) for other in others 4532 ] 4533 common_components = [None] * len(self_components) 4534 4535 def common_supertype_or_equal(a, bs): 4536 if isinstance(a, trace.TraceType): 4537 return a.most_specific_common_supertype(bs) 4538 else: 4539 return a if all(a == b for b in bs) else None 4540 4541 for i, self_component in enumerate(self_components): 4542 common_components[i] = common_supertype_or_equal( 4543 self_component, 4544 [other_components[i] for other_components in others_components]) 4545 if self_component is not None and common_components[i] is None: 4546 return None 4547 common_element_spec = tf_nest.pack_sequence_as(self._element_spec, 4548 common_components) 4549 4550 common_dataset_shape = self._dataset_shape.most_specific_common_supertype( 4551 [other._dataset_shape for other in others]) # pylint: disable=protected-access 4552 if common_dataset_shape is None: 4553 return None 4554 4555 return DatasetSpec(common_element_spec, common_dataset_shape) 4556 4557 # TODO(b/220385675): Once _element_spec is guaranteed to be TypeSpec, the 4558 # following functions do not need to be overloaded: is_subtype_of, 4559 # most_specific_common_supertype, __hash__ and __eq__ 4560 def _serialize(self): 4561 return (self._element_spec, self._dataset_shape) 4562 4563 @property 4564 def _component_specs(self): 4565 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant) 4566 4567 def _to_components(self, value): 4568 return value._variant_tensor # pylint: disable=protected-access 4569 4570 def _from_components(self, components): 4571 # pylint: disable=protected-access 4572 if self._dataset_shape.ndims == 0: 4573 return _VariantDataset(components, self._element_spec) 4574 else: 4575 return _NestedVariant(components, self._element_spec, self._dataset_shape) 4576 4577 def _to_tensor_list(self, value): 4578 return [ 4579 ops.convert_to_tensor( 4580 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access 4581 ] 4582 4583 @staticmethod 4584 def from_value(value): 4585 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value.""" 4586 return DatasetSpec(value.element_spec) # pylint: disable=protected-access 4587 4588 def _batch(self, batch_size): 4589 return DatasetSpec( 4590 self._element_spec, 4591 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) 4592 4593 def _unbatch(self): 4594 if self._dataset_shape.ndims == 0: 4595 raise ValueError("Slicing dataset elements is not supported for rank 0.") 4596 return DatasetSpec(self._element_spec, self._dataset_shape[1:]) 4597 4598 def _to_batched_tensor_list(self, value): 4599 if self._dataset_shape.ndims == 0: 4600 raise ValueError("Slicing dataset elements is not supported for rank 0.") 4601 return self._to_tensor_list(value) 4602 4603 def _to_legacy_output_types(self): 4604 return self 4605 4606 def _to_legacy_output_shapes(self): 4607 return self 4608 4609 def _to_legacy_output_classes(self): 4610 return self 4611 4612 def __hash__(self): 4613 # TODO(b/220385675): attributes can be dicts and hence unhashable. 4614 return hash(DatasetSpec) 4615 4616 def __eq__(self, other): 4617 return (isinstance(other, DatasetSpec) and 4618 self._element_spec == other._element_spec and 4619 self._dataset_shape == other._dataset_shape) 4620 4621 4622class _NumpyIterator: 4623 """Iterator over a dataset with elements converted to numpy.""" 4624 4625 __slots__ = ["_iterator"] 4626 4627 def __init__(self, dataset): 4628 self._iterator = iter(dataset) 4629 4630 def __iter__(self): 4631 return self 4632 4633 def __next__(self): 4634 4635 def to_numpy(x): 4636 numpy = x._numpy() # pylint: disable=protected-access 4637 if isinstance(numpy, np.ndarray): 4638 # `numpy` shares the same underlying buffer as the `x` Tensor. 4639 # Tensors are expected to be immutable, so we disable writes. 4640 numpy.setflags(write=False) 4641 return numpy 4642 4643 return nest.map_structure(to_numpy, next(self._iterator)) 4644 4645 def next(self): 4646 return self.__next__() 4647 4648 4649class _VariantTracker(resource_lib.CapturableResource): 4650 """Allows export of functions capturing a Dataset in SavedModels. 4651 4652 When saving a SavedModel, `tf.saved_model.save` traverses the object 4653 graph. Since Datasets reference _VariantTracker objects, that traversal will 4654 find a _VariantTracker for each Dataset and so know how to save and restore 4655 functions which reference the Dataset's variant Tensor. 4656 """ 4657 4658 def __init__(self, variant_tensor, resource_creator): 4659 """Record that `variant_tensor` is associated with `resource_creator`. 4660 4661 Args: 4662 variant_tensor: The variant-dtype Tensor associated with the Dataset. This 4663 Tensor will be a captured input to functions which use the Dataset, and 4664 is used by saving code to identify the corresponding _VariantTracker. 4665 resource_creator: A zero-argument function which creates a new 4666 variant-dtype Tensor. This function will be included in SavedModels and 4667 run to re-create the Dataset's variant Tensor on restore. 4668 """ 4669 super(_VariantTracker, self).__init__(device="CPU") 4670 self._resource_handle = variant_tensor 4671 if not isinstance(resource_creator, def_function.Function): 4672 # Internal validation -- _VariantTracker assumes that resource creator is 4673 # already a tf.function. 4674 raise TypeError("Resource creator should already be a tf.function.") 4675 self._create_resource = resource_creator 4676 4677 def _trackable_children(self, 4678 save_type=tracking_base.SaveType.CHECKPOINT, 4679 **kwargs): 4680 if save_type != tracking_base.SaveType.SAVEDMODEL: 4681 return {} 4682 4683 children = super(_VariantTracker, 4684 self)._trackable_children(save_type, **kwargs) 4685 # Overwrite the _create_resource function, since `self._create_resource` 4686 # is already a tf.function. 4687 children["_create_resource"] = self._create_resource 4688 return children 4689 4690 4691class TensorDataset(DatasetSource): 4692 """A `Dataset` with a single element.""" 4693 4694 def __init__(self, element, name=None): 4695 """See `Dataset.from_tensors()` for details.""" 4696 element = structure.normalize_element(element) 4697 self._structure = structure.type_spec_from_value(element) 4698 self._tensors = structure.to_tensor_list(self._structure, element) 4699 self._name = name 4700 variant_tensor = gen_dataset_ops.tensor_dataset( 4701 self._tensors, 4702 output_shapes=structure.get_flat_tensor_shapes(self._structure), 4703 metadata=self._metadata.SerializeToString()) 4704 super(TensorDataset, self).__init__(variant_tensor) 4705 4706 @property 4707 def element_spec(self): 4708 return self._structure 4709 4710 4711class TensorSliceDataset(DatasetSource): 4712 """A `Dataset` of slices from a dataset element.""" 4713 4714 def __init__(self, element, is_files=False, name=None): 4715 """See `Dataset.from_tensor_slices()` for details.""" 4716 element = structure.normalize_element(element) 4717 batched_spec = structure.type_spec_from_value(element) 4718 self._tensors = structure.to_batched_tensor_list(batched_spec, element) 4719 if not self._tensors: 4720 raise ValueError("Invalid `element`. `element` should not be empty.") 4721 self._structure = nest.map_structure( 4722 lambda component_spec: component_spec._unbatch(), batched_spec) # pylint: disable=protected-access 4723 self._name = name 4724 4725 batch_dim = tensor_shape.Dimension( 4726 tensor_shape.dimension_value(self._tensors[0].get_shape()[0])) 4727 for t in self._tensors[1:]: 4728 batch_dim.assert_is_compatible_with( 4729 tensor_shape.Dimension( 4730 tensor_shape.dimension_value(t.get_shape()[0]))) 4731 4732 variant_tensor = gen_dataset_ops.tensor_slice_dataset( 4733 self._tensors, 4734 output_shapes=structure.get_flat_tensor_shapes(self._structure), 4735 is_files=is_files, 4736 metadata=self._metadata.SerializeToString()) 4737 super(TensorSliceDataset, self).__init__(variant_tensor) 4738 4739 @property 4740 def element_spec(self): 4741 return self._structure 4742 4743 4744class SparseTensorSliceDataset(DatasetSource): 4745 """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows.""" 4746 4747 def __init__(self, sparse_tensor): 4748 """See `Dataset.from_sparse_tensor_slices()` for details.""" 4749 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 4750 raise TypeError(f"Invalid `sparse_tensor`. `sparse_tensor` must be a " 4751 f"`tf.sparse.SparseTensor`. Got {type(sparse_tensor)}.") 4752 self._sparse_tensor = sparse_tensor 4753 4754 indices_shape = self._sparse_tensor.indices.get_shape() 4755 shape_shape = self._sparse_tensor.dense_shape.get_shape() 4756 rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1) 4757 self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64), 4758 tensor_spec.TensorSpec([None], 4759 self._sparse_tensor.dtype), 4760 tensor_spec.TensorSpec([rank], dtypes.int64)) 4761 4762 variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( 4763 self._sparse_tensor.indices, self._sparse_tensor.values, 4764 self._sparse_tensor.dense_shape) 4765 super(SparseTensorSliceDataset, self).__init__(variant_tensor) 4766 4767 @property 4768 def element_spec(self): 4769 return self._structure 4770 4771 4772class _GeneratorDataset(DatasetSource): 4773 """A `Dataset` that generates elements by invoking a function.""" 4774 4775 def __init__(self, 4776 init_args, 4777 init_func, 4778 next_func, 4779 finalize_func, 4780 output_signature, 4781 name=None): 4782 """Constructs a `_GeneratorDataset`. 4783 4784 Args: 4785 init_args: A (nested) structure representing the arguments to `init_func`. 4786 init_func: A TensorFlow function that will be called on `init_args` each 4787 time a C++ iterator over this dataset is constructed. Returns a (nested) 4788 structure representing the "state" of the dataset. 4789 next_func: A TensorFlow function that will be called on the result of 4790 `init_func` to produce each element, and that raises `OutOfRangeError` 4791 to terminate iteration. 4792 finalize_func: A TensorFlow function that will be called on the result of 4793 `init_func` immediately before a C++ iterator over this dataset is 4794 destroyed. The return value is ignored. 4795 output_signature: A (nested) structure of `tf.TypeSpec` objects describing 4796 the output of `next_func`. 4797 name: Optional. A name for the tf.data transformation. 4798 """ 4799 self._init_args = init_args 4800 4801 self._init_structure = structure.type_spec_from_value(init_args) 4802 4803 self._init_func = structured_function.StructuredFunctionWrapper( 4804 init_func, 4805 self._transformation_name(), 4806 input_structure=self._init_structure) 4807 4808 self._next_func = structured_function.StructuredFunctionWrapper( 4809 next_func, 4810 self._transformation_name(), 4811 input_structure=self._init_func.output_structure) 4812 4813 self._finalize_func = structured_function.StructuredFunctionWrapper( 4814 finalize_func, 4815 self._transformation_name(), 4816 input_structure=self._init_func.output_structure) 4817 4818 self._output_signature = output_signature 4819 4820 self._name = name 4821 4822 variant_tensor = gen_dataset_ops.generator_dataset( 4823 structure.to_tensor_list(self._init_structure, self._init_args) + 4824 self._init_func.function.captured_inputs, 4825 self._next_func.function.captured_inputs, 4826 self._finalize_func.function.captured_inputs, 4827 init_func=self._init_func.function, 4828 next_func=self._next_func.function, 4829 finalize_func=self._finalize_func.function, 4830 **self._common_args) 4831 super(_GeneratorDataset, self).__init__(variant_tensor) 4832 4833 @property 4834 def element_spec(self): 4835 return self._output_signature 4836 4837 def _transformation_name(self): 4838 return "Dataset.from_generator()" 4839 4840 4841class ZipDataset(DatasetV2): 4842 """A `Dataset` that zips its inputs together.""" 4843 4844 def __init__(self, datasets, name=None): 4845 """See `Dataset.zip()` for details.""" 4846 for ds in nest.flatten(datasets): 4847 if not isinstance(ds, DatasetV2): 4848 if isinstance(ds, list): 4849 raise TypeError("Invalid `datasets`. `datasets` is expected to be a " 4850 "(nested) structure of `tf.data.Dataset` objects. " 4851 "Python `list` is not supported and you should use " 4852 "`tuple` instead.") 4853 else: 4854 raise TypeError(f"Invalid `datasets`. `datasets` is expected to be a " 4855 f"(nested) structure of `tf.data.Dataset` objects " 4856 f"but encountered object of type {type(ds)}.") 4857 self._datasets = datasets 4858 self._structure = nest.pack_sequence_as( 4859 self._datasets, 4860 [ds.element_spec for ds in nest.flatten(self._datasets)]) 4861 self._name = name 4862 variant_tensor = gen_dataset_ops.zip_dataset( 4863 [ds._variant_tensor for ds in nest.flatten(self._datasets)], 4864 **self._common_args) 4865 super(ZipDataset, self).__init__(variant_tensor) 4866 4867 def _inputs(self): 4868 return nest.flatten(self._datasets) 4869 4870 @property 4871 def element_spec(self): 4872 return self._structure 4873 4874 4875class ConcatenateDataset(DatasetV2): 4876 """A `Dataset` that concatenates its input with given dataset.""" 4877 4878 def __init__(self, input_dataset, dataset_to_concatenate, name=None): 4879 """See `Dataset.concatenate()` for details.""" 4880 self._input_dataset = input_dataset 4881 self._dataset_to_concatenate = dataset_to_concatenate 4882 4883 def common_supertype(a, b): 4884 result = a.most_specific_common_supertype([b]) 4885 if result is None: 4886 raise TypeError(f"No common supertype of {a} and {b}.") 4887 return result 4888 4889 try: 4890 self._structure = tf_nest.map_structure( 4891 common_supertype, input_dataset.element_spec, 4892 dataset_to_concatenate.element_spec) 4893 except (TypeError, ValueError) as e: 4894 raise TypeError( 4895 f"Incompatible dataset elements:\n" 4896 f" {input_dataset.element_spec} vs. " 4897 f" {dataset_to_concatenate.element_spec}") from e 4898 4899 self._input_datasets = [input_dataset, dataset_to_concatenate] 4900 self._name = name 4901 # pylint: disable=protected-access 4902 variant_tensor = gen_dataset_ops.concatenate_dataset( 4903 input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, 4904 **self._common_args) 4905 # pylint: enable=protected-access 4906 super(ConcatenateDataset, self).__init__(variant_tensor) 4907 4908 def _inputs(self): 4909 return self._input_datasets 4910 4911 @property 4912 def element_spec(self): 4913 return self._structure 4914 4915 4916class RepeatDataset(UnaryUnchangedStructureDataset): 4917 """A `Dataset` that repeats its input several times.""" 4918 4919 def __init__(self, input_dataset, count, name=None): 4920 """See `Dataset.repeat()` for details.""" 4921 self._input_dataset = input_dataset 4922 if count is None: 4923 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 4924 else: 4925 self._count = ops.convert_to_tensor( 4926 count, dtype=dtypes.int64, name="count") 4927 self._name = name 4928 variant_tensor = gen_dataset_ops.repeat_dataset( 4929 input_dataset._variant_tensor, # pylint: disable=protected-access 4930 count=self._count, 4931 **self._common_args) 4932 super(RepeatDataset, self).__init__(input_dataset, variant_tensor) 4933 4934 4935class RangeDataset(DatasetSource): 4936 """A `Dataset` of a step separated range of values.""" 4937 4938 def __init__(self, *args, **kwargs): 4939 """See `Dataset.range()` for details.""" 4940 self._parse_args(*args, **kwargs) 4941 self._structure = tensor_spec.TensorSpec([], self._output_type) 4942 variant_tensor = gen_dataset_ops.range_dataset( 4943 start=self._start, 4944 stop=self._stop, 4945 step=self._step, 4946 **self._common_args) 4947 super(RangeDataset, self).__init__(variant_tensor) 4948 4949 def _parse_args(self, *args, **kwargs): 4950 """Parse arguments according to the same rules as the `range()` builtin.""" 4951 if len(args) == 1: 4952 self._start = self._build_tensor(0, "start") 4953 self._stop = self._build_tensor(args[0], "stop") 4954 self._step = self._build_tensor(1, "step") 4955 elif len(args) == 2: 4956 self._start = self._build_tensor(args[0], "start") 4957 self._stop = self._build_tensor(args[1], "stop") 4958 self._step = self._build_tensor(1, "step") 4959 elif len(args) == 3: 4960 self._start = self._build_tensor(args[0], "start") 4961 self._stop = self._build_tensor(args[1], "stop") 4962 self._step = self._build_tensor(args[2], "step") 4963 else: 4964 raise ValueError(f"Invalid `args`. The lenght of `args` should be " 4965 f"between 1 and 3 but was {len(args)}.") 4966 if "output_type" in kwargs: 4967 self._output_type = kwargs["output_type"] 4968 else: 4969 self._output_type = dtypes.int64 4970 self._name = kwargs["name"] if "name" in kwargs else None 4971 4972 def _build_tensor(self, int64_value, name): 4973 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 4974 4975 @property 4976 def element_spec(self): 4977 return self._structure 4978 4979 4980class CacheDataset(UnaryUnchangedStructureDataset): 4981 """A `Dataset` that caches elements of its input.""" 4982 4983 def __init__(self, input_dataset, filename, name=None): 4984 """See `Dataset.cache()` for details.""" 4985 self._input_dataset = input_dataset 4986 self._filename = ops.convert_to_tensor( 4987 filename, dtype=dtypes.string, name="filename") 4988 self._name = name 4989 if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()): 4990 variant_tensor = gen_dataset_ops.cache_dataset_v2( 4991 input_dataset._variant_tensor, # pylint: disable=protected-access 4992 filename=self._filename, 4993 cache=gen_dataset_ops.dummy_memory_cache(), 4994 **self._common_args) 4995 else: 4996 variant_tensor = gen_dataset_ops.cache_dataset( 4997 input_dataset._variant_tensor, # pylint: disable=protected-access 4998 filename=self._filename, 4999 **self._common_args) 5000 super(CacheDataset, self).__init__(input_dataset, variant_tensor) 5001 5002 5003class ShuffleDataset(UnaryUnchangedStructureDataset): 5004 """A `Dataset` that randomly shuffles the elements of its input.""" 5005 5006 def __init__(self, 5007 input_dataset, 5008 buffer_size, 5009 seed=None, 5010 reshuffle_each_iteration=None, 5011 name=None): 5012 """See `Dataset.shuffle()` for details.""" 5013 self._input_dataset = input_dataset 5014 self._buffer_size = ops.convert_to_tensor( 5015 buffer_size, dtype=dtypes.int64, name="buffer_size") 5016 self._seed, self._seed2 = random_seed.get_seed(seed) 5017 if reshuffle_each_iteration is None: 5018 reshuffle_each_iteration = True 5019 self._reshuffle_each_iteration = reshuffle_each_iteration 5020 self._name = name 5021 5022 if (tf2.enabled() and 5023 (context.executing_eagerly() or ops.inside_function())): 5024 variant_tensor = gen_dataset_ops.shuffle_dataset_v3( 5025 input_dataset._variant_tensor, # pylint: disable=protected-access 5026 buffer_size=self._buffer_size, 5027 seed=self._seed, 5028 seed2=self._seed2, 5029 seed_generator=gen_dataset_ops.dummy_seed_generator(), 5030 reshuffle_each_iteration=self._reshuffle_each_iteration, 5031 **self._common_args) 5032 else: 5033 variant_tensor = gen_dataset_ops.shuffle_dataset( 5034 input_dataset._variant_tensor, # pylint: disable=protected-access 5035 buffer_size=self._buffer_size, 5036 seed=self._seed, 5037 seed2=self._seed2, 5038 reshuffle_each_iteration=self._reshuffle_each_iteration, 5039 **self._common_args) 5040 super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) 5041 5042 5043class TakeDataset(UnaryUnchangedStructureDataset): 5044 """A `Dataset` containing the first `count` elements from its input.""" 5045 5046 def __init__(self, input_dataset, count, name=None): 5047 """See `Dataset.take()` for details.""" 5048 self._input_dataset = input_dataset 5049 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 5050 self._name = name 5051 variant_tensor = gen_dataset_ops.take_dataset( 5052 input_dataset._variant_tensor, # pylint: disable=protected-access 5053 count=self._count, 5054 **self._common_args) 5055 super(TakeDataset, self).__init__(input_dataset, variant_tensor) 5056 5057 5058class SkipDataset(UnaryUnchangedStructureDataset): 5059 """A `Dataset` skipping the first `count` elements from its input.""" 5060 5061 def __init__(self, input_dataset, count, name=None): 5062 """See `Dataset.skip()` for details.""" 5063 self._input_dataset = input_dataset 5064 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 5065 self._name = name 5066 variant_tensor = gen_dataset_ops.skip_dataset( 5067 input_dataset._variant_tensor, # pylint: disable=protected-access 5068 count=self._count, 5069 **self._common_args) 5070 super(SkipDataset, self).__init__(input_dataset, variant_tensor) 5071 5072 5073class ShardDataset(UnaryUnchangedStructureDataset): 5074 """A `Dataset` for sharding its input.""" 5075 5076 def __init__(self, input_dataset, num_shards, index, name=None): 5077 """See `Dataset.shard()` for details.""" 5078 self._input_dataset = input_dataset 5079 self._num_shards = ops.convert_to_tensor( 5080 num_shards, dtype=dtypes.int64, name="num_shards") 5081 self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index") 5082 self._name = name 5083 variant_tensor = gen_dataset_ops.shard_dataset( 5084 input_dataset._variant_tensor, # pylint: disable=protected-access 5085 num_shards=self._num_shards, 5086 index=self._index, 5087 **self._common_args) 5088 super(ShardDataset, self).__init__(input_dataset, variant_tensor) 5089 5090 5091class BatchDataset(UnaryDataset): 5092 """A `Dataset` that batches contiguous elements from its input.""" 5093 5094 def __init__(self, input_dataset, batch_size, drop_remainder, name=None): 5095 """See `Dataset.batch()` for details.""" 5096 self._input_dataset = input_dataset 5097 self._batch_size = ops.convert_to_tensor( 5098 batch_size, dtype=dtypes.int64, name="batch_size") 5099 self._drop_remainder = ops.convert_to_tensor( 5100 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 5101 5102 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 5103 # pylint: disable=protected-access 5104 if constant_drop_remainder: 5105 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 5106 # or `False` (explicitly retaining the remainder). 5107 # pylint: disable=g-long-lambda 5108 constant_batch_size = tensor_util.constant_value(self._batch_size) 5109 self._structure = nest.map_structure( 5110 lambda component_spec: component_spec._batch(constant_batch_size), 5111 input_dataset.element_spec) 5112 else: 5113 self._structure = nest.map_structure( 5114 lambda component_spec: component_spec._batch(None), 5115 input_dataset.element_spec) 5116 5117 self._name = name 5118 variant_tensor = gen_dataset_ops.batch_dataset_v2( 5119 input_dataset._variant_tensor, 5120 batch_size=self._batch_size, 5121 drop_remainder=self._drop_remainder, 5122 **self._common_args) 5123 super(BatchDataset, self).__init__(input_dataset, variant_tensor) 5124 5125 @property 5126 def element_spec(self): 5127 return self._structure 5128 5129 5130class ParallelBatchDataset(UnaryDataset): 5131 """A `Dataset` that batches contiguous elements from its input in parallel.""" 5132 5133 def __init__(self, 5134 input_dataset, 5135 batch_size, 5136 drop_remainder, 5137 num_parallel_calls, 5138 deterministic, 5139 name=None): 5140 """See `Dataset.batch()` for details.""" 5141 self._input_dataset = input_dataset 5142 self._batch_size = ops.convert_to_tensor( 5143 batch_size, dtype=dtypes.int64, name="batch_size") 5144 self._drop_remainder = ops.convert_to_tensor( 5145 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 5146 self._num_parallel_calls = ops.convert_to_tensor( 5147 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 5148 if deterministic is None: 5149 self._deterministic = "default" 5150 elif deterministic: 5151 self._deterministic = "true" 5152 else: 5153 self._deterministic = "false" 5154 5155 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 5156 # pylint: disable=protected-access 5157 if constant_drop_remainder: 5158 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 5159 # or `False` (explicitly retaining the remainder). 5160 # pylint: disable=g-long-lambda 5161 constant_batch_size = tensor_util.constant_value(self._batch_size) 5162 self._structure = nest.map_structure( 5163 lambda component_spec: component_spec._batch(constant_batch_size), 5164 input_dataset.element_spec) 5165 else: 5166 self._structure = nest.map_structure( 5167 lambda component_spec: component_spec._batch(None), 5168 input_dataset.element_spec) 5169 5170 self._name = name 5171 variant_tensor = gen_dataset_ops.parallel_batch_dataset( 5172 input_dataset._variant_tensor, 5173 batch_size=self._batch_size, 5174 num_parallel_calls=self._num_parallel_calls, 5175 drop_remainder=self._drop_remainder, 5176 deterministic=self._deterministic, 5177 **self._common_args) 5178 5179 super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor) 5180 5181 @property 5182 def element_spec(self): 5183 return self._structure 5184 5185 5186def _is_padded_shape_compatible_with(padded_shape, input_component_shape): 5187 """Returns `True` if `input_component_shape` can be padded to `padded_shape`. 5188 5189 Args: 5190 padded_shape: A `tf.TensorShape`. 5191 input_component_shape: A `tf.TensorShape`. 5192 5193 Returns: 5194 `True` if `input_component_shape` can be padded to `padded_shape`, otherwise 5195 `False`. 5196 """ 5197 5198 if padded_shape.dims is None or input_component_shape.dims is None: 5199 return True 5200 if len(padded_shape.dims) != len(input_component_shape.dims): 5201 return False 5202 for padded_dim, input_dim in zip( 5203 padded_shape.dims, input_component_shape.dims): 5204 if (padded_dim.value is not None and input_dim.value is not None 5205 and padded_dim.value < input_dim.value): 5206 return False 5207 return True 5208 5209 5210def _padded_shape_to_tensor(padded_shape, input_component_shape): 5211 """Converts `padded_shape` to a `tf.Tensor` representing that shape. 5212 5213 Args: 5214 padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python 5215 sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. 5216 input_component_shape: A `tf.TensorShape`, with which `padded_shape` must 5217 be compatible. 5218 5219 Returns: 5220 A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. 5221 5222 Raises: 5223 ValueError: If `padded_shape` is not a shape or not compatible with 5224 `input_component_shape`. 5225 TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. 5226 """ 5227 try: 5228 # Try to convert the `padded_shape` to a `tf.TensorShape` 5229 padded_shape_as_shape = tensor_shape.as_shape(padded_shape) 5230 # We will return the "canonical" tensor representation, which uses 5231 # `-1` in place of `None`. 5232 ret = ops.convert_to_tensor( 5233 [dim if dim is not None else -1 5234 for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) 5235 except (TypeError, ValueError) as e: 5236 # The argument was not trivially convertible to a 5237 # `tf.TensorShape`, so fall back on the conversion to tensor 5238 # machinery. 5239 ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) 5240 if ret.shape.dims is not None and len(ret.shape.dims) != 1: 5241 raise ValueError( 5242 f"Padded shape {padded_shape} must be a `tf.int64` vector tensor, " 5243 f"but its shape was {ret.shape}.") from e 5244 if ret.dtype != dtypes.int64: 5245 raise TypeError( 5246 f"Padded shape {padded_shape} must be a `tf.int64` vector " 5247 f"tensor, but its element type was {ret.dtype.name}.") from e 5248 padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) 5249 5250 if not _is_padded_shape_compatible_with(padded_shape_as_shape, 5251 input_component_shape): 5252 raise ValueError(f"The padded shape {padded_shape_as_shape} is not " 5253 f"compatible with the shape {input_component_shape} of " 5254 f"the corresponding input component.") 5255 5256 return ret 5257 5258 5259def _padding_value_to_tensor(value, output_type): 5260 """Converts the padding value to a tensor. 5261 5262 Args: 5263 value: The padding value. 5264 output_type: Its expected dtype. 5265 5266 Returns: 5267 A scalar `Tensor`. 5268 5269 Raises: 5270 ValueError: if the padding value is not a scalar. 5271 TypeError: if the padding value's type does not match `output_type`. 5272 """ 5273 value = ops.convert_to_tensor(value, name="padding_value") 5274 if not value.shape.is_compatible_with(tensor_shape.TensorShape([])): 5275 raise ValueError(f"Invalid `padding_values`. `padding_values` values " 5276 f"should be scalars, but got {value.shape}.") 5277 if value.dtype != output_type: 5278 raise TypeError(f"Invalid `padding_values`. `padding_values` values " 5279 f"type {value.dtype} does not match type {output_type} " 5280 f"of the corresponding input component.") 5281 return value 5282 5283 5284def _padding_values_or_default(padding_values, input_dataset): 5285 """Returns padding values with None elements replaced with default values.""" 5286 5287 def make_zero(t): 5288 if t.base_dtype == dtypes.string: 5289 return "" 5290 elif t.base_dtype == dtypes.variant: 5291 raise TypeError("Unable to create default padding value for a component " 5292 "of type 'variant'.") 5293 elif t.base_dtype == dtypes.bfloat16: 5294 # Special case `bfloat16` because it is not supported by NumPy. 5295 return constant_op.constant(0, dtype=dtypes.bfloat16) 5296 else: 5297 return np.zeros_like(t.as_numpy_dtype()) 5298 5299 def value_or_default(value, default): 5300 return default if value is None else value 5301 5302 default_padding = nest.map_structure( 5303 make_zero, 5304 get_legacy_output_types(input_dataset)) 5305 return nest.map_structure_up_to(padding_values, value_or_default, 5306 padding_values, default_padding) 5307 5308 5309class PaddedBatchDataset(UnaryDataset): 5310 """A `Dataset` that batches and pads contiguous elements from its input.""" 5311 5312 def __init__(self, 5313 input_dataset, 5314 batch_size, 5315 padded_shapes, 5316 padding_values, 5317 drop_remainder, 5318 name=None): 5319 """See `Dataset.batch()` for details.""" 5320 self._input_dataset = input_dataset 5321 5322 def check_types(component_spec): 5323 if not isinstance(component_spec, tensor_spec.TensorSpec): 5324 raise TypeError(f"`padded_batch` is only supported for datasets that " 5325 f"produce tensor elements but the input dataset " 5326 f"produces elements of unsupported type " 5327 f"{component_spec.value_type()}.") 5328 5329 nest.map_structure(check_types, input_dataset.element_spec) 5330 self._input_dataset = input_dataset 5331 self._batch_size = ops.convert_to_tensor( 5332 batch_size, dtype=dtypes.int64, name="batch_size") 5333 padding_values = _padding_values_or_default(padding_values, input_dataset) 5334 5335 input_shapes = get_legacy_output_shapes(input_dataset) 5336 flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes) 5337 5338 flat_padded_shapes_as_tensors = [] 5339 5340 for input_component_shape, padded_shape in zip( 5341 nest.flatten(input_shapes), flat_padded_shapes): 5342 flat_padded_shapes_as_tensors.append( 5343 _padded_shape_to_tensor(padded_shape, input_component_shape)) 5344 5345 self._padded_shapes = nest.pack_sequence_as(input_shapes, 5346 flat_padded_shapes_as_tensors) 5347 5348 # If padding_values is a single element and input_shapes is a structure, 5349 # "broadcast" padding_values to the same structure as input_shapes. 5350 if nest.is_nested(input_shapes) and not nest.is_nested(padding_values): 5351 padding_values = nest.map_structure(lambda _: padding_values, 5352 input_shapes) 5353 5354 self._padding_values = nest.map_structure_up_to( 5355 input_shapes, _padding_value_to_tensor, padding_values, 5356 get_legacy_output_types(input_dataset)) 5357 self._drop_remainder = ops.convert_to_tensor( 5358 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 5359 5360 def _padded_shape_to_batch_shape(s): 5361 return tensor_shape.TensorShape([ 5362 tensor_util.constant_value(self._batch_size) 5363 if smart_cond.smart_constant_value(self._drop_remainder) else None 5364 ]).concatenate(tensor_util.constant_value_as_shape(s)) 5365 5366 output_shapes = nest.map_structure( 5367 _padded_shape_to_batch_shape, self._padded_shapes) 5368 self._structure = structure.convert_legacy_structure( 5369 get_legacy_output_types(self._input_dataset), output_shapes, 5370 get_legacy_output_classes(self._input_dataset)) 5371 5372 self._name = name 5373 # pylint: disable=protected-access 5374 variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( 5375 input_dataset._variant_tensor, # pylint: disable=protected-access 5376 batch_size=self._batch_size, 5377 padded_shapes=[ 5378 ops.convert_to_tensor(s, dtype=dtypes.int64) 5379 for s in nest.flatten(self._padded_shapes) 5380 ], 5381 padding_values=nest.flatten(self._padding_values), 5382 drop_remainder=self._drop_remainder, 5383 output_shapes=structure.get_flat_tensor_shapes(self._structure), 5384 metadata=self._metadata.SerializeToString()) 5385 super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) 5386 5387 @property 5388 def element_spec(self): 5389 return self._structure 5390 5391 5392class MapDataset(UnaryDataset): 5393 """A `Dataset` that maps a function over elements in its input.""" 5394 5395 def __init__(self, 5396 input_dataset, 5397 map_func, 5398 use_inter_op_parallelism=True, 5399 preserve_cardinality=False, 5400 use_legacy_function=False, 5401 name=None): 5402 """See `Dataset.map()` for details.""" 5403 self._input_dataset = input_dataset 5404 self._use_inter_op_parallelism = use_inter_op_parallelism 5405 self._preserve_cardinality = preserve_cardinality 5406 self._map_func = structured_function.StructuredFunctionWrapper( 5407 map_func, 5408 self._transformation_name(), 5409 dataset=input_dataset, 5410 use_legacy_function=use_legacy_function) 5411 self._name = name 5412 variant_tensor = gen_dataset_ops.map_dataset( 5413 input_dataset._variant_tensor, # pylint: disable=protected-access 5414 self._map_func.function.captured_inputs, 5415 f=self._map_func.function, 5416 use_inter_op_parallelism=self._use_inter_op_parallelism, 5417 preserve_cardinality=self._preserve_cardinality, 5418 **self._common_args) 5419 super(MapDataset, self).__init__(input_dataset, variant_tensor) 5420 5421 def _functions(self): 5422 return [self._map_func] 5423 5424 @property 5425 def element_spec(self): 5426 return self._map_func.output_structure 5427 5428 def _transformation_name(self): 5429 return "Dataset.map()" 5430 5431 5432class ParallelMapDataset(UnaryDataset): 5433 """A `Dataset` that maps a function over elements in its input in parallel.""" 5434 5435 def __init__(self, 5436 input_dataset, 5437 map_func, 5438 num_parallel_calls, 5439 deterministic, 5440 use_inter_op_parallelism=True, 5441 preserve_cardinality=False, 5442 use_legacy_function=False, 5443 name=None): 5444 """See `Dataset.map()` for details.""" 5445 self._input_dataset = input_dataset 5446 self._use_inter_op_parallelism = use_inter_op_parallelism 5447 self._map_func = structured_function.StructuredFunctionWrapper( 5448 map_func, 5449 self._transformation_name(), 5450 dataset=input_dataset, 5451 use_legacy_function=use_legacy_function) 5452 if deterministic is None: 5453 self._deterministic = "default" 5454 elif deterministic: 5455 self._deterministic = "true" 5456 else: 5457 self._deterministic = "false" 5458 self._preserve_cardinality = preserve_cardinality 5459 self._num_parallel_calls = ops.convert_to_tensor( 5460 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 5461 self._name = name 5462 variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( 5463 input_dataset._variant_tensor, # pylint: disable=protected-access 5464 self._map_func.function.captured_inputs, 5465 f=self._map_func.function, 5466 num_parallel_calls=self._num_parallel_calls, 5467 deterministic=self._deterministic, 5468 use_inter_op_parallelism=self._use_inter_op_parallelism, 5469 preserve_cardinality=self._preserve_cardinality, 5470 **self._common_args) 5471 super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) 5472 5473 def _functions(self): 5474 return [self._map_func] 5475 5476 @property 5477 def element_spec(self): 5478 return self._map_func.output_structure 5479 5480 def _transformation_name(self): 5481 return "Dataset.map()" 5482 5483 5484class FlatMapDataset(UnaryDataset): 5485 """A `Dataset` that maps a function over its input and flattens the result.""" 5486 5487 def __init__(self, input_dataset, map_func, name=None): 5488 """See `Dataset.flat_map()` for details.""" 5489 self._input_dataset = input_dataset 5490 self._map_func = structured_function.StructuredFunctionWrapper( 5491 map_func, self._transformation_name(), dataset=input_dataset) 5492 if not isinstance(self._map_func.output_structure, DatasetSpec): 5493 raise TypeError( 5494 "The `map_func` argument must return a `Dataset` object. Got " 5495 f"{_get_type(self._map_func.output_structure)!r}.") 5496 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5497 self._name = name 5498 variant_tensor = gen_dataset_ops.flat_map_dataset( 5499 input_dataset._variant_tensor, # pylint: disable=protected-access 5500 self._map_func.function.captured_inputs, 5501 f=self._map_func.function, 5502 **self._common_args) 5503 super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) 5504 5505 def _functions(self): 5506 return [self._map_func] 5507 5508 @property 5509 def element_spec(self): 5510 return self._structure 5511 5512 def _transformation_name(self): 5513 return "Dataset.flat_map()" 5514 5515 5516class InterleaveDataset(UnaryDataset): 5517 """A `Dataset` that interleaves the result of transformed inputs.""" 5518 5519 def __init__(self, 5520 input_dataset, 5521 map_func, 5522 cycle_length, 5523 block_length, 5524 name=None): 5525 """See `Dataset.interleave()` for details.""" 5526 5527 self._input_dataset = input_dataset 5528 self._map_func = structured_function.StructuredFunctionWrapper( 5529 map_func, self._transformation_name(), dataset=input_dataset) 5530 if not isinstance(self._map_func.output_structure, DatasetSpec): 5531 raise TypeError( 5532 "The `map_func` argument must return a `Dataset` object. Got " 5533 f"{_get_type(self._map_func.output_structure)!r}.") 5534 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5535 self._cycle_length = ops.convert_to_tensor( 5536 cycle_length, dtype=dtypes.int64, name="cycle_length") 5537 self._block_length = ops.convert_to_tensor( 5538 block_length, dtype=dtypes.int64, name="block_length") 5539 self._name = name 5540 variant_tensor = gen_dataset_ops.interleave_dataset( 5541 input_dataset._variant_tensor, # pylint: disable=protected-access 5542 self._map_func.function.captured_inputs, # pylint: disable=protected-access 5543 self._cycle_length, 5544 self._block_length, 5545 f=self._map_func.function, 5546 **self._common_args) 5547 super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) 5548 5549 def _functions(self): 5550 return [self._map_func] 5551 5552 @property 5553 def element_spec(self): 5554 return self._structure 5555 5556 def _transformation_name(self): 5557 return "Dataset.interleave()" 5558 5559 5560class ParallelInterleaveDataset(UnaryDataset): 5561 """A `Dataset` that maps a function over its input and interleaves the result.""" 5562 5563 def __init__(self, 5564 input_dataset, 5565 map_func, 5566 cycle_length, 5567 block_length, 5568 num_parallel_calls, 5569 buffer_output_elements=AUTOTUNE, 5570 prefetch_input_elements=AUTOTUNE, 5571 deterministic=None, 5572 name=None): 5573 """See `Dataset.interleave()` for details.""" 5574 self._input_dataset = input_dataset 5575 self._map_func = structured_function.StructuredFunctionWrapper( 5576 map_func, self._transformation_name(), dataset=input_dataset) 5577 if not isinstance(self._map_func.output_structure, DatasetSpec): 5578 raise TypeError( 5579 "The `map_func` argument must return a `Dataset` object. Got " 5580 f"{_get_type(self._map_func.output_structure)!r}.") 5581 self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access 5582 self._cycle_length = ops.convert_to_tensor( 5583 cycle_length, dtype=dtypes.int64, name="cycle_length") 5584 self._block_length = ops.convert_to_tensor( 5585 block_length, dtype=dtypes.int64, name="block_length") 5586 self._buffer_output_elements = ops.convert_to_tensor( 5587 buffer_output_elements, 5588 dtype=dtypes.int64, 5589 name="buffer_output_elements") 5590 self._prefetch_input_elements = ops.convert_to_tensor( 5591 prefetch_input_elements, 5592 dtype=dtypes.int64, 5593 name="prefetch_input_elements") 5594 5595 self._num_parallel_calls = ops.convert_to_tensor( 5596 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 5597 if deterministic is None: 5598 deterministic_string = "default" 5599 elif deterministic: 5600 deterministic_string = "true" 5601 else: 5602 deterministic_string = "false" 5603 5604 self._name = name 5605 variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4( 5606 input_dataset._variant_tensor, # pylint: disable=protected-access 5607 self._map_func.function.captured_inputs, # pylint: disable=protected-access 5608 self._cycle_length, 5609 self._block_length, 5610 self._buffer_output_elements, 5611 self._prefetch_input_elements, 5612 self._num_parallel_calls, 5613 f=self._map_func.function, 5614 deterministic=deterministic_string, 5615 **self._common_args) 5616 super(ParallelInterleaveDataset, self).__init__(input_dataset, 5617 variant_tensor) 5618 5619 def _functions(self): 5620 return [self._map_func] 5621 5622 @property 5623 def element_spec(self): 5624 return self._structure 5625 5626 def _transformation_name(self): 5627 return "Dataset.interleave()" 5628 5629 5630class FilterDataset(UnaryUnchangedStructureDataset): 5631 """A `Dataset` that filters its input according to a predicate function.""" 5632 5633 def __init__(self, 5634 input_dataset, 5635 predicate, 5636 use_legacy_function=False, 5637 name=None): 5638 """See `Dataset.filter()` for details.""" 5639 self._input_dataset = input_dataset 5640 wrapped_func = structured_function.StructuredFunctionWrapper( 5641 predicate, 5642 self._transformation_name(), 5643 dataset=input_dataset, 5644 use_legacy_function=use_legacy_function) 5645 if not wrapped_func.output_structure.is_compatible_with( 5646 tensor_spec.TensorSpec([], dtypes.bool)): 5647 raise ValueError(f"Invalid `predicate`. `predicate` must return a " 5648 f"`tf.bool` scalar tensor, but its return type is " 5649 f"{wrapped_func.output_structure}.") 5650 self._predicate = wrapped_func 5651 self._name = name 5652 variant_tensor = gen_dataset_ops.filter_dataset( 5653 input_dataset._variant_tensor, # pylint: disable=protected-access 5654 other_arguments=self._predicate.function.captured_inputs, 5655 predicate=self._predicate.function, 5656 **self._common_args) 5657 super(FilterDataset, self).__init__(input_dataset, variant_tensor) 5658 5659 def _functions(self): 5660 return [self._predicate] 5661 5662 def _transformation_name(self): 5663 return "Dataset.filter()" 5664 5665 5666class PrefetchDataset(UnaryUnchangedStructureDataset): 5667 """A `Dataset` that asynchronously prefetches its input.""" 5668 5669 def __init__(self, input_dataset, buffer_size, slack_period=None, name=None): 5670 """See `Dataset.prefetch()` for details.""" 5671 self._input_dataset = input_dataset 5672 if buffer_size is None: 5673 buffer_size = AUTOTUNE 5674 self._buffer_size = ops.convert_to_tensor( 5675 buffer_size, dtype=dtypes.int64, name="buffer_size") 5676 self._name = name 5677 # pylint: disable=protected-access 5678 # We colocate the prefetch dataset with its input as this collocation only 5679 # happens automatically in graph mode. 5680 with ops.colocate_with(input_dataset._variant_tensor): 5681 variant_tensor = gen_dataset_ops.prefetch_dataset( 5682 input_dataset._variant_tensor, 5683 buffer_size=self._buffer_size, 5684 slack_period=slack_period, 5685 **self._common_args) 5686 super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) 5687 5688 5689class WindowDataset(UnaryDataset): 5690 """A dataset that creates window datasets from the input elements.""" 5691 5692 def __init__(self, 5693 input_dataset, 5694 size, 5695 shift, 5696 stride, 5697 drop_remainder, 5698 name=None): 5699 """See `window()` for more details.""" 5700 self._input_dataset = input_dataset 5701 self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") 5702 self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") 5703 self._stride = ops.convert_to_tensor( 5704 stride, dtype=dtypes.int64, name="stride") 5705 self._drop_remainder = ops.convert_to_tensor( 5706 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 5707 self._structure = nest.pack_sequence_as( 5708 get_legacy_output_classes(input_dataset), [ 5709 DatasetSpec( # pylint: disable=g-complex-comprehension 5710 structure.convert_legacy_structure( 5711 output_type, output_shape, output_class)) 5712 for output_class, output_shape, output_type in zip( 5713 nest.flatten(get_legacy_output_classes(input_dataset)), 5714 nest.flatten(get_legacy_output_shapes(input_dataset)), 5715 nest.flatten(get_legacy_output_types(input_dataset))) 5716 ]) 5717 self._name = name 5718 variant_tensor = gen_dataset_ops.window_dataset( 5719 input_dataset._variant_tensor, # pylint: disable=protected-access 5720 size=self._size, 5721 shift=self._shift, 5722 stride=self._stride, 5723 drop_remainder=self._drop_remainder, 5724 **self._common_args) 5725 super(WindowDataset, self).__init__(input_dataset, variant_tensor) 5726 5727 @property 5728 def element_spec(self): 5729 return self._structure 5730 5731 5732class _OptionsDataset(UnaryUnchangedStructureDataset): 5733 """An identity `Dataset` that stores options.""" 5734 5735 def __init__(self, input_dataset, options, name=None): 5736 # pylint: disable=protected-access 5737 self._input_dataset = input_dataset 5738 options_pb = dataset_options_pb2.Options() 5739 options_pb.CopyFrom(options._to_proto()) 5740 self._name = name 5741 with ops.colocate_with(input_dataset._variant_tensor): 5742 variant_tensor = gen_dataset_ops.options_dataset( 5743 input_dataset._variant_tensor, options_pb.SerializeToString(), 5744 **self._common_args) 5745 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 5746 5747 if self._options_attr: 5748 self._options_attr._set_mutable(True) 5749 self._options_attr = self._options_attr.merge(options) 5750 else: 5751 self._options_attr = options 5752 self._options_attr._set_mutable(False) 5753 5754 5755def normalize_to_dense(dataset): 5756 """Normalizes non-tensor components in a dataset to dense representations. 5757 5758 This is necessary for dataset transformations that slice along the batch 5759 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`. 5760 5761 Args: 5762 dataset: Dataset to normalize. 5763 5764 Returns: 5765 A dataset whose sparse and ragged tensors have been normalized to their 5766 dense representations. 5767 """ 5768 5769 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all 5770 # non-tensor components. 5771 # 5772 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck. 5773 if structured_function._should_unpack(dataset.element_spec): # pylint: disable=protected-access 5774 5775 def normalize(*args): 5776 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args)) 5777 else: 5778 def normalize(arg): 5779 return structure.to_batched_tensor_list(dataset.element_spec, arg) 5780 5781 normalized_dataset = dataset.map(normalize) 5782 5783 # NOTE(mrry): Our `map()` has lost information about the structure of 5784 # non-tensor components, so re-apply the structure of the original dataset. 5785 return _RestructuredDataset(normalized_dataset, dataset.element_spec) 5786 5787 5788class _RestructuredDataset(UnaryDataset): 5789 """An internal helper for changing the element spec of a dataset.""" 5790 5791 def __init__(self, dataset, element_spec): 5792 self._input_dataset = dataset 5793 self._element_spec = element_spec 5794 5795 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 5796 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 5797 5798 @property 5799 def element_spec(self): 5800 return self._element_spec 5801 5802 5803class _UnbatchDataset(UnaryDataset): 5804 """A dataset that splits the elements of its input into multiple elements.""" 5805 5806 def __init__(self, input_dataset, name=None): 5807 """See `unbatch()` for more details.""" 5808 flat_shapes = input_dataset._flat_shapes # pylint: disable=protected-access 5809 if any(s.ndims == 0 for s in flat_shapes): 5810 raise ValueError("Cannot unbatch an input with scalar components.") 5811 known_batch_dim = tensor_shape.Dimension(None) 5812 for s in flat_shapes: 5813 try: 5814 known_batch_dim = known_batch_dim.merge_with(s[0]) 5815 except ValueError: 5816 raise ValueError(f"`unbatch()` is only supported for datasets of " 5817 f"elements whose components have a matching leading " 5818 f"dimension. Encountered both {known_batch_dim} and " 5819 f"{s[0]}.") 5820 self._input_dataset = input_dataset 5821 self._structure = nest.map_structure( 5822 lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access 5823 get_structure(input_dataset)) 5824 self._name = name 5825 variant_tensor = ged_ops.unbatch_dataset( 5826 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5827 **self._common_args) 5828 super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) 5829 5830 @property 5831 def element_spec(self): 5832 return self._structure 5833 5834 5835class _GroupByWindowDataset(UnaryDataset): 5836 """A `Dataset` that groups its input and performs a windowed reduction.""" 5837 5838 def __init__(self, 5839 input_dataset, 5840 key_func, 5841 reduce_func, 5842 window_size_func, 5843 name=None): 5844 """See `group_by_window()` for details.""" 5845 self._input_dataset = input_dataset 5846 self._make_key_func(key_func, input_dataset) 5847 self._make_reduce_func(reduce_func, input_dataset) 5848 self._make_window_size_func(window_size_func) 5849 self._name = name 5850 variant_tensor = ged_ops.group_by_window_dataset( 5851 self._input_dataset._variant_tensor, # pylint: disable=protected-access 5852 self._key_func.function.captured_inputs, 5853 self._reduce_func.function.captured_inputs, 5854 self._window_size_func.function.captured_inputs, 5855 key_func=self._key_func.function, 5856 reduce_func=self._reduce_func.function, 5857 window_size_func=self._window_size_func.function, 5858 **self._common_args) 5859 super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) 5860 5861 def _make_window_size_func(self, window_size_func): 5862 """Make wrapping defun for window_size_func.""" 5863 5864 def window_size_func_wrapper(key): 5865 return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) 5866 5867 self._window_size_func = structured_function.StructuredFunctionWrapper( 5868 window_size_func_wrapper, 5869 self._transformation_name(), 5870 input_structure=tensor_spec.TensorSpec([], dtypes.int64)) 5871 if not self._window_size_func.output_structure.is_compatible_with( 5872 tensor_spec.TensorSpec([], dtypes.int64)): 5873 raise ValueError(f"Invalid `window_size_func`. `window_size_func` must " 5874 f"return a single `tf.int64` scalar tensor but its " 5875 f"return type is " 5876 f"{self._window_size_func.output_structure}.") 5877 5878 def _make_key_func(self, key_func, input_dataset): 5879 """Make wrapping defun for key_func.""" 5880 5881 def key_func_wrapper(*args): 5882 return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) 5883 5884 self._key_func = structured_function.StructuredFunctionWrapper( 5885 key_func_wrapper, self._transformation_name(), dataset=input_dataset) 5886 if not self._key_func.output_structure.is_compatible_with( 5887 tensor_spec.TensorSpec([], dtypes.int64)): 5888 raise ValueError(f"Invalid `key_func`. `key_func` must return a single " 5889 f"`tf.int64` scalar tensor but its return type is " 5890 f"{self._key_func.output_structure}.") 5891 5892 def _make_reduce_func(self, reduce_func, input_dataset): 5893 """Make wrapping defun for reduce_func.""" 5894 nested_dataset = DatasetSpec(input_dataset.element_spec) 5895 input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset) 5896 self._reduce_func = structured_function.StructuredFunctionWrapper( 5897 reduce_func, 5898 self._transformation_name(), 5899 input_structure=input_structure) 5900 if not isinstance(self._reduce_func.output_structure, DatasetSpec): 5901 raise TypeError(f"Invalid `reduce_func`. `reduce_func` must return a " 5902 f"single `tf.data.Dataset` object but its return type " 5903 f"is {self._reduce_func.output_structure}.") 5904 # pylint: disable=protected-access 5905 self._element_spec = (self._reduce_func.output_structure._element_spec) 5906 5907 @property 5908 def element_spec(self): 5909 return self._element_spec 5910 5911 def _functions(self): 5912 return [self._key_func, self._reduce_func, self._window_size_func] 5913 5914 def _transformation_name(self): 5915 return "Dataset.group_by_window()" 5916 5917 5918class RandomDataset(DatasetSource): 5919 """A `Dataset` of pseudorandom values.""" 5920 5921 def __init__(self, seed=None, name=None): 5922 """A `Dataset` of pseudorandom values.""" 5923 self._seed, self._seed2 = random_seed.get_seed(seed) 5924 self._name = name 5925 variant_tensor = ged_ops.random_dataset( 5926 seed=self._seed, seed2=self._seed2, **self._common_args) 5927 super(RandomDataset, self).__init__(variant_tensor) 5928 5929 @property 5930 def element_spec(self): 5931 return tensor_spec.TensorSpec([], dtypes.int64) 5932 5933 5934def _get_prob_original_static(initial_dist_t, target_dist_t): 5935 """Returns the static probability of sampling from the original. 5936 5937 `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters 5938 an Op that it isn't defined for. We have some custom logic to avoid this. 5939 5940 Args: 5941 initial_dist_t: A tensor of the initial distribution. 5942 target_dist_t: A tensor of the target distribution. 5943 5944 Returns: 5945 The probability of sampling from the original distribution as a constant, 5946 if it is a constant, or `None`. 5947 """ 5948 init_static = tensor_util.constant_value(initial_dist_t) 5949 target_static = tensor_util.constant_value(target_dist_t) 5950 5951 if init_static is None or target_static is None: 5952 return None 5953 else: 5954 return np.min(target_static / init_static) 5955 5956 5957def _filter_ds(dataset, 5958 acceptance_dist_ds, 5959 initial_dist_ds, 5960 class_func, 5961 seed, 5962 name=None): 5963 """Filters a dataset based on per-class acceptance probabilities. 5964 5965 Args: 5966 dataset: The dataset to be filtered. 5967 acceptance_dist_ds: A dataset of acceptance probabilities. 5968 initial_dist_ds: A dataset of the initial probability distribution, given or 5969 estimated. 5970 class_func: A function mapping an element of the input dataset to a scalar 5971 `tf.int32` tensor. Values should be in `[0, num_classes)`. 5972 seed: (Optional.) Python integer seed for the resampler. 5973 name: (Optional.) A name for the tf.data operation. 5974 5975 Returns: 5976 A dataset of (class value, data) after filtering. 5977 """ 5978 5979 def maybe_warn_on_large_rejection(accept_dist, initial_dist): 5980 proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) 5981 return control_flow_ops.cond( 5982 math_ops.less(proportion_rejected, .5), 5983 lambda: accept_dist, 5984 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 5985 accept_dist, [proportion_rejected, initial_dist, accept_dist], 5986 message="Proportion of examples rejected by sampler is high: ", 5987 summarize=100, 5988 first_n=10)) 5989 5990 acceptance_dist_ds = ( 5991 DatasetV2.zip((acceptance_dist_ds, initial_dist_ds), 5992 name=name).map(maybe_warn_on_large_rejection, name=name)) 5993 5994 def _gather_and_copy(acceptance_prob, data): 5995 if isinstance(data, tuple): 5996 class_val = class_func(*data) 5997 else: 5998 class_val = class_func(data) 5999 return class_val, array_ops.gather(acceptance_prob, class_val), data 6000 6001 current_probabilities_and_class_and_data_ds = DatasetV2.zip( 6002 (acceptance_dist_ds, dataset), name=name).map( 6003 _gather_and_copy, name=name) 6004 6005 def _reject(unused_class_val, p, unused_data): 6006 return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p 6007 6008 filtered_ds = current_probabilities_and_class_and_data_ds.filter( 6009 _reject, name=name) 6010 return filtered_ds.map( 6011 lambda class_value, _, data: (class_value, data), name=name) 6012 6013 6014# pylint: disable=missing-function-docstring 6015def _estimate_initial_dist_ds(target_dist_t, 6016 class_values_ds, 6017 dist_estimation_batch_size=32, 6018 smoothing_constant=10, 6019 name=None): 6020 num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) 6021 initial_examples_per_class_seen = array_ops.fill([num_classes], 6022 np.int64(smoothing_constant)) 6023 6024 def update_estimate_and_tile(num_examples_per_class_seen, c): 6025 updated_examples_per_class_seen, dist = _estimate_data_distribution( 6026 c, num_examples_per_class_seen) 6027 tiled_dist = array_ops.tile( 6028 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) 6029 return updated_examples_per_class_seen, tiled_dist 6030 6031 initial_dist_ds = ( 6032 class_values_ds.batch(dist_estimation_batch_size, name=name).scan( 6033 initial_examples_per_class_seen, update_estimate_and_tile, 6034 name=name).unbatch(name=name)) 6035 6036 return initial_dist_ds 6037 6038 6039def _get_target_to_initial_ratio(initial_probs, target_probs): 6040 # Add tiny to initial_probs to avoid divide by zero. 6041 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) 6042 return target_probs / denom 6043 6044 6045def _estimate_data_distribution(c, num_examples_per_class_seen): 6046 """Estimate data distribution as labels are seen. 6047 6048 Args: 6049 c: The class labels. Type `int32`, shape `[batch_size]`. 6050 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing 6051 counts. 6052 6053 Returns: 6054 num_examples_per_lass_seen: Updated counts. Type `int64`, shape 6055 `[num_classes]`. 6056 dist: The updated distribution. Type `float32`, shape `[num_classes]`. 6057 """ 6058 num_classes = num_examples_per_class_seen.get_shape()[0] 6059 # Update the class-count based on what labels are seen in batch. 6060 num_examples_per_class_seen = math_ops.add( 6061 num_examples_per_class_seen, 6062 math_ops.reduce_sum( 6063 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) 6064 init_prob_estimate = math_ops.truediv( 6065 num_examples_per_class_seen, 6066 math_ops.reduce_sum(num_examples_per_class_seen)) 6067 dist = math_ops.cast(init_prob_estimate, dtypes.float32) 6068 return num_examples_per_class_seen, dist 6069 6070 6071def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): 6072 """Calculates the acceptance probabilities and mixing ratio. 6073 6074 In this case, we assume that we can *either* sample from the original data 6075 distribution with probability `m`, or sample from a reshaped distribution 6076 that comes from rejection sampling on the original distribution. This 6077 rejection sampling is done on a per-class basis, with `a_i` representing the 6078 probability of accepting data from class `i`. 6079 6080 This method is based on solving the following analysis for the reshaped 6081 distribution: 6082 6083 Let F be the probability of a rejection (on any example). 6084 Let p_i be the proportion of examples in the data in class i (init_probs) 6085 Let a_i is the rate the rejection sampler should *accept* class i 6086 Let t_i is the target proportion in the minibatches for class i (target_probs) 6087 6088 ``` 6089 F = sum_i(p_i * (1-a_i)) 6090 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 6091 ``` 6092 6093 An example with class `i` will be accepted if `k` rejections occur, then an 6094 example with class `i` is seen by the rejector, and it is accepted. This can 6095 be written as follows: 6096 6097 ``` 6098 t_i = sum_k=0^inf(F^k * p_i * a_i) 6099 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 6100 = p_i * a_i / sum_j(p_j * a_j) using F from above 6101 ``` 6102 6103 Note that the following constraints hold: 6104 ``` 6105 0 <= p_i <= 1, sum_i(p_i) = 1 6106 0 <= a_i <= 1 6107 0 <= t_i <= 1, sum_i(t_i) = 1 6108 ``` 6109 6110 A solution for a_i in terms of the other variables is the following: 6111 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 6112 6113 If we try to minimize the amount of data rejected, we get the following: 6114 6115 M_max = max_i [ t_i / p_i ] 6116 M_min = min_i [ t_i / p_i ] 6117 6118 The desired probability of accepting data if it comes from class `i`: 6119 6120 a_i = (t_i/p_i - m) / (M_max - m) 6121 6122 The desired probability of pulling a data element from the original dataset, 6123 rather than the filtered one: 6124 6125 m = M_min 6126 6127 Args: 6128 initial_probs: A Tensor of the initial probability distribution, given or 6129 estimated. 6130 target_probs: A Tensor of the corresponding classes. 6131 6132 Returns: 6133 (A 1D Tensor with the per-class acceptance probabilities, the desired 6134 probability of pull from the original distribution.) 6135 """ 6136 ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) 6137 max_ratio = math_ops.reduce_max(ratio_l) 6138 min_ratio = math_ops.reduce_min(ratio_l) 6139 6140 # Target prob to sample from original distribution. 6141 m = min_ratio 6142 6143 # TODO(joelshor): Simplify fraction, if possible. 6144 a_i = (ratio_l - m) / (max_ratio - m) 6145 return a_i, m 6146 6147 6148class _TakeWhileDataset(UnaryUnchangedStructureDataset): 6149 """A dataset that stops iteration when `predicate` returns false.""" 6150 6151 def __init__(self, input_dataset, predicate, name=None): 6152 """See `take_while()` for details.""" 6153 6154 self._input_dataset = input_dataset 6155 wrapped_func = structured_function.StructuredFunctionWrapper( 6156 predicate, self._transformation_name(), dataset=self._input_dataset) 6157 6158 if not wrapped_func.output_structure.is_compatible_with( 6159 tensor_spec.TensorSpec([], dtypes.bool)): 6160 raise ValueError(f"Invalid `predicate`. `predicate` must return a " 6161 f"`tf.bool` scalar tensor but its return type is" 6162 f"{wrapped_func.output_structure}.") 6163 6164 self._predicate = wrapped_func 6165 self._name = name 6166 variant_tensor = ged_ops.take_while_dataset( 6167 self._input_dataset._variant_tensor, # pylint: disable=protected-access 6168 other_arguments=self._predicate.function.captured_inputs, 6169 predicate=self._predicate.function, 6170 **self._common_args) 6171 super(_TakeWhileDataset, self).__init__(input_dataset, variant_tensor) 6172 6173 def _functions(self): 6174 return [self._predicate] 6175 6176 def _transformation_name(self): 6177 return "Dataset.take_while()" 6178 6179 6180class _UniqueDataset(UnaryUnchangedStructureDataset): 6181 """A `Dataset` contains the unique elements from its input.""" 6182 6183 def __init__(self, input_dataset, name=None): 6184 """See `unique()` for details.""" 6185 self._input_dataset = input_dataset 6186 for ty in nest.flatten(get_legacy_output_types(input_dataset)): 6187 if ty not in (dtypes.int32, dtypes.int64, dtypes.string): 6188 raise TypeError( 6189 f"`unique()` does not support type {ty}, only `tf.int32`, " 6190 f"`tf.int64`, and `tf.string` are supported.") 6191 self._name = name 6192 variant_tensor = ged_ops.unique_dataset( 6193 self._input_dataset._variant_tensor, # pylint: disable=protected-access 6194 **self._common_args) 6195 super(_UniqueDataset, self).__init__(input_dataset, variant_tensor) 6196 6197 6198class _SnapshotDataset(UnaryUnchangedStructureDataset): 6199 """A dataset that allows saving and re-use of already processed data.""" 6200 6201 def __init__(self, 6202 input_dataset, 6203 path, 6204 shard_func, 6205 compression=None, 6206 reader_func=None, 6207 pending_snapshot_expiry_seconds=None, 6208 use_legacy_function=False, 6209 name=None): 6210 6211 if reader_func is None: 6212 reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda 6213 lambda x: x, 6214 cycle_length=multiprocessing.cpu_count(), 6215 num_parallel_calls=AUTOTUNE) 6216 6217 self._input_dataset = input_dataset 6218 self._path = path 6219 self._compression = compression 6220 6221 self._reader_func = structured_function.StructuredFunctionWrapper( 6222 reader_func, 6223 self._transformation_name() + ".reader_func", 6224 # Dataset of datasets of input elements 6225 input_structure=DatasetSpec(DatasetSpec(input_dataset.element_spec)), 6226 use_legacy_function=use_legacy_function) 6227 self._shard_func = structured_function.StructuredFunctionWrapper( 6228 shard_func, 6229 self._transformation_name() + ".shard_func", 6230 dataset=input_dataset, 6231 use_legacy_function=use_legacy_function) 6232 6233 if ((not self._shard_func.output_structure.is_compatible_with( 6234 tensor_spec.TensorSpec([], dtypes.int32))) and 6235 (not self._shard_func.output_structure.is_compatible_with( 6236 tensor_spec.TensorSpec([], dtypes.int64)))): 6237 raise TypeError(f"Invalid `shard_func`. `shard_func` must return " 6238 f"`tf.int64` scalar tensor but its return type is " 6239 f"{self._shard_func.output_structure}.") 6240 6241 self._name = name 6242 variant_tensor = ged_ops.snapshot_dataset_v2( 6243 input_dataset._variant_tensor, # pylint: disable=protected-access 6244 path, 6245 self._reader_func.function.captured_inputs, 6246 self._shard_func.function.captured_inputs, 6247 compression=compression, 6248 reader_func=self._reader_func.function, 6249 shard_func=self._shard_func.function, 6250 **self._common_args) 6251 super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor) 6252 6253 def _functions(self): 6254 return [self._reader_func, self._shard_func] 6255 6256 def _transformation_name(self): 6257 return "Dataset.snapshot()" 6258 6259 6260class _ScanDataset(UnaryDataset): 6261 """A dataset that scans a function across its input.""" 6262 6263 def __init__(self, 6264 input_dataset, 6265 initial_state, 6266 scan_func, 6267 use_default_device=None, 6268 name=None): 6269 """See `scan()` for details.""" 6270 self._input_dataset = input_dataset 6271 self._initial_state = structure.normalize_element(initial_state) 6272 6273 # Compute initial values for the state classes, shapes and types based on 6274 # the initial state. The shapes may be refined by running `tf_scan_func` one 6275 # or more times below. 6276 self._state_structure = structure.type_spec_from_value(self._initial_state) 6277 6278 # Iteratively rerun the scan function until reaching a fixed point on 6279 # `self._state_shapes`. 6280 need_to_rerun = True 6281 while need_to_rerun: 6282 6283 wrapped_func = structured_function.StructuredFunctionWrapper( 6284 scan_func, 6285 self._transformation_name(), 6286 input_structure=(self._state_structure, input_dataset.element_spec), 6287 add_to_graph=False) 6288 if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) 6289 and len(wrapped_func.output_types) == 2): 6290 raise TypeError(f"Invalid `scan_func`. `scan_func` should return a " 6291 f"pair consisting of new state and the output value " 6292 f"but its return type is " 6293 f"{wrapped_func.output_structure}.") 6294 6295 new_state_classes, self._output_classes = wrapped_func.output_classes 6296 6297 # Extract and validate class information from the returned values. 6298 new_state_classes, output_classes = wrapped_func.output_classes 6299 old_state_classes = nest.map_structure( 6300 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 6301 self._state_structure) 6302 for new_state_class, old_state_class in zip( 6303 nest.flatten(new_state_classes), nest.flatten(old_state_classes)): 6304 if not issubclass(new_state_class, old_state_class): 6305 raise TypeError(f"Invalid `scan_func`. The element classes for the " 6306 f"new state must match the initial state. Expected " 6307 f"{old_state_classes}, got {new_state_classes}.") 6308 6309 # Extract and validate type information from the returned values. 6310 new_state_types, output_types = wrapped_func.output_types 6311 old_state_types = nest.map_structure( 6312 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 6313 self._state_structure) 6314 for new_state_type, old_state_type in zip( 6315 nest.flatten(new_state_types), nest.flatten(old_state_types)): 6316 if new_state_type != old_state_type: 6317 raise TypeError(f"Invalid `scan_func`. The element types for the " 6318 f"new state must match the initial state. Expected " 6319 f"{old_state_types}, got {new_state_types}.") 6320 6321 # Extract shape information from the returned values. 6322 new_state_shapes, output_shapes = wrapped_func.output_shapes 6323 old_state_shapes = nest.map_structure( 6324 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 6325 self._state_structure) 6326 self._element_spec = structure.convert_legacy_structure( 6327 output_types, output_shapes, output_classes) 6328 6329 flat_state_shapes = nest.flatten(old_state_shapes) 6330 flat_new_state_shapes = nest.flatten(new_state_shapes) 6331 weakened_state_shapes = [ 6332 original.most_specific_compatible_shape(new) 6333 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 6334 ] 6335 6336 need_to_rerun = False 6337 for original_shape, weakened_shape in zip(flat_state_shapes, 6338 weakened_state_shapes): 6339 if original_shape.ndims is not None and ( 6340 weakened_shape.ndims is None or 6341 original_shape.as_list() != weakened_shape.as_list()): 6342 need_to_rerun = True 6343 break 6344 6345 if need_to_rerun: 6346 # TODO(b/110122868): Support a "most specific compatible structure" 6347 # method for combining structures, to avoid using legacy structures 6348 # in this method. 6349 self._state_structure = structure.convert_legacy_structure( 6350 old_state_types, 6351 nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), 6352 old_state_classes) 6353 6354 self._scan_func = wrapped_func 6355 self._scan_func.function.add_to_graph(ops.get_default_graph()) 6356 6357 self._name = name 6358 # pylint: disable=protected-access 6359 if use_default_device is not None: 6360 variant_tensor = ged_ops.scan_dataset( 6361 self._input_dataset._variant_tensor, 6362 structure.to_tensor_list(self._state_structure, self._initial_state), 6363 self._scan_func.function.captured_inputs, 6364 f=self._scan_func.function, 6365 preserve_cardinality=True, 6366 use_default_device=use_default_device, 6367 **self._common_args) 6368 else: 6369 variant_tensor = ged_ops.scan_dataset( 6370 self._input_dataset._variant_tensor, 6371 structure.to_tensor_list(self._state_structure, self._initial_state), 6372 self._scan_func.function.captured_inputs, 6373 f=self._scan_func.function, 6374 preserve_cardinality=True, 6375 **self._common_args) 6376 super(_ScanDataset, self).__init__(input_dataset, variant_tensor) 6377 6378 def _functions(self): 6379 return [self._scan_func] 6380 6381 @property 6382 def element_spec(self): 6383 return self._element_spec 6384 6385 def _transformation_name(self): 6386 return "Dataset.scan()" 6387 6388 6389class _DirectedInterleaveDataset(DatasetV2): 6390 """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" 6391 6392 def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False): 6393 self._selector_input = selector_input 6394 self._data_inputs = list(data_inputs) 6395 self._stop_on_empty_dataset = stop_on_empty_dataset 6396 6397 spec = self._data_inputs[0].element_spec 6398 for i, data_input in enumerate(self._data_inputs[1:]): 6399 def common_supertype(a, b): 6400 result = a.most_specific_common_supertype([b]) 6401 if result is None: 6402 raise TypeError(f"No common supertype of {a} and {b}.") 6403 return result 6404 6405 try: 6406 spec = nest.map_structure(common_supertype, spec, 6407 data_input.element_spec) 6408 except (TypeError, ValueError) as e: 6409 raise TypeError(f"Invalid `datasets`. `datasets` must have compatible " 6410 f"element specs.\n Dataset 0 " 6411 f"element_spec={data_inputs[0].element_spec}.\n" 6412 f"Dataset {i+1} " 6413 f"element_spec={data_input.element_spec}.") from e 6414 self._element_spec = spec 6415 6416 # pylint: disable=protected-access 6417 variant_tensor = ( 6418 ged_ops.directed_interleave_dataset( 6419 self._selector_input._variant_tensor, 6420 [data_input._variant_tensor for data_input in self._data_inputs], 6421 stop_on_empty_dataset=self._stop_on_empty_dataset, 6422 **self._flat_structure)) 6423 6424 super(_DirectedInterleaveDataset, self).__init__(variant_tensor) 6425 6426 def _inputs(self): 6427 return [self._selector_input] + self._data_inputs 6428 6429 @property 6430 def element_spec(self): 6431 return self._element_spec 6432 6433 6434def _apply_rewrite(dataset, rewrite): 6435 # pylint: disable=protected-access 6436 return _VariantDataset( 6437 gen_dataset_ops.rewrite_dataset(dataset._variant_tensor, rewrite, 6438 **dataset._flat_structure), 6439 dataset.element_spec) 6440 6441 6442def _collect_resource_inputs(op): 6443 """Collects resource inputs for the given ops (and its variant inputs).""" 6444 6445 def _process(op_queue, seen_ops): 6446 """Processes the next element of the op queue. 6447 6448 Args: 6449 op_queue: Queue of Dataset operations to process. 6450 seen_ops: Already processed set of Operations. 6451 6452 Returns: 6453 A 2-tuple containing sets of resource handles. The first tuple entry 6454 contains read-only handles and the second entry contains read-write 6455 handles. 6456 """ 6457 6458 reads = [] 6459 writes = [] 6460 op = op_queue.pop() 6461 if op in seen_ops: 6462 return reads, writes 6463 seen_ops.add(op) 6464 # TODO(b/150139257): All resource inputs are in writes right now since we 6465 # have not updated the functional ops to set the special attribute that ACD 6466 # uses to figure out which of the op's inputs are read-only. 6467 reads, writes = acd_utils.get_read_write_resource_inputs(op) 6468 # Conservatively assume that any variant inputs are datasets. 6469 op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant) 6470 return reads, writes 6471 6472 op_queue = [op] 6473 seen_ops = set() 6474 all_reads = [] 6475 all_writes = [] 6476 while op_queue: 6477 reads, writes = _process(op_queue, seen_ops) 6478 all_reads.extend(reads) 6479 all_writes.extend(writes) 6480 6481 return all_reads, all_writes 6482 6483 6484@auto_control_deps.register_acd_resource_resolver 6485def _resource_resolver(op, resource_reads, resource_writes): 6486 """Updates resource inputs for tf.data ops with indirect dependencies.""" 6487 6488 updated = False 6489 if op.type in [ 6490 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" 6491 ]: 6492 reads, writes = _collect_resource_inputs(op) 6493 for inp in reads: 6494 if inp not in resource_reads: 6495 updated = True 6496 resource_reads.add(inp) 6497 for inp in writes: 6498 if inp not in resource_writes: 6499 updated = True 6500 resource_writes.add(inp) 6501 6502 if op.type in [ 6503 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" 6504 ]: 6505 iterator_resource = op.inputs[0] 6506 make_iterator_ops = [ 6507 op for op in iterator_resource.consumers() if op.type == "MakeIterator" 6508 ] 6509 6510 if len(make_iterator_ops) == 1: 6511 reads, writes = _collect_resource_inputs(make_iterator_ops[0]) 6512 for inp in reads: 6513 if inp not in resource_reads: 6514 updated = True 6515 resource_reads.add(inp) 6516 for inp in writes: 6517 if inp not in resource_writes: 6518 updated = True 6519 resource_writes.add(inp) 6520 6521 return updated 6522 6523 6524DEBUG_MODE = False 6525 6526 6527@tf_export("data.experimental.enable_debug_mode") 6528def enable_debug_mode(): 6529 """Enables debug mode for tf.data. 6530 6531 Example usage with pdb module: 6532 ``` 6533 import tensorflow as tf 6534 import pdb 6535 6536 tf.data.experimental.enable_debug_mode() 6537 6538 def func(x): 6539 # Python 3.7 and older requires `pdb.Pdb(nosigint=True).set_trace()` 6540 pdb.set_trace() 6541 x = x + 1 6542 return x 6543 6544 dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 6545 dataset = dataset.map(func) 6546 6547 for item in dataset: 6548 print(item) 6549 ``` 6550 6551 The effect of debug mode is two-fold: 6552 6553 1) Any transformations that would introduce asynchrony, parallelism, or 6554 non-determinism to the input pipeline execution will be forced to execute 6555 synchronously, sequentially, and deterministically. 6556 6557 2) Any user-defined functions passed into tf.data transformations such as 6558 `map` will be wrapped in `tf.py_function` so that their body is executed 6559 "eagerly" as a Python function as opposed to a traced TensorFlow graph, which 6560 is the default behavior. Note that even when debug mode is enabled, the 6561 user-defined function is still traced to infer the shape and type of its 6562 outputs; as a consequence, any `print` statements or breakpoints will be 6563 triggered once during the tracing before the actual execution of the input 6564 pipeline. 6565 6566 NOTE: As the debug mode setting affects the construction of the tf.data input 6567 pipeline, it should be enabled before any tf.data definitions. 6568 6569 Raises: 6570 ValueError: When invoked from graph mode. 6571 """ 6572 if context.executing_eagerly(): 6573 toggle_debug_mode(True) 6574 else: 6575 raise ValueError("`enable_debug_mode() is only supported in eager mode.") 6576 6577 6578def toggle_debug_mode(debug_mode): 6579 global DEBUG_MODE 6580 DEBUG_MODE = debug_mode 6581