1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tensor-like objects that are composed from tf.Tensors.""" 16 17import abc 18 19from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 20from tensorflow.python.util import _pywrap_utils 21from tensorflow.python.util import nest 22from tensorflow.python.util.tf_export import tf_export 23 24 25@tf_export("__internal__.CompositeTensor", v1=[]) 26class CompositeTensor(metaclass=abc.ABCMeta): 27 """Abstract base class for Tensor-like objects that are composed from Tensors. 28 29 Each `CompositeTensor` can be decomposed into a structured collection of 30 component `tf.Tensor`s, and reconstructed from those components. 31 32 The `tensorflow.python.util.nest` module has support for treating composite 33 tensors as structure, which makes it easy to flatten and reconstruct 34 composite tensors (or larger structures that contain composite tensors). 35 E.g.: 36 37 ```python 38 ct = ... # Create a composite tensor. 39 flat_list_of_tensors = nest.flatten(ct, expand_composites=True) 40 transformed_list_of_tensors = ... # do something with the flat tensors. 41 result = nest.pack_sequence_as(ct, transformed_list_of_tensors, 42 expand_composites=True) 43 ``` 44 """ 45 46 @abc.abstractproperty 47 def _type_spec(self): 48 """A `TypeSpec` describing the type of this value.""" 49 raise NotImplementedError(f"{type(self).__name__}._type_spec()") 50 51 def _shape_invariant_to_type_spec(self, shape): 52 """Returns a TypeSpec given a shape invariant (used by `tf.while_loop`). 53 54 Args: 55 shape: A `tf.TensorShape` object. The shape invariant for this 56 `CompositeTensor`, or `None` if a default shape invariant should be used 57 (based on the value of this `CompositeTensor`). 58 59 Returns: 60 A nested structure whose values are `tf.TensorShape` objects, specifying 61 the shape invariants for the tensors that comprise this `CompositeTensor`. 62 """ 63 # New TypeSpec subclasses generally do not need to implement this -- 64 # this method is used for backwards compatibility. Users of tf.while_loop 65 # can specify a type by passing in TypeSpec instead. 66 raise NotImplementedError( 67 f"{type(self).__name__}._shape_invariant_to_type_spec") 68 69 def _consumers(self): 70 """Returns a list of `Operation`s that consume this `CompositeTensor`. 71 72 Returns: 73 A list of `Operation`s. 74 75 Raises: 76 RuntimeError: If this method is called while executing eagerly. 77 """ 78 consumers = nest.flatten([ 79 component.consumers() 80 for component in nest.flatten(self, expand_composites=True) 81 if getattr(component, "graph", None) is not None 82 ]) 83 return list(set(consumers)) 84 85 def __tf_tracing_type__(self, context): 86 return self._type_spec.__tf_tracing_type__(context) 87 88 def _convert_variables_to_tensors(self): 89 """Converts ResourceVariable components to Tensors. 90 91 Override this method to explicitly convert ResourceVariables embedded in the 92 CompositeTensor to Tensors. By default, it returns the CompositeTensor 93 unchanged. 94 95 Returns: 96 A CompositeTensor with all its ResourceVariable components converted to 97 Tensors. 98 """ 99 return self 100 101 102_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor) 103 104 105def replace_composites_with_components(structure): 106 """Recursively replaces CompositeTensors with their components. 107 108 Args: 109 structure: A `nest`-compatible structure, possibly containing composite 110 tensors. 111 112 Returns: 113 A copy of `structure`, where each composite tensor has been replaced by 114 its components. The result will contain no composite tensors. 115 Note that `nest.flatten(replace_composites_with_components(structure))` 116 returns the same value as `nest.flatten(structure)`. 117 """ 118 if isinstance(structure, CompositeTensor): 119 return replace_composites_with_components( 120 structure._type_spec._to_components(structure)) # pylint: disable=protected-access 121 elif not nest.is_nested(structure): 122 return structure 123 else: 124 return nest.map_structure( 125 replace_composites_with_components, structure, expand_composites=False) 126 127 128def convert_variables_to_tensors(composite_tensor): 129 return composite_tensor._convert_variables_to_tensors() # pylint: disable=protected-access 130 131 132# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just 133# convert_to_tensor_or_composite? Alternatively, should composite tensors 134# register a dispatch override for tf.convert_to_tensor? 135 136# Note about the internal encoding of composite tensors when they are "lowered" 137# from Python objects to tensors. The usual encoding is "component encoding" 138# which uses the dense tensors that represent a composite tensor. 139# A second encoding, "batchable tensor list encoding", is used by datasets 140# and map_fn which in addition to supporting batching also can use ops 141# for encoding and decoding, e.g. for encoding/decoding to/from a 142# single variant that represents a composite tensor. Some internal properties 143# for type specs for composite tensors use `flat` as a nickname for 144# "batchable tensor list encoding". (e.g. `flat_tensor_specs`). 145