1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for serializing `Function`s.""" 16 17from tensorflow.core.protobuf import saved_object_graph_pb2 18from tensorflow.python.eager import function as defun 19from tensorflow.python.framework import func_graph as func_graph_module 20from tensorflow.python.saved_model import nested_structure_coder 21from tensorflow.python.util import nest 22 23 24def _serialize_function_spec(function_spec): 25 """Serialize a FunctionSpec object into its proto representation.""" 26 if function_spec.is_method and not function_spec.fullargspec.args: 27 raise NotImplementedError( 28 "Cannot serialize a method function without a named " 29 "'self' argument.") 30 proto = saved_object_graph_pb2.FunctionSpec() 31 32 # Intentionally skip encoding annotations of a function because function 33 # annotations are mainly for optional type checking during development 34 # and does not affect runtime behavior. 35 # https://www.python.org/dev/peps/pep-3107/ 36 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 37 proto.fullargspec.CopyFrom( 38 nested_structure_coder.encode_structure( 39 function_spec.fullargspec._replace(annotations={}))) 40 41 proto.is_method = function_spec.is_method 42 proto.input_signature.CopyFrom( 43 nested_structure_coder.encode_structure(function_spec.input_signature)) 44 45 # See `tf.function` and the JitCompile proto for details. 46 proto.jit_compile = { 47 None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT, 48 True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON, 49 False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF, 50 }.get(function_spec.jit_compile) 51 52 return proto 53 54 55def serialize_concrete_function(concrete_function, node_ids): 56 """Build a SavedConcreteFunction.""" 57 bound_inputs = [] 58 try: 59 for capture in concrete_function.captured_inputs: 60 bound_inputs.append(node_ids[capture]) 61 except KeyError: 62 raise KeyError( 63 f"Failed to add concrete function '{concrete_function.name}' to object-" 64 f"based SavedModel as it captures tensor {capture!r} which is unsupported" 65 " or not reachable from root. " 66 "One reason could be that a stateful object or a variable that the " 67 "function depends on is not assigned to an attribute of the serialized " 68 "trackable object (see SaveTest.test_captures_unreachable_variable).") 69 concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction() 70 structured_outputs = func_graph_module.convert_structure_to_signature( 71 concrete_function.structured_outputs) 72 concrete_function_proto.canonicalized_input_signature.CopyFrom( 73 nested_structure_coder.encode_structure( 74 concrete_function.structured_input_signature)) 75 concrete_function_proto.output_signature.CopyFrom( 76 nested_structure_coder.encode_structure(structured_outputs)) 77 concrete_function_proto.bound_inputs.extend(bound_inputs) 78 return concrete_function_proto 79 80 81def serialize_bare_concrete_function(concrete_function): 82 """Build a SavedBareConcreteFunction.""" 83 # pylint: disable=protected-access 84 proto = saved_object_graph_pb2.SavedBareConcreteFunction( 85 concrete_function_name=concrete_function.name, 86 allowed_positional_arguments=concrete_function._num_positional_args, 87 argument_keywords=concrete_function._arg_keywords) 88 if concrete_function._pre_initialized_function_spec is not None: 89 proto.function_spec.CopyFrom( 90 _serialize_function_spec( 91 concrete_function._pre_initialized_function_spec)) 92 return proto 93 # pylint: enable=protected-access 94 95 96def serialize_function(function, concrete_functions): 97 """Build a SavedFunction proto.""" 98 proto = saved_object_graph_pb2.SavedFunction() 99 100 function_spec_proto = _serialize_function_spec(function.function_spec) 101 proto.function_spec.CopyFrom(function_spec_proto) 102 for concrete_function in concrete_functions: 103 proto.concrete_functions.append(concrete_function.name) 104 return proto 105 106 107def wrap_cached_variables(concrete_function): 108 """Wraps the concrete function if it uses cached read tensors. 109 110 This function creates a new concrete function that captures variables 111 instead of the cached read tensors. 112 113 Args: 114 concrete_function: A Concrete function that maybe captures cached read 115 tensors. 116 117 Returns: 118 A concrete function that wraps the original concrete function, which 119 captures variables instead. If the original function did not capture any 120 cached values, then the function is not wrapped and the original object is 121 returned. 122 """ 123 outer_graph = func_graph_module.FuncGraph( 124 "{}_no_cache".format(concrete_function.graph.name)) 125 captures = concrete_function.graph._captures # pylint: disable=protected-access 126 mapped_captures = None 127 remapped_captures = {} 128 129 # Update the external captures to use read tensors generated in the outer 130 # graph. 131 with outer_graph.as_default(): 132 for capture, placeholder in concrete_function.graph.captures: 133 cached_variable = getattr(capture, "_cached_variable", None) 134 if cached_variable is None: 135 continue 136 cached_variable = cached_variable() 137 new_cached_value = cached_variable.read_value() 138 remapped_captures[id(capture)] = captures[id(capture)] 139 captures[id(capture)] = (new_cached_value, placeholder) 140 mapped_captures = True 141 142 if not mapped_captures: 143 return concrete_function 144 145 inner_concrete = defun.ConcreteFunction(concrete_function.graph) 146 147 def wrap_function(*args): 148 return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access 149 150 args = nest.flatten(concrete_function.structured_input_signature, 151 expand_composites=True) 152 func_graph_module.func_graph_from_py_func( 153 None, wrap_function, args=tuple(args), kwargs={}, 154 func_graph=outer_graph) 155 156 # Create concrete function, and copy the attributes necessary to serialize 157 # the function. 158 # pylint: disable=protected-access 159 fn = defun.ConcreteFunction( 160 outer_graph, spec=concrete_function._function_spec) 161 fn._arg_keywords = concrete_function._arg_keywords 162 fn._num_positional_args = concrete_function._num_positional_args 163 fn._pre_initialized_function_spec = ( 164 concrete_function._pre_initialized_function_spec) 165 # pylint: enable=protected-access 166 167 # Return the captures to their original values 168 for key, capture in remapped_captures.items(): 169 captures[key] = capture 170 return fn 171