1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for describing the structure of a `tf.data` type.""" 16import collections 17import functools 18import itertools 19 20import wrapt 21 22from tensorflow.python.data.util import nest 23from tensorflow.python.framework import composite_tensor 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import type_spec 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import tensor_array_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.util import deprecation 34from tensorflow.python.util.compat import collections_abc 35from tensorflow.python.util.tf_export import tf_export 36 37 38# pylint: disable=invalid-name 39@tf_export(v1=["data.experimental.TensorStructure"]) 40@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.") 41def _TensorStructure(dtype, shape): 42 return tensor_spec.TensorSpec(shape, dtype) 43 44 45@tf_export(v1=["data.experimental.SparseTensorStructure"]) 46@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.") 47def _SparseTensorStructure(dtype, shape): 48 return sparse_tensor.SparseTensorSpec(shape, dtype) 49 50 51@tf_export(v1=["data.experimental.TensorArrayStructure"]) 52@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.") 53def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape): 54 return tensor_array_ops.TensorArraySpec(element_shape, dtype, 55 dynamic_size, infer_shape) 56 57 58@tf_export(v1=["data.experimental.RaggedTensorStructure"]) 59@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.") 60def _RaggedTensorStructure(dtype, shape, ragged_rank): 61 return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank) 62# pylint: enable=invalid-name 63 64 65# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once 66# it is a subclass of `CompositeTensor`. 67def normalize_element(element, element_signature=None): 68 """Normalizes a nested structure of element components. 69 70 * Components matching `SparseTensorSpec` are converted to `SparseTensor`. 71 * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`. 72 * Components matching `VariableSpec` are converted to `Tensor`. 73 * Components matching `DatasetSpec` or `TensorArraySpec` are passed through. 74 * `CompositeTensor` components are passed through. 75 * All other components are converted to `Tensor`. 76 77 Args: 78 element: A nested structure of individual components. 79 element_signature: (Optional.) A nested structure of `tf.DType` objects 80 corresponding to each component of `element`. If specified, it will be 81 used to set the exact type of output tensor when converting input 82 components which are not tensors themselves (e.g. numpy arrays, native 83 python types, etc.) 84 85 Returns: 86 A nested structure of `Tensor`, `Variable`, `Dataset`, `SparseTensor`, 87 `RaggedTensor`, or `TensorArray` objects. 88 """ 89 normalized_components = [] 90 if element_signature is None: 91 components = nest.flatten(element) 92 flattened_signature = [None] * len(components) 93 pack_as = element 94 else: 95 flattened_signature = nest.flatten(element_signature) 96 components = nest.flatten_up_to(element_signature, element) 97 pack_as = element_signature 98 with ops.name_scope("normalize_element"): 99 # Imported here to avoid circular dependency. 100 from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top 101 for i, (t, spec) in enumerate(zip(components, flattened_signature)): 102 try: 103 if spec is None: 104 spec = type_spec_from_value(t, use_fallback=False) 105 except TypeError: 106 # TypeError indicates it was not possible to compute a `TypeSpec` for 107 # the value. As a fallback try converting the value to a tensor. 108 normalized_components.append( 109 ops.convert_to_tensor(t, name="component_%d" % i)) 110 else: 111 if isinstance(spec, sparse_tensor.SparseTensorSpec): 112 normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) 113 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 114 normalized_components.append( 115 ragged_tensor.convert_to_tensor_or_ragged_tensor( 116 t, name="component_%d" % i)) 117 elif isinstance( 118 spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)): 119 normalized_components.append(t) 120 elif isinstance(spec, NoneTensorSpec): 121 normalized_components.append(NoneTensor()) 122 elif isinstance(spec, resource_variable_ops.VariableSpec): 123 normalized_components.append( 124 ops.convert_to_tensor(t, name=f"component_{i}", dtype=spec.dtype)) 125 elif isinstance(t, composite_tensor.CompositeTensor): 126 normalized_components.append(t) 127 else: 128 dtype = getattr(spec, "dtype", None) 129 normalized_components.append( 130 ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) 131 return nest.pack_sequence_as(pack_as, normalized_components) 132 133 134def convert_legacy_structure(output_types, output_shapes, output_classes): 135 """Returns a `Structure` that represents the given legacy structure. 136 137 This method provides a way to convert from the existing `Dataset` and 138 `Iterator` structure-related properties to a `Structure` object. A "legacy" 139 structure is represented by the `tf.data.Dataset.output_types`, 140 `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` 141 properties. 142 143 TODO(b/110122868): Remove this function once `Structure` is used throughout 144 `tf.data`. 145 146 Args: 147 output_types: A nested structure of `tf.DType` objects corresponding to 148 each component of a structured value. 149 output_shapes: A nested structure of `tf.TensorShape` objects 150 corresponding to each component a structured value. 151 output_classes: A nested structure of Python `type` objects corresponding 152 to each component of a structured value. 153 154 Returns: 155 A `Structure`. 156 157 Raises: 158 TypeError: If a structure cannot be built from the arguments, because one of 159 the component classes in `output_classes` is not supported. 160 """ 161 flat_types = nest.flatten(output_types) 162 flat_shapes = nest.flatten(output_shapes) 163 flat_classes = nest.flatten(output_classes) 164 flat_ret = [] 165 for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, 166 flat_classes): 167 if isinstance(flat_class, type_spec.TypeSpec): 168 flat_ret.append(flat_class) 169 elif issubclass(flat_class, sparse_tensor.SparseTensor): 170 flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type)) 171 elif issubclass(flat_class, ops.Tensor): 172 flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type)) 173 elif issubclass(flat_class, tensor_array_ops.TensorArray): 174 # We sneaked the dynamic_size and infer_shape into the legacy shape. 175 flat_ret.append( 176 tensor_array_ops.TensorArraySpec( 177 flat_shape[2:], flat_type, 178 dynamic_size=tensor_shape.dimension_value(flat_shape[0]), 179 infer_shape=tensor_shape.dimension_value(flat_shape[1]))) 180 else: 181 # NOTE(mrry): Since legacy structures produced by iterators only 182 # comprise Tensors, SparseTensors, and nests, we do not need to 183 # support all structure types here. 184 raise TypeError( 185 "Could not build a structure for output class {}. Make sure any " 186 "component class in `output_classes` inherits from one of the " 187 "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, " 188 "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__)) 189 190 return nest.pack_sequence_as(output_classes, flat_ret) 191 192 193def _from_tensor_list_helper(decode_fn, element_spec, tensor_list): 194 """Returns an element constructed from the given spec and tensor list. 195 196 Args: 197 decode_fn: Method that constructs an element component from the element spec 198 component and a tensor list. 199 element_spec: A nested structure of `tf.TypeSpec` objects representing to 200 element type specification. 201 tensor_list: A list of tensors to use for constructing the value. 202 203 Returns: 204 An element constructed from the given spec and tensor list. 205 206 Raises: 207 ValueError: If the number of tensors needed to construct an element for 208 the given spec does not match the given number of tensors. 209 """ 210 211 # pylint: disable=protected-access 212 213 flat_specs = nest.flatten(element_spec) 214 flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs] 215 if sum(flat_spec_lengths) != len(tensor_list): 216 raise ValueError("Expected {} tensors but got {}.".format( 217 sum(flat_spec_lengths), len(tensor_list))) 218 219 i = 0 220 flat_ret = [] 221 for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths): 222 value = tensor_list[i:i + num_flat_values] 223 flat_ret.append(decode_fn(component_spec, value)) 224 i += num_flat_values 225 return nest.pack_sequence_as(element_spec, flat_ret) 226 227 228def from_compatible_tensor_list(element_spec, tensor_list): 229 """Returns an element constructed from the given spec and tensor list. 230 231 Args: 232 element_spec: A nested structure of `tf.TypeSpec` objects representing to 233 element type specification. 234 tensor_list: A list of tensors to use for constructing the value. 235 236 Returns: 237 An element constructed from the given spec and tensor list. 238 239 Raises: 240 ValueError: If the number of tensors needed to construct an element for 241 the given spec does not match the given number of tensors. 242 """ 243 244 # pylint: disable=protected-access 245 # pylint: disable=g-long-lambda 246 return _from_tensor_list_helper( 247 lambda spec, value: spec._from_compatible_tensor_list(value), 248 element_spec, tensor_list) 249 250 251def from_tensor_list(element_spec, tensor_list): 252 """Returns an element constructed from the given spec and tensor list. 253 254 Args: 255 element_spec: A nested structure of `tf.TypeSpec` objects representing to 256 element type specification. 257 tensor_list: A list of tensors to use for constructing the value. 258 259 Returns: 260 An element constructed from the given spec and tensor list. 261 262 Raises: 263 ValueError: If the number of tensors needed to construct an element for 264 the given spec does not match the given number of tensors or the given 265 spec is not compatible with the tensor list. 266 """ 267 268 # pylint: disable=protected-access 269 # pylint: disable=g-long-lambda 270 return _from_tensor_list_helper( 271 lambda spec, value: spec._from_tensor_list(value), element_spec, 272 tensor_list) 273 274 275def get_flat_tensor_specs(element_spec): 276 """Returns a list `tf.TypeSpec`s for the element tensor representation. 277 278 Args: 279 element_spec: A nested structure of `tf.TypeSpec` objects representing to 280 element type specification. 281 282 Returns: 283 A list `tf.TypeSpec`s for the element tensor representation. 284 """ 285 286 # pylint: disable=protected-access 287 return list( 288 itertools.chain.from_iterable( 289 spec._flat_tensor_specs for spec in nest.flatten(element_spec))) 290 291 292def get_flat_tensor_shapes(element_spec): 293 """Returns a list `tf.TensorShapes`s for the element tensor representation. 294 295 Args: 296 element_spec: A nested structure of `tf.TypeSpec` objects representing to 297 element type specification. 298 299 Returns: 300 A list `tf.TensorShapes`s for the element tensor representation. 301 """ 302 return [spec.shape for spec in get_flat_tensor_specs(element_spec)] 303 304 305def get_flat_tensor_types(element_spec): 306 """Returns a list `tf.DType`s for the element tensor representation. 307 308 Args: 309 element_spec: A nested structure of `tf.TypeSpec` objects representing to 310 element type specification. 311 312 Returns: 313 A list `tf.DType`s for the element tensor representation. 314 """ 315 return [spec.dtype for spec in get_flat_tensor_specs(element_spec)] 316 317 318def _to_tensor_list_helper(encode_fn, element_spec, element): 319 """Returns a tensor list representation of the element. 320 321 Args: 322 encode_fn: Method that constructs a tensor list representation from the 323 given element spec and element. 324 element_spec: A nested structure of `tf.TypeSpec` objects representing to 325 element type specification. 326 element: The element to convert to tensor list representation. 327 328 Returns: 329 A tensor list representation of `element`. 330 331 Raises: 332 ValueError: If `element_spec` and `element` do not have the same number of 333 elements or if the two structures are not nested in the same way. 334 TypeError: If `element_spec` and `element` differ in the type of sequence 335 in any of their substructures. 336 """ 337 338 nest.assert_same_structure(element_spec, element) 339 340 def reduce_fn(state, value): 341 spec, component = value 342 return encode_fn(state, spec, component) 343 344 return functools.reduce( 345 reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), []) 346 347 348def to_batched_tensor_list(element_spec, element): 349 """Returns a tensor list representation of the element. 350 351 Args: 352 element_spec: A nested structure of `tf.TypeSpec` objects representing to 353 element type specification. 354 element: The element to convert to tensor list representation. 355 356 Returns: 357 A tensor list representation of `element`. 358 359 Raises: 360 ValueError: If `element_spec` and `element` do not have the same number of 361 elements or if the two structures are not nested in the same way or the 362 rank of any of the tensors in the tensor list representation is 0. 363 TypeError: If `element_spec` and `element` differ in the type of sequence 364 in any of their substructures. 365 """ 366 367 # pylint: disable=protected-access 368 # pylint: disable=g-long-lambda 369 return _to_tensor_list_helper( 370 lambda state, spec, component: state + spec._to_batched_tensor_list( 371 component), element_spec, element) 372 373 374def to_tensor_list(element_spec, element): 375 """Returns a tensor list representation of the element. 376 377 Args: 378 element_spec: A nested structure of `tf.TypeSpec` objects representing to 379 element type specification. 380 element: The element to convert to tensor list representation. 381 382 Returns: 383 A tensor list representation of `element`. 384 385 Raises: 386 ValueError: If `element_spec` and `element` do not have the same number of 387 elements or if the two structures are not nested in the same way. 388 TypeError: If `element_spec` and `element` differ in the type of sequence 389 in any of their substructures. 390 """ 391 392 # pylint: disable=protected-access 393 # pylint: disable=g-long-lambda 394 return _to_tensor_list_helper( 395 lambda state, spec, component: state + spec._to_tensor_list(component), 396 element_spec, element) 397 398 399def are_compatible(spec1, spec2): 400 """Indicates whether two type specifications are compatible. 401 402 Two type specifications are compatible if they have the same nested structure 403 and the their individual components are pair-wise compatible. 404 405 Args: 406 spec1: A `tf.TypeSpec` object to compare. 407 spec2: A `tf.TypeSpec` object to compare. 408 409 Returns: 410 `True` if the two type specifications are compatible and `False` otherwise. 411 """ 412 413 try: 414 nest.assert_same_structure(spec1, spec2) 415 except TypeError: 416 return False 417 except ValueError: 418 return False 419 420 for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)): 421 if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1): 422 return False 423 return True 424 425 426def type_spec_from_value(element, use_fallback=True): 427 """Creates a type specification for the given value. 428 429 Args: 430 element: The element to create the type specification for. 431 use_fallback: Whether to fall back to converting the element to a tensor 432 in order to compute its `TypeSpec`. 433 434 Returns: 435 A nested structure of `TypeSpec`s that represents the type specification 436 of `element`. 437 438 Raises: 439 TypeError: If a `TypeSpec` cannot be built for `element`, because its type 440 is not supported. 441 """ 442 spec = type_spec._type_spec_from_value(element) # pylint: disable=protected-access 443 if spec is not None: 444 return spec 445 446 if isinstance(element, collections_abc.Mapping): 447 # We create a shallow copy in an attempt to preserve the key order. 448 # 449 # Note that we do not guarantee that the key order is preserved, which is 450 # a limitation inherited from `copy()`. As a consequence, callers of 451 # `type_spec_from_value` should not assume that the key order of a `dict` 452 # in the returned nested structure matches the key order of the 453 # corresponding `dict` in the input value. 454 if isinstance(element, collections.defaultdict): 455 ctor = lambda items: type(element)(element.default_factory, items) 456 else: 457 ctor = type(element) 458 return ctor([(k, type_spec_from_value(v)) for k, v in element.items()]) 459 460 if isinstance(element, tuple): 461 if hasattr(element, "_fields") and isinstance( 462 element._fields, collections_abc.Sequence) and all( 463 isinstance(f, str) for f in element._fields): 464 if isinstance(element, wrapt.ObjectProxy): 465 element_type = type(element.__wrapped__) 466 else: 467 element_type = type(element) 468 # `element` is a namedtuple 469 return element_type(*[type_spec_from_value(v) for v in element]) 470 # `element` is not a namedtuple 471 return tuple([type_spec_from_value(v) for v in element]) 472 473 if hasattr(element.__class__, "__attrs_attrs__"): 474 # `element` is an `attr.s` decorated class 475 attrs = getattr(element.__class__, "__attrs_attrs__") 476 return type(element)(*[ 477 type_spec_from_value(getattr(element, a.name)) for a in attrs 478 ]) 479 480 if use_fallback: 481 # As a fallback try converting the element to a tensor. 482 try: 483 tensor = ops.convert_to_tensor(element) 484 spec = type_spec_from_value(tensor) 485 if spec is not None: 486 return spec 487 except (ValueError, TypeError) as e: 488 logging.vlog( 489 3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e)) 490 491 raise TypeError("Could not build a `TypeSpec` for {} with type {}".format( 492 element, 493 type(element).__name__)) 494 495 496# TODO(b/149584798): Move this to framework and add tests for non-tf.data 497# functionality. 498class NoneTensor(composite_tensor.CompositeTensor): 499 """Composite tensor representation for `None` value.""" 500 501 @property 502 def _type_spec(self): 503 return NoneTensorSpec() 504 505 506# TODO(b/149584798): Move this to framework and add tests for non-tf.data 507# functionality. 508@type_spec.register("tf.NoneTensorSpec") 509class NoneTensorSpec(type_spec.BatchableTypeSpec): 510 """Type specification for `None` value.""" 511 512 @property 513 def value_type(self): 514 return NoneTensor 515 516 def _serialize(self): 517 return () 518 519 @property 520 def _component_specs(self): 521 return [] 522 523 def _to_components(self, value): 524 return [] 525 526 def _from_components(self, components): 527 return 528 529 def _to_tensor_list(self, value): 530 return [] 531 532 @staticmethod 533 def from_value(value): 534 return NoneTensorSpec() 535 536 def _batch(self, batch_size): 537 return NoneTensorSpec() 538 539 def _unbatch(self): 540 return NoneTensorSpec() 541 542 def _to_batched_tensor_list(self, value): 543 return [] 544 545 def _to_legacy_output_types(self): 546 return self 547 548 def _to_legacy_output_shapes(self): 549 return self 550 551 def _to_legacy_output_classes(self): 552 return self 553 554 def most_specific_compatible_shape(self, other): 555 if type(self) is not type(other): 556 raise ValueError("No `TypeSpec` is compatible with both {} and {}".format( 557 self, other)) 558 return self 559 560 561type_spec.register_type_spec_from_value_converter(type(None), 562 NoneTensorSpec.from_value) 563