xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/util/structure.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for describing the structure of a `tf.data` type."""
16import collections
17import functools
18import itertools
19
20import wrapt
21
22from tensorflow.python.data.util import nest
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import tensor_array_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import deprecation
34from tensorflow.python.util.compat import collections_abc
35from tensorflow.python.util.tf_export import tf_export
36
37
38# pylint: disable=invalid-name
39@tf_export(v1=["data.experimental.TensorStructure"])
40@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.")
41def _TensorStructure(dtype, shape):
42  return tensor_spec.TensorSpec(shape, dtype)
43
44
45@tf_export(v1=["data.experimental.SparseTensorStructure"])
46@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.")
47def _SparseTensorStructure(dtype, shape):
48  return sparse_tensor.SparseTensorSpec(shape, dtype)
49
50
51@tf_export(v1=["data.experimental.TensorArrayStructure"])
52@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.")
53def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
54  return tensor_array_ops.TensorArraySpec(element_shape, dtype,
55                                          dynamic_size, infer_shape)
56
57
58@tf_export(v1=["data.experimental.RaggedTensorStructure"])
59@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.")
60def _RaggedTensorStructure(dtype, shape, ragged_rank):
61  return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank)
62# pylint: enable=invalid-name
63
64
65# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
66# it is a subclass of `CompositeTensor`.
67def normalize_element(element, element_signature=None):
68  """Normalizes a nested structure of element components.
69
70  * Components matching `SparseTensorSpec` are converted to `SparseTensor`.
71  * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`.
72  * Components matching `VariableSpec` are converted to `Tensor`.
73  * Components matching `DatasetSpec` or `TensorArraySpec` are passed through.
74  * `CompositeTensor` components are passed through.
75  * All other components are converted to `Tensor`.
76
77  Args:
78    element: A nested structure of individual components.
79    element_signature: (Optional.) A nested structure of `tf.DType` objects
80      corresponding to each component of `element`. If specified, it will be
81      used to set the exact type of output tensor when converting input
82      components which are not tensors themselves (e.g. numpy arrays, native
83      python types, etc.)
84
85  Returns:
86    A nested structure of `Tensor`, `Variable`, `Dataset`, `SparseTensor`,
87    `RaggedTensor`, or `TensorArray` objects.
88  """
89  normalized_components = []
90  if element_signature is None:
91    components = nest.flatten(element)
92    flattened_signature = [None] * len(components)
93    pack_as = element
94  else:
95    flattened_signature = nest.flatten(element_signature)
96    components = nest.flatten_up_to(element_signature, element)
97    pack_as = element_signature
98  with ops.name_scope("normalize_element"):
99    # Imported here to avoid circular dependency.
100    from tensorflow.python.data.ops import dataset_ops  # pylint: disable=g-import-not-at-top
101    for i, (t, spec) in enumerate(zip(components, flattened_signature)):
102      try:
103        if spec is None:
104          spec = type_spec_from_value(t, use_fallback=False)
105      except TypeError:
106        # TypeError indicates it was not possible to compute a `TypeSpec` for
107        # the value. As a fallback try converting the value to a tensor.
108        normalized_components.append(
109            ops.convert_to_tensor(t, name="component_%d" % i))
110      else:
111        if isinstance(spec, sparse_tensor.SparseTensorSpec):
112          normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
113        elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
114          normalized_components.append(
115              ragged_tensor.convert_to_tensor_or_ragged_tensor(
116                  t, name="component_%d" % i))
117        elif isinstance(
118            spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
119          normalized_components.append(t)
120        elif isinstance(spec, NoneTensorSpec):
121          normalized_components.append(NoneTensor())
122        elif isinstance(spec, resource_variable_ops.VariableSpec):
123          normalized_components.append(
124              ops.convert_to_tensor(t, name=f"component_{i}", dtype=spec.dtype))
125        elif isinstance(t, composite_tensor.CompositeTensor):
126          normalized_components.append(t)
127        else:
128          dtype = getattr(spec, "dtype", None)
129          normalized_components.append(
130              ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
131  return nest.pack_sequence_as(pack_as, normalized_components)
132
133
134def convert_legacy_structure(output_types, output_shapes, output_classes):
135  """Returns a `Structure` that represents the given legacy structure.
136
137  This method provides a way to convert from the existing `Dataset` and
138  `Iterator` structure-related properties to a `Structure` object. A "legacy"
139  structure is represented by the `tf.data.Dataset.output_types`,
140  `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
141  properties.
142
143  TODO(b/110122868): Remove this function once `Structure` is used throughout
144  `tf.data`.
145
146  Args:
147    output_types: A nested structure of `tf.DType` objects corresponding to
148      each component of a structured value.
149    output_shapes: A nested structure of `tf.TensorShape` objects
150      corresponding to each component a structured value.
151    output_classes: A nested structure of Python `type` objects corresponding
152      to each component of a structured value.
153
154  Returns:
155    A `Structure`.
156
157  Raises:
158    TypeError: If a structure cannot be built from the arguments, because one of
159      the component classes in `output_classes` is not supported.
160  """
161  flat_types = nest.flatten(output_types)
162  flat_shapes = nest.flatten(output_shapes)
163  flat_classes = nest.flatten(output_classes)
164  flat_ret = []
165  for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
166                                               flat_classes):
167    if isinstance(flat_class, type_spec.TypeSpec):
168      flat_ret.append(flat_class)
169    elif issubclass(flat_class, sparse_tensor.SparseTensor):
170      flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type))
171    elif issubclass(flat_class, ops.Tensor):
172      flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type))
173    elif issubclass(flat_class, tensor_array_ops.TensorArray):
174      # We sneaked the dynamic_size and infer_shape into the legacy shape.
175      flat_ret.append(
176          tensor_array_ops.TensorArraySpec(
177              flat_shape[2:], flat_type,
178              dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
179              infer_shape=tensor_shape.dimension_value(flat_shape[1])))
180    else:
181      # NOTE(mrry): Since legacy structures produced by iterators only
182      # comprise Tensors, SparseTensors, and nests, we do not need to
183      # support all structure types here.
184      raise TypeError(
185          "Could not build a structure for output class {}. Make sure any "
186          "component class in `output_classes` inherits from one of the "
187          "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, "
188          "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__))
189
190  return nest.pack_sequence_as(output_classes, flat_ret)
191
192
193def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
194  """Returns an element constructed from the given spec and tensor list.
195
196  Args:
197    decode_fn: Method that constructs an element component from the element spec
198      component and a tensor list.
199    element_spec: A nested structure of `tf.TypeSpec` objects representing to
200      element type specification.
201    tensor_list: A list of tensors to use for constructing the value.
202
203  Returns:
204    An element constructed from the given spec and tensor list.
205
206  Raises:
207    ValueError: If the number of tensors needed to construct an element for
208      the given spec does not match the given number of tensors.
209  """
210
211  # pylint: disable=protected-access
212
213  flat_specs = nest.flatten(element_spec)
214  flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs]
215  if sum(flat_spec_lengths) != len(tensor_list):
216    raise ValueError("Expected {} tensors but got {}.".format(
217        sum(flat_spec_lengths), len(tensor_list)))
218
219  i = 0
220  flat_ret = []
221  for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths):
222    value = tensor_list[i:i + num_flat_values]
223    flat_ret.append(decode_fn(component_spec, value))
224    i += num_flat_values
225  return nest.pack_sequence_as(element_spec, flat_ret)
226
227
228def from_compatible_tensor_list(element_spec, tensor_list):
229  """Returns an element constructed from the given spec and tensor list.
230
231  Args:
232    element_spec: A nested structure of `tf.TypeSpec` objects representing to
233      element type specification.
234    tensor_list: A list of tensors to use for constructing the value.
235
236  Returns:
237    An element constructed from the given spec and tensor list.
238
239  Raises:
240    ValueError: If the number of tensors needed to construct an element for
241      the given spec does not match the given number of tensors.
242  """
243
244  # pylint: disable=protected-access
245  # pylint: disable=g-long-lambda
246  return _from_tensor_list_helper(
247      lambda spec, value: spec._from_compatible_tensor_list(value),
248      element_spec, tensor_list)
249
250
251def from_tensor_list(element_spec, tensor_list):
252  """Returns an element constructed from the given spec and tensor list.
253
254  Args:
255    element_spec: A nested structure of `tf.TypeSpec` objects representing to
256      element type specification.
257    tensor_list: A list of tensors to use for constructing the value.
258
259  Returns:
260    An element constructed from the given spec and tensor list.
261
262  Raises:
263    ValueError: If the number of tensors needed to construct an element for
264      the given spec does not match the given number of tensors or the given
265      spec is not compatible with the tensor list.
266  """
267
268  # pylint: disable=protected-access
269  # pylint: disable=g-long-lambda
270  return _from_tensor_list_helper(
271      lambda spec, value: spec._from_tensor_list(value), element_spec,
272      tensor_list)
273
274
275def get_flat_tensor_specs(element_spec):
276  """Returns a list `tf.TypeSpec`s for the element tensor representation.
277
278  Args:
279    element_spec: A nested structure of `tf.TypeSpec` objects representing to
280      element type specification.
281
282  Returns:
283    A list `tf.TypeSpec`s for the element tensor representation.
284  """
285
286  # pylint: disable=protected-access
287  return list(
288      itertools.chain.from_iterable(
289          spec._flat_tensor_specs for spec in nest.flatten(element_spec)))
290
291
292def get_flat_tensor_shapes(element_spec):
293  """Returns a list `tf.TensorShapes`s for the element tensor representation.
294
295  Args:
296    element_spec: A nested structure of `tf.TypeSpec` objects representing to
297      element type specification.
298
299  Returns:
300    A list `tf.TensorShapes`s for the element tensor representation.
301  """
302  return [spec.shape for spec in get_flat_tensor_specs(element_spec)]
303
304
305def get_flat_tensor_types(element_spec):
306  """Returns a list `tf.DType`s for the element tensor representation.
307
308  Args:
309    element_spec: A nested structure of `tf.TypeSpec` objects representing to
310      element type specification.
311
312  Returns:
313    A list `tf.DType`s for the element tensor representation.
314  """
315  return [spec.dtype for spec in get_flat_tensor_specs(element_spec)]
316
317
318def _to_tensor_list_helper(encode_fn, element_spec, element):
319  """Returns a tensor list representation of the element.
320
321  Args:
322    encode_fn: Method that constructs a tensor list representation from the
323      given element spec and element.
324    element_spec: A nested structure of `tf.TypeSpec` objects representing to
325      element type specification.
326    element: The element to convert to tensor list representation.
327
328  Returns:
329    A tensor list representation of `element`.
330
331  Raises:
332    ValueError: If `element_spec` and `element` do not have the same number of
333      elements or if the two structures are not nested in the same way.
334    TypeError: If `element_spec` and `element` differ in the type of sequence
335      in any of their substructures.
336  """
337
338  nest.assert_same_structure(element_spec, element)
339
340  def reduce_fn(state, value):
341    spec, component = value
342    return encode_fn(state, spec, component)
343
344  return functools.reduce(
345      reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), [])
346
347
348def to_batched_tensor_list(element_spec, element):
349  """Returns a tensor list representation of the element.
350
351  Args:
352    element_spec: A nested structure of `tf.TypeSpec` objects representing to
353      element type specification.
354    element: The element to convert to tensor list representation.
355
356  Returns:
357    A tensor list representation of `element`.
358
359  Raises:
360    ValueError: If `element_spec` and `element` do not have the same number of
361      elements or if the two structures are not nested in the same way or the
362      rank of any of the tensors in the tensor list representation is 0.
363    TypeError: If `element_spec` and `element` differ in the type of sequence
364      in any of their substructures.
365  """
366
367  # pylint: disable=protected-access
368  # pylint: disable=g-long-lambda
369  return _to_tensor_list_helper(
370      lambda state, spec, component: state + spec._to_batched_tensor_list(
371          component), element_spec, element)
372
373
374def to_tensor_list(element_spec, element):
375  """Returns a tensor list representation of the element.
376
377  Args:
378    element_spec: A nested structure of `tf.TypeSpec` objects representing to
379      element type specification.
380    element: The element to convert to tensor list representation.
381
382  Returns:
383    A tensor list representation of `element`.
384
385  Raises:
386    ValueError: If `element_spec` and `element` do not have the same number of
387      elements or if the two structures are not nested in the same way.
388    TypeError: If `element_spec` and `element` differ in the type of sequence
389      in any of their substructures.
390  """
391
392  # pylint: disable=protected-access
393  # pylint: disable=g-long-lambda
394  return _to_tensor_list_helper(
395      lambda state, spec, component: state + spec._to_tensor_list(component),
396      element_spec, element)
397
398
399def are_compatible(spec1, spec2):
400  """Indicates whether two type specifications are compatible.
401
402  Two type specifications are compatible if they have the same nested structure
403  and the their individual components are pair-wise compatible.
404
405  Args:
406    spec1: A `tf.TypeSpec` object to compare.
407    spec2: A `tf.TypeSpec` object to compare.
408
409  Returns:
410    `True` if the two type specifications are compatible and `False` otherwise.
411  """
412
413  try:
414    nest.assert_same_structure(spec1, spec2)
415  except TypeError:
416    return False
417  except ValueError:
418    return False
419
420  for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)):
421    if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1):
422      return False
423  return True
424
425
426def type_spec_from_value(element, use_fallback=True):
427  """Creates a type specification for the given value.
428
429  Args:
430    element: The element to create the type specification for.
431    use_fallback: Whether to fall back to converting the element to a tensor
432      in order to compute its `TypeSpec`.
433
434  Returns:
435    A nested structure of `TypeSpec`s that represents the type specification
436    of `element`.
437
438  Raises:
439    TypeError: If a `TypeSpec` cannot be built for `element`, because its type
440      is not supported.
441  """
442  spec = type_spec._type_spec_from_value(element)  # pylint: disable=protected-access
443  if spec is not None:
444    return spec
445
446  if isinstance(element, collections_abc.Mapping):
447    # We create a shallow copy in an attempt to preserve the key order.
448    #
449    # Note that we do not guarantee that the key order is preserved, which is
450    # a limitation inherited from `copy()`. As a consequence, callers of
451    # `type_spec_from_value` should not assume that the key order of a `dict`
452    # in the returned nested structure matches the key order of the
453    # corresponding `dict` in the input value.
454    if isinstance(element, collections.defaultdict):
455      ctor = lambda items: type(element)(element.default_factory, items)
456    else:
457      ctor = type(element)
458    return ctor([(k, type_spec_from_value(v)) for k, v in element.items()])
459
460  if isinstance(element, tuple):
461    if hasattr(element, "_fields") and isinstance(
462        element._fields, collections_abc.Sequence) and all(
463            isinstance(f, str) for f in element._fields):
464      if isinstance(element, wrapt.ObjectProxy):
465        element_type = type(element.__wrapped__)
466      else:
467        element_type = type(element)
468      # `element` is a namedtuple
469      return element_type(*[type_spec_from_value(v) for v in element])
470    # `element` is not a namedtuple
471    return tuple([type_spec_from_value(v) for v in element])
472
473  if hasattr(element.__class__, "__attrs_attrs__"):
474    # `element` is an `attr.s` decorated class
475    attrs = getattr(element.__class__, "__attrs_attrs__")
476    return type(element)(*[
477        type_spec_from_value(getattr(element, a.name)) for a in attrs
478    ])
479
480  if use_fallback:
481    # As a fallback try converting the element to a tensor.
482    try:
483      tensor = ops.convert_to_tensor(element)
484      spec = type_spec_from_value(tensor)
485      if spec is not None:
486        return spec
487    except (ValueError, TypeError) as e:
488      logging.vlog(
489          3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
490
491  raise TypeError("Could not build a `TypeSpec` for {} with type {}".format(
492      element,
493      type(element).__name__))
494
495
496# TODO(b/149584798): Move this to framework and add tests for non-tf.data
497# functionality.
498class NoneTensor(composite_tensor.CompositeTensor):
499  """Composite tensor representation for `None` value."""
500
501  @property
502  def _type_spec(self):
503    return NoneTensorSpec()
504
505
506# TODO(b/149584798): Move this to framework and add tests for non-tf.data
507# functionality.
508@type_spec.register("tf.NoneTensorSpec")
509class NoneTensorSpec(type_spec.BatchableTypeSpec):
510  """Type specification for `None` value."""
511
512  @property
513  def value_type(self):
514    return NoneTensor
515
516  def _serialize(self):
517    return ()
518
519  @property
520  def _component_specs(self):
521    return []
522
523  def _to_components(self, value):
524    return []
525
526  def _from_components(self, components):
527    return
528
529  def _to_tensor_list(self, value):
530    return []
531
532  @staticmethod
533  def from_value(value):
534    return NoneTensorSpec()
535
536  def _batch(self, batch_size):
537    return NoneTensorSpec()
538
539  def _unbatch(self):
540    return NoneTensorSpec()
541
542  def _to_batched_tensor_list(self, value):
543    return []
544
545  def _to_legacy_output_types(self):
546    return self
547
548  def _to_legacy_output_shapes(self):
549    return self
550
551  def _to_legacy_output_classes(self):
552    return self
553
554  def most_specific_compatible_shape(self, other):
555    if type(self) is not type(other):
556      raise ValueError("No `TypeSpec` is compatible with both {} and {}".format(
557          self, other))
558    return self
559
560
561type_spec.register_type_spec_from_value_converter(type(None),
562                                                  NoneTensorSpec.from_value)
563