1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental `dataset` API for parsing example.""" 16from tensorflow.python.data.ops import dataset_ops 17from tensorflow.python.data.util import structure 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import sparse_tensor 20from tensorflow.python.framework import tensor_spec 21from tensorflow.python.ops import gen_experimental_dataset_ops 22from tensorflow.python.ops import parsing_ops 23from tensorflow.python.ops.ragged import ragged_tensor 24from tensorflow.python.util.tf_export import tf_export 25 26 27class _ParseExampleDataset(dataset_ops.UnaryDataset): 28 """A `Dataset` that parses `example` dataset into a `dict` dataset.""" 29 30 def __init__(self, input_dataset, features, num_parallel_calls, 31 deterministic): 32 self._input_dataset = input_dataset 33 if not structure.are_compatible( 34 input_dataset.element_spec, 35 tensor_spec.TensorSpec([None], dtypes.string)): 36 raise TypeError("Input dataset should be a dataset of vectors of " 37 f"strings. Instead it is `{input_dataset.element_spec}`.") 38 self._num_parallel_calls = num_parallel_calls 39 if deterministic is None: 40 self._deterministic = "default" 41 elif deterministic: 42 self._deterministic = "true" 43 else: 44 self._deterministic = "false" 45 # pylint: disable=protected-access 46 self._features = parsing_ops._prepend_none_dimension(features) 47 # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature 48 params = parsing_ops._ParseOpParams.from_features(self._features, [ 49 parsing_ops.VarLenFeature, parsing_ops.SparseFeature, 50 parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature, 51 parsing_ops.RaggedFeature 52 ]) 53 # pylint: enable=protected-access 54 self._sparse_keys = params.sparse_keys 55 self._sparse_types = params.sparse_types 56 self._ragged_keys = params.ragged_keys 57 self._ragged_value_types = params.ragged_value_types 58 self._ragged_split_types = params.ragged_split_types 59 self._dense_keys = params.dense_keys 60 self._dense_defaults = params.dense_defaults_vec 61 self._dense_shapes = params.dense_shapes_as_proto 62 self._dense_types = params.dense_types 63 input_dataset_shape = dataset_ops.get_legacy_output_shapes( 64 self._input_dataset) 65 66 self._element_spec = {} 67 68 for (key, value_type) in zip(params.sparse_keys, params.sparse_types): 69 self._element_spec[key] = sparse_tensor.SparseTensorSpec( 70 input_dataset_shape.concatenate([None]), value_type) 71 72 for (key, value_type, dense_shape) in zip(params.dense_keys, 73 params.dense_types, 74 params.dense_shapes): 75 self._element_spec[key] = tensor_spec.TensorSpec( 76 input_dataset_shape.concatenate(dense_shape), value_type) 77 78 for (key, value_type, splits_type) in zip(params.ragged_keys, 79 params.ragged_value_types, 80 params.ragged_split_types): 81 self._element_spec[key] = ragged_tensor.RaggedTensorSpec( 82 input_dataset_shape.concatenate([None]), value_type, 1, splits_type) 83 84 variant_tensor = ( 85 gen_experimental_dataset_ops.parse_example_dataset_v2( 86 self._input_dataset._variant_tensor, # pylint: disable=protected-access 87 self._num_parallel_calls, 88 self._dense_defaults, 89 self._sparse_keys, 90 self._dense_keys, 91 self._sparse_types, 92 self._dense_shapes, 93 deterministic=self._deterministic, 94 ragged_keys=self._ragged_keys, 95 ragged_value_types=self._ragged_value_types, 96 ragged_split_types=self._ragged_split_types, 97 **self._flat_structure)) 98 super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) 99 100 @property 101 def element_spec(self): 102 return self._element_spec 103 104 105# TODO(b/111553342): add arguments names and example names as well. 106@tf_export("data.experimental.parse_example_dataset") 107def parse_example_dataset(features, num_parallel_calls=1, deterministic=None): 108 """A transformation that parses `Example` protos into a `dict` of tensors. 109 110 Parses a number of serialized `Example` protos given in `serialized`. We refer 111 to `serialized` as a batch with `batch_size` many entries of individual 112 `Example` protos. 113 114 This op parses serialized examples into a dictionary mapping keys to `Tensor`, 115 `SparseTensor`, and `RaggedTensor` objects. `features` is a dict from keys to 116 `VarLenFeature`, `RaggedFeature`, `SparseFeature`, and `FixedLenFeature` 117 objects. Each `VarLenFeature` and `SparseFeature` is mapped to a 118 `SparseTensor`; each `RaggedFeature` is mapped to a `RaggedTensor`; and each 119 `FixedLenFeature` is mapped to a `Tensor`. See `tf.io.parse_example` for more 120 details about feature dictionaries. 121 122 Args: 123 features: A `dict` mapping feature keys to `FixedLenFeature`, 124 `VarLenFeature`, `RaggedFeature`, and `SparseFeature` values. 125 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 126 representing the number of parsing processes to call in parallel. 127 deterministic: (Optional.) A boolean controlling whether determinism 128 should be traded for performance by allowing elements to be produced out 129 of order if some parsing calls complete faster than others. If 130 `deterministic` is `None`, the 131 `tf.data.Options.deterministic` dataset option (`True` by default) is used 132 to decide whether to produce elements deterministically. 133 134 Returns: 135 A dataset transformation function, which can be passed to 136 `tf.data.Dataset.apply`. 137 138 Raises: 139 ValueError: if features argument is None. 140 """ 141 if features is None: 142 raise ValueError("Argument `features` is required, but not specified.") 143 144 def _apply_fn(dataset): 145 """Function from `Dataset` to `Dataset` that applies the transformation.""" 146 out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls, 147 deterministic) 148 if any( 149 isinstance(feature, parsing_ops.SparseFeature) or 150 isinstance(feature, parsing_ops.RaggedFeature) 151 for feature in features.values()): 152 # pylint: disable=protected-access 153 # pylint: disable=g-long-lambda 154 out_dataset = out_dataset.map( 155 lambda x: parsing_ops._construct_tensors_for_composite_features( 156 features, x), 157 num_parallel_calls=num_parallel_calls) 158 return out_dataset 159 160 return _apply_fn 161