xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/composite_tensor.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensor-like objects that are composed from tf.Tensors."""
16
17import abc
18
19from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
20from tensorflow.python.util import _pywrap_utils
21from tensorflow.python.util import nest
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("__internal__.CompositeTensor", v1=[])
26class CompositeTensor(metaclass=abc.ABCMeta):
27  """Abstract base class for Tensor-like objects that are composed from Tensors.
28
29  Each `CompositeTensor` can be decomposed into a structured collection of
30  component `tf.Tensor`s, and reconstructed from those components.
31
32  The `tensorflow.python.util.nest` module has support for treating composite
33  tensors as structure, which makes it easy to flatten and reconstruct
34  composite tensors (or larger structures that contain composite tensors).
35  E.g.:
36
37  ```python
38  ct = ...  # Create a composite tensor.
39  flat_list_of_tensors = nest.flatten(ct, expand_composites=True)
40  transformed_list_of_tensors = ...  # do something with the flat tensors.
41  result = nest.pack_sequence_as(ct, transformed_list_of_tensors,
42                                 expand_composites=True)
43  ```
44  """
45
46  @abc.abstractproperty
47  def _type_spec(self):
48    """A `TypeSpec` describing the type of this value."""
49    raise NotImplementedError(f"{type(self).__name__}._type_spec()")
50
51  def _shape_invariant_to_type_spec(self, shape):
52    """Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
53
54    Args:
55      shape: A `tf.TensorShape` object.  The shape invariant for this
56        `CompositeTensor`, or `None` if a default shape invariant should be used
57        (based on the value of this `CompositeTensor`).
58
59    Returns:
60      A nested structure whose values are `tf.TensorShape` objects, specifying
61      the shape invariants for the tensors that comprise this `CompositeTensor`.
62    """
63    # New TypeSpec subclasses generally do not need to implement this --
64    # this method is used for backwards compatibility.  Users of tf.while_loop
65    # can specify a type by passing in TypeSpec instead.
66    raise NotImplementedError(
67        f"{type(self).__name__}._shape_invariant_to_type_spec")
68
69  def _consumers(self):
70    """Returns a list of `Operation`s that consume this `CompositeTensor`.
71
72    Returns:
73      A list of `Operation`s.
74
75    Raises:
76      RuntimeError: If this method is called while executing eagerly.
77    """
78    consumers = nest.flatten([
79        component.consumers()
80        for component in nest.flatten(self, expand_composites=True)
81        if getattr(component, "graph", None) is not None
82    ])
83    return list(set(consumers))
84
85  def __tf_tracing_type__(self, context):
86    return self._type_spec.__tf_tracing_type__(context)
87
88  def _convert_variables_to_tensors(self):
89    """Converts ResourceVariable components to Tensors.
90
91    Override this method to explicitly convert ResourceVariables embedded in the
92    CompositeTensor to Tensors. By default, it returns the CompositeTensor
93    unchanged.
94
95    Returns:
96      A CompositeTensor with all its ResourceVariable components converted to
97      Tensors.
98    """
99    return self
100
101
102_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor)
103
104
105def replace_composites_with_components(structure):
106  """Recursively replaces CompositeTensors with their components.
107
108  Args:
109    structure: A `nest`-compatible structure, possibly containing composite
110      tensors.
111
112  Returns:
113    A copy of `structure`, where each composite tensor has been replaced by
114    its components.  The result will contain no composite tensors.
115    Note that `nest.flatten(replace_composites_with_components(structure))`
116    returns the same value as `nest.flatten(structure)`.
117  """
118  if isinstance(structure, CompositeTensor):
119    return replace_composites_with_components(
120        structure._type_spec._to_components(structure))  # pylint: disable=protected-access
121  elif not nest.is_nested(structure):
122    return structure
123  else:
124    return nest.map_structure(
125        replace_composites_with_components, structure, expand_composites=False)
126
127
128def convert_variables_to_tensors(composite_tensor):
129  return composite_tensor._convert_variables_to_tensors()  # pylint: disable=protected-access
130
131
132# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just
133# convert_to_tensor_or_composite?  Alternatively, should composite tensors
134# register a dispatch override for tf.convert_to_tensor?
135
136# Note about the internal encoding of composite tensors when they are "lowered"
137# from Python objects to tensors. The usual encoding is "component encoding"
138# which uses the dense tensors that represent a composite tensor.
139# A second encoding, "batchable tensor list encoding", is used by datasets
140# and map_fn which in addition to supporting batching also can use ops
141# for encoding and decoding, e.g. for encoding/decoding to/from a
142# single variant that represents a composite tensor. Some internal properties
143# for type specs for composite tensors use `flat` as a nickname for
144# "batchable tensor list encoding". (e.g. `flat_tensor_specs`).
145