1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""A shim layer for working with functions exported/restored from saved models. 17 18This functionality should ultimately be moved into a first-class core API. 19""" 20 21import gc 22import warnings 23 24import numpy 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import handle_data_util 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.ops import variables as variables_lib 33from tensorflow.python.saved_model import registration 34from tensorflow.python.trackable import asset 35from tensorflow.python.trackable import base as trackable 36from tensorflow.python.trackable import resource 37 38 39@registration.register_tf_serializable() 40class TrackableConstant(trackable.Trackable): 41 """Trackable class for captured constants.""" 42 __slots__ = ("capture", "function", "_exported_tensor") 43 44 def __init__(self, capture, function): 45 self.capture = capture 46 self.function = function 47 self._exported_tensor = None 48 49 def _export_to_saved_model_graph(self, tensor_map, **unused_kwargs): 50 capture_constant_value = tensor_util.constant_value(self.capture) 51 if capture_constant_value is None: 52 raise ValueError( 53 f"Unable to save function {self.function.name} because it " 54 f"captures graph tensor {self.capture} from a parent function which " 55 "cannot be converted to a constant with `tf.get_static_value`.") 56 57 if numpy.prod(self.capture.shape.as_list()) > 1 and numpy.all( 58 capture_constant_value == capture_constant_value.flat[0]): 59 # For the common case of a constant array filled with the same 60 # value, rebuild the constant op specifically with the shape arg, 61 # since otherwise the whole array is written into the node def, 62 # causing performance and graph proto size issues (protos cannot be 63 # bigger than 2GB). 64 copied_tensor = constant_op.constant( 65 capture_constant_value.flat[0], 66 dtype=self.capture.dtype, 67 shape=self.capture.shape) 68 else: 69 copied_tensor = constant_op.constant(capture_constant_value) 70 71 tensor_map[self.capture] = copied_tensor 72 self._exported_tensor = copied_tensor 73 return [self.capture] 74 75 def _serialize_to_proto(self, object_proto=None, **kwargs): 76 object_proto.constant.operation = self._exported_tensor.op.name 77 78 @classmethod 79 def _deserialize_from_proto(cls, object_proto, operation_attributes, 80 **kwargs): 81 tensor_proto = ( 82 operation_attributes[object_proto.constant.operation]["value"].tensor) 83 ndarray = tensor_util.MakeNdarray(tensor_proto) 84 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string: 85 with ops.device("CPU"): 86 # String operations should be done on the CPU. 87 imported_constant = constant_op.constant(ndarray) 88 else: 89 imported_constant = constant_op.constant(ndarray) 90 return imported_constant 91 92 93# TODO(kathywu): Delete this class when ConcreteFunctions can be copied with new 94# captures. 95class ExportedConcreteFunction(trackable.Trackable): 96 """A callable class that uses captures from the exported SavedModel graph.""" 97 __slots__ = ("function", "tensor_map") 98 99 def __init__(self, function, tensor_map): 100 self.function = function 101 self.tensor_map = tensor_map 102 103 def __call__(self, *args, **kwargs): 104 _, _, filtered_flat_args = ( 105 self.function._function_spec.canonicalize_function_inputs(args, kwargs)) 106 export_captures = _map_captures_to_created_tensors( 107 self.function.graph.captures, self.tensor_map, self.function) 108 return self.function._call_flat(filtered_flat_args, export_captures) 109 110 111def _map_captures_to_created_tensors(original_captures, tensor_map, function): 112 """Maps eager tensors captured by a function to Graph resources for export. 113 114 Args: 115 original_captures: A dictionary mapping from tensors captured by the 116 function to interior placeholders for those tensors (inside the function 117 body). 118 tensor_map: A dictionary mapping from resource tensors owned by the eager 119 context to resource tensors in the exported graph. 120 function: Function with the original captures. Only used when raising the 121 AssertionError. 122 123 Returns: 124 A list of stand-in tensors which belong to the exported graph, corresponding 125 to the function's captures. 126 127 Raises: 128 AssertionError: If the function references a resource which is not part of 129 `tensor_map`. 130 """ 131 export_captures = [] 132 for exterior, interior in original_captures: 133 mapped_resource = tensor_map.get(exterior, None) 134 if mapped_resource is None: 135 _raise_untracked_capture_error(function.name, exterior, interior) 136 export_captures.append(mapped_resource) 137 return export_captures 138 139 140def _raise_untracked_capture_error(function_name, capture, 141 internal_capture=None, 142 node_path=None): 143 """Raises AssertionError due to being unable to export a function.""" 144 msg = ("Tried to export a function which references an 'untracked' resource. " 145 "TensorFlow objects (e.g. tf.Variable) captured by functions must be " 146 "'tracked' by assigning them to an attribute of a tracked object or " 147 "assigned to an attribute of the main object directly. See the " 148 "information below:" 149 f"\n\tFunction name = {function_name}") 150 151 if node_path is not None: 152 msg += f"\n\tPath to Function = {node_path}" 153 154 msg += f"\n\tCaptured Tensor = {capture}" 155 msg += f"\n\t{_get_trackable_parent_error_string(capture)}" 156 157 if internal_capture is not None: 158 msg += f"\n\tInternal Tensor = {internal_capture}" 159 raise AssertionError(msg) 160 161 162def _get_trackable_parent_error_string(capture): 163 """Gets error string with the capture's parent object.""" 164 parent = getattr(capture, "_parent_trackable", None) 165 if parent is not None: 166 return f"Trackable referencing this tensor = {parent()}" 167 168 # Try to figure out where the resource came from by iterating over objects 169 # which reference it. This is slow and doesn't help us figure out how to 170 # match it to other objects when loading the SavedModel as a checkpoint, 171 # so we can't continue saving. But we can at least tell the user what 172 # needs attaching. 173 trackable_referrers = [] 174 for primary_referrer in gc.get_referrers(capture): 175 if isinstance(primary_referrer, trackable.Trackable): 176 trackable_referrers.append(primary_referrer) 177 for secondary_referrer in gc.get_referrers(primary_referrer): 178 if isinstance(secondary_referrer, trackable.Trackable): 179 trackable_referrers.append(secondary_referrer) 180 return ("Trackable Python objects referring to this tensor " 181 "(from gc.get_referrers, limited to two hops) = [\n\t\t{}]" 182 .format("\n\t\t".join([repr(obj) for obj in trackable_referrers]))) 183 184 185def get_tensor_from_node(node): 186 """Resolves a saved model graph node into a tensor to be captured. 187 188 Args: 189 node: a tensor, variable, or resource to be resolved into a capturable 190 tensor 191 192 Returns: 193 A list of tensors. 194 Raises: 195 ValueError: if the node cannot be converted into a tensor. 196 """ 197 with ops.init_scope(): 198 # TODO(b/210144904): Use __tf_tensor__ instead of `is_[...]` checks 199 if getattr(node, "is_distributed_variable", False): 200 return node 201 elif getattr(node, "is_distributed_table", False): 202 return node 203 elif getattr(node, "is_sharded_variable", False): 204 return node 205 elif resource_variable_ops.is_resource_variable(node): 206 return node.handle 207 elif isinstance(node, asset.Asset): 208 return node.asset_path 209 elif tensor_util.is_tf_type(node): 210 return node 211 elif isinstance(node, resource.CapturableResource): 212 # Note: this executes restored functions in the CapturableResource. 213 return node.resource_handle 214 raise ValueError(f"Cannot convert node {node} to tensor.") 215 216 217def restore_captures(concrete_function, inputs): 218 """Restore captures for the concrete function. 219 220 Used at deserialization time. For functions that are being deserialized, 221 saved model restores objects that tensors were captured from, but functions 222 only know about their tensors -- object information is destroyed by tracing. 223 This additional logic extracts the tensors which the function originally 224 captured. 225 226 Args: 227 concrete_function: the concrete function for which to restore captures 228 inputs: a list tensors or other Python objects (such as variables) which 229 contain tensors that were originally captured by the function 230 """ 231 bound_inputs = [get_tensor_from_node(obj) for obj in inputs] 232 bound_variables = [ 233 obj for obj in inputs 234 if isinstance(obj, (variables_lib.Variable, 235 resource_variable_ops.BaseResourceVariable)) 236 ] 237 # TODO(b/205010575): This is only injecting the captured inputs into the 238 # concrete function, note that we did not modify the FuncGraph 239 # itself. 240 captured_inputs_list = [] 241 concrete_function.set_variables(bound_variables) 242 if bound_inputs: 243 for bound_input, internal_capture in zip( 244 bound_inputs, concrete_function.inputs[-len(bound_inputs):]): 245 # Distributed inputs have special logic for capturing, so we call their 246 # custom restoration methods 247 if hasattr(bound_input, "__tf_experimental_restore_capture__"): 248 captured_inputs_list.append( 249 bound_input.__tf_experimental_restore_capture__( 250 concrete_function, internal_capture)) 251 else: 252 captured_inputs_list.append(bound_input) 253 concrete_function.graph.replace_capture(bound_input, internal_capture) 254 if internal_capture.dtype == dtypes.resource: 255 if resource_variable_ops.is_resource_variable(bound_input): 256 try: 257 handle = bound_input.handle 258 except ValueError: 259 # For mirrored variables we'll copy handle data for components 260 # as they get captured. 261 pass 262 else: 263 handle_data_util.copy_handle_data(handle, internal_capture) 264 else: 265 # TODO(b/213451747): Remove need to call copy_handle_data 266 handle_data_util.copy_handle_data(bound_input, internal_capture) 267 # Setting "captures" first means "capture" won't create a new 268 # placeholder for this input. 269 concrete_function.graph.capture(bound_input) 270 271 if any([inp is None for inp in captured_inputs_list]): 272 warnings.warn("Trying to load ShardedVariables using tf.saved_model.load. " 273 "This won't work if using a tf.distribute.Strategy, and may " 274 "use excess memory if not using a Strategy. Ignore this " 275 "warning if using tf.keras.models.load_model.") 276 concrete_function.set_external_captures(captured_inputs_list) 277 278