xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/function_saved_model_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""A shim layer for working with functions exported/restored from saved models.
17
18This functionality should ultimately be moved into a first-class core API.
19"""
20
21import gc
22import warnings
23
24import numpy
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import handle_data_util
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import variables as variables_lib
33from tensorflow.python.saved_model import registration
34from tensorflow.python.trackable import asset
35from tensorflow.python.trackable import base as trackable
36from tensorflow.python.trackable import resource
37
38
39@registration.register_tf_serializable()
40class TrackableConstant(trackable.Trackable):
41  """Trackable class for captured constants."""
42  __slots__ = ("capture", "function", "_exported_tensor")
43
44  def __init__(self, capture, function):
45    self.capture = capture
46    self.function = function
47    self._exported_tensor = None
48
49  def _export_to_saved_model_graph(self, tensor_map, **unused_kwargs):
50    capture_constant_value = tensor_util.constant_value(self.capture)
51    if capture_constant_value is None:
52      raise ValueError(
53          f"Unable to save function {self.function.name} because it "
54          f"captures graph tensor {self.capture} from a parent function which "
55          "cannot be converted to a constant with `tf.get_static_value`.")
56
57    if numpy.prod(self.capture.shape.as_list()) > 1 and numpy.all(
58        capture_constant_value == capture_constant_value.flat[0]):
59      # For the common case of a constant array filled with the same
60      # value, rebuild the constant op specifically with the shape arg,
61      # since otherwise the whole array is written into the node def,
62      # causing performance and graph proto size issues (protos cannot be
63      # bigger than 2GB).
64      copied_tensor = constant_op.constant(
65          capture_constant_value.flat[0],
66          dtype=self.capture.dtype,
67          shape=self.capture.shape)
68    else:
69      copied_tensor = constant_op.constant(capture_constant_value)
70
71    tensor_map[self.capture] = copied_tensor
72    self._exported_tensor = copied_tensor
73    return [self.capture]
74
75  def _serialize_to_proto(self, object_proto=None, **kwargs):
76    object_proto.constant.operation = self._exported_tensor.op.name
77
78  @classmethod
79  def _deserialize_from_proto(cls, object_proto, operation_attributes,
80                              **kwargs):
81    tensor_proto = (
82        operation_attributes[object_proto.constant.operation]["value"].tensor)
83    ndarray = tensor_util.MakeNdarray(tensor_proto)
84    if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
85      with ops.device("CPU"):
86        # String operations should be done on the CPU.
87        imported_constant = constant_op.constant(ndarray)
88    else:
89      imported_constant = constant_op.constant(ndarray)
90    return imported_constant
91
92
93# TODO(kathywu): Delete this class when ConcreteFunctions can be copied with new
94# captures.
95class ExportedConcreteFunction(trackable.Trackable):
96  """A callable class that uses captures from the exported SavedModel graph."""
97  __slots__ = ("function", "tensor_map")
98
99  def __init__(self, function, tensor_map):
100    self.function = function
101    self.tensor_map = tensor_map
102
103  def __call__(self, *args, **kwargs):
104    _, _, filtered_flat_args = (
105        self.function._function_spec.canonicalize_function_inputs(args, kwargs))
106    export_captures = _map_captures_to_created_tensors(
107        self.function.graph.captures, self.tensor_map, self.function)
108    return self.function._call_flat(filtered_flat_args, export_captures)
109
110
111def _map_captures_to_created_tensors(original_captures, tensor_map, function):
112  """Maps eager tensors captured by a function to Graph resources for export.
113
114  Args:
115    original_captures: A dictionary mapping from tensors captured by the
116      function to interior placeholders for those tensors (inside the function
117      body).
118    tensor_map: A dictionary mapping from resource tensors owned by the eager
119      context to resource tensors in the exported graph.
120    function: Function with the original captures. Only used when raising the
121      AssertionError.
122
123  Returns:
124    A list of stand-in tensors which belong to the exported graph, corresponding
125    to the function's captures.
126
127  Raises:
128    AssertionError: If the function references a resource which is not part of
129      `tensor_map`.
130  """
131  export_captures = []
132  for exterior, interior in original_captures:
133    mapped_resource = tensor_map.get(exterior, None)
134    if mapped_resource is None:
135      _raise_untracked_capture_error(function.name, exterior, interior)
136    export_captures.append(mapped_resource)
137  return export_captures
138
139
140def _raise_untracked_capture_error(function_name, capture,
141                                   internal_capture=None,
142                                   node_path=None):
143  """Raises AssertionError due to being unable to export a function."""
144  msg = ("Tried to export a function which references an 'untracked' resource. "
145         "TensorFlow objects (e.g. tf.Variable) captured by functions must be "
146         "'tracked' by assigning them to an attribute of a tracked object or "
147         "assigned to an attribute of the main object directly. See the "
148         "information below:"
149         f"\n\tFunction name = {function_name}")
150
151  if node_path is not None:
152    msg += f"\n\tPath to Function = {node_path}"
153
154  msg += f"\n\tCaptured Tensor = {capture}"
155  msg += f"\n\t{_get_trackable_parent_error_string(capture)}"
156
157  if internal_capture is not None:
158    msg += f"\n\tInternal Tensor = {internal_capture}"
159  raise AssertionError(msg)
160
161
162def _get_trackable_parent_error_string(capture):
163  """Gets error string with the capture's parent object."""
164  parent = getattr(capture, "_parent_trackable", None)
165  if parent is not None:
166    return f"Trackable referencing this tensor = {parent()}"
167
168  # Try to figure out where the resource came from by iterating over objects
169  # which reference it. This is slow and doesn't help us figure out how to
170  # match it to other objects when loading the SavedModel as a checkpoint,
171  # so we can't continue saving. But we can at least tell the user what
172  # needs attaching.
173  trackable_referrers = []
174  for primary_referrer in gc.get_referrers(capture):
175    if isinstance(primary_referrer, trackable.Trackable):
176      trackable_referrers.append(primary_referrer)
177    for secondary_referrer in gc.get_referrers(primary_referrer):
178      if isinstance(secondary_referrer, trackable.Trackable):
179        trackable_referrers.append(secondary_referrer)
180  return ("Trackable Python objects referring to this tensor "
181          "(from gc.get_referrers, limited to two hops) = [\n\t\t{}]"
182          .format("\n\t\t".join([repr(obj) for obj in trackable_referrers])))
183
184
185def get_tensor_from_node(node):
186  """Resolves a saved model graph node into a tensor to be captured.
187
188  Args:
189    node: a tensor, variable, or resource to be resolved into a capturable
190      tensor
191
192  Returns:
193    A list of tensors.
194  Raises:
195    ValueError: if the node cannot be converted into a tensor.
196  """
197  with ops.init_scope():
198    # TODO(b/210144904): Use __tf_tensor__ instead of `is_[...]` checks
199    if getattr(node, "is_distributed_variable", False):
200      return node
201    elif getattr(node, "is_distributed_table", False):
202      return node
203    elif getattr(node, "is_sharded_variable", False):
204      return node
205    elif resource_variable_ops.is_resource_variable(node):
206      return node.handle
207    elif isinstance(node, asset.Asset):
208      return node.asset_path
209    elif tensor_util.is_tf_type(node):
210      return node
211    elif isinstance(node, resource.CapturableResource):
212      # Note: this executes restored functions in the CapturableResource.
213      return node.resource_handle
214    raise ValueError(f"Cannot convert node {node} to tensor.")
215
216
217def restore_captures(concrete_function, inputs):
218  """Restore captures for the concrete function.
219
220  Used at deserialization time.  For functions that are being deserialized,
221  saved model restores objects that tensors were captured from, but functions
222  only know about their tensors -- object information is destroyed by tracing.
223  This additional logic extracts the tensors which the function originally
224  captured.
225
226  Args:
227    concrete_function: the concrete function for which to restore captures
228    inputs: a list tensors or other Python objects (such as variables) which
229      contain tensors that were originally captured by the function
230  """
231  bound_inputs = [get_tensor_from_node(obj) for obj in inputs]
232  bound_variables = [
233      obj for obj in inputs
234      if isinstance(obj, (variables_lib.Variable,
235                          resource_variable_ops.BaseResourceVariable))
236  ]
237  # TODO(b/205010575): This is only injecting the captured inputs into the
238  # concrete function, note that we did not modify the FuncGraph
239  # itself.
240  captured_inputs_list = []
241  concrete_function.set_variables(bound_variables)
242  if bound_inputs:
243    for bound_input, internal_capture in zip(
244        bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
245      # Distributed inputs have special logic for capturing, so we call their
246      # custom restoration methods
247      if hasattr(bound_input, "__tf_experimental_restore_capture__"):
248        captured_inputs_list.append(
249            bound_input.__tf_experimental_restore_capture__(
250                concrete_function, internal_capture))
251      else:
252        captured_inputs_list.append(bound_input)
253        concrete_function.graph.replace_capture(bound_input, internal_capture)
254        if internal_capture.dtype == dtypes.resource:
255          if resource_variable_ops.is_resource_variable(bound_input):
256            try:
257              handle = bound_input.handle
258            except ValueError:
259              # For mirrored variables we'll copy handle data for components
260              # as they get captured.
261              pass
262            else:
263              handle_data_util.copy_handle_data(handle, internal_capture)
264          else:
265            # TODO(b/213451747): Remove need to call copy_handle_data
266            handle_data_util.copy_handle_data(bound_input, internal_capture)
267        # Setting "captures" first means "capture" won't create a new
268        # placeholder for this input.
269        concrete_function.graph.capture(bound_input)
270
271  if any([inp is None for inp in captured_inputs_list]):
272    warnings.warn("Trying to load ShardedVariables using tf.saved_model.load. "
273                  "This won't work if using a tf.distribute.Strategy, and may "
274                  "use excess memory if not using a Strategy. Ignore this "
275                  "warning if using tf.keras.models.load_model.")
276  concrete_function.set_external_captures(captured_inputs_list)
277
278