xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/function_serialization.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for serializing `Function`s."""
16
17from tensorflow.core.protobuf import saved_object_graph_pb2
18from tensorflow.python.eager import function as defun
19from tensorflow.python.framework import func_graph as func_graph_module
20from tensorflow.python.saved_model import nested_structure_coder
21from tensorflow.python.util import nest
22
23
24def _serialize_function_spec(function_spec):
25  """Serialize a FunctionSpec object into its proto representation."""
26  if function_spec.is_method and not function_spec.fullargspec.args:
27    raise NotImplementedError(
28        "Cannot serialize a method function without a named "
29        "'self' argument.")
30  proto = saved_object_graph_pb2.FunctionSpec()
31
32  # Intentionally skip encoding annotations of a function because function
33  # annotations are mainly for optional type checking during development
34  # and does not affect runtime behavior.
35  # https://www.python.org/dev/peps/pep-3107/
36  # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
37  proto.fullargspec.CopyFrom(
38      nested_structure_coder.encode_structure(
39          function_spec.fullargspec._replace(annotations={})))
40
41  proto.is_method = function_spec.is_method
42  proto.input_signature.CopyFrom(
43      nested_structure_coder.encode_structure(function_spec.input_signature))
44
45  # See `tf.function` and the JitCompile proto for details.
46  proto.jit_compile = {
47      None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
48      True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
49      False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
50  }.get(function_spec.jit_compile)
51
52  return proto
53
54
55def serialize_concrete_function(concrete_function, node_ids):
56  """Build a SavedConcreteFunction."""
57  bound_inputs = []
58  try:
59    for capture in concrete_function.captured_inputs:
60      bound_inputs.append(node_ids[capture])
61  except KeyError:
62    raise KeyError(
63        f"Failed to add concrete function '{concrete_function.name}' to object-"
64        f"based SavedModel as it captures tensor {capture!r} which is unsupported"
65        " or not reachable from root. "
66        "One reason could be that a stateful object or a variable that the "
67        "function depends on is not assigned to an attribute of the serialized "
68        "trackable object (see SaveTest.test_captures_unreachable_variable).")
69  concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
70  structured_outputs = func_graph_module.convert_structure_to_signature(
71      concrete_function.structured_outputs)
72  concrete_function_proto.canonicalized_input_signature.CopyFrom(
73      nested_structure_coder.encode_structure(
74          concrete_function.structured_input_signature))
75  concrete_function_proto.output_signature.CopyFrom(
76      nested_structure_coder.encode_structure(structured_outputs))
77  concrete_function_proto.bound_inputs.extend(bound_inputs)
78  return concrete_function_proto
79
80
81def serialize_bare_concrete_function(concrete_function):
82  """Build a SavedBareConcreteFunction."""
83  # pylint: disable=protected-access
84  proto = saved_object_graph_pb2.SavedBareConcreteFunction(
85      concrete_function_name=concrete_function.name,
86      allowed_positional_arguments=concrete_function._num_positional_args,
87      argument_keywords=concrete_function._arg_keywords)
88  if concrete_function._pre_initialized_function_spec is not None:
89    proto.function_spec.CopyFrom(
90        _serialize_function_spec(
91            concrete_function._pre_initialized_function_spec))
92  return proto
93  # pylint: enable=protected-access
94
95
96def serialize_function(function, concrete_functions):
97  """Build a SavedFunction proto."""
98  proto = saved_object_graph_pb2.SavedFunction()
99
100  function_spec_proto = _serialize_function_spec(function.function_spec)
101  proto.function_spec.CopyFrom(function_spec_proto)
102  for concrete_function in concrete_functions:
103    proto.concrete_functions.append(concrete_function.name)
104  return proto
105
106
107def wrap_cached_variables(concrete_function):
108  """Wraps the concrete function if it uses cached read tensors.
109
110  This function creates a new concrete function that captures variables
111  instead of the cached read tensors.
112
113  Args:
114    concrete_function: A Concrete function that maybe captures cached read
115      tensors.
116
117  Returns:
118    A concrete function that wraps the original concrete function, which
119    captures variables instead. If the original function did not capture any
120    cached values, then the function is not wrapped and the original object is
121    returned.
122  """
123  outer_graph = func_graph_module.FuncGraph(
124      "{}_no_cache".format(concrete_function.graph.name))
125  captures = concrete_function.graph._captures  # pylint: disable=protected-access
126  mapped_captures = None
127  remapped_captures = {}
128
129  # Update the external captures to use read tensors generated in the outer
130  # graph.
131  with outer_graph.as_default():
132    for capture, placeholder in concrete_function.graph.captures:
133      cached_variable = getattr(capture, "_cached_variable", None)
134      if cached_variable is None:
135        continue
136      cached_variable = cached_variable()
137      new_cached_value = cached_variable.read_value()
138      remapped_captures[id(capture)] = captures[id(capture)]
139      captures[id(capture)] = (new_cached_value, placeholder)
140      mapped_captures = True
141
142  if not mapped_captures:
143    return concrete_function
144
145  inner_concrete = defun.ConcreteFunction(concrete_function.graph)
146
147  def wrap_function(*args):
148    return inner_concrete._call_flat(args, inner_concrete.captured_inputs)  # pylint:disable=protected-access
149
150  args = nest.flatten(concrete_function.structured_input_signature,
151                      expand_composites=True)
152  func_graph_module.func_graph_from_py_func(
153      None, wrap_function, args=tuple(args), kwargs={},
154      func_graph=outer_graph)
155
156  # Create concrete function, and copy the attributes necessary to serialize
157  # the function.
158  # pylint: disable=protected-access
159  fn = defun.ConcreteFunction(
160      outer_graph, spec=concrete_function._function_spec)
161  fn._arg_keywords = concrete_function._arg_keywords
162  fn._num_positional_args = concrete_function._num_positional_args
163  fn._pre_initialized_function_spec = (
164      concrete_function._pre_initialized_function_spec)
165  # pylint: enable=protected-access
166
167  # Return the captures to their original values
168  for key, capture in remapped_captures.items():
169    captures[key] = capture
170  return fn
171