1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a TF v1-style SavedModel when executing eagerly.""" 16 17import functools 18 19from tensorflow.python.eager import context 20from tensorflow.python.eager import lift_to_graph 21from tensorflow.python.eager import wrap_function 22from tensorflow.python.framework import composite_tensor 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import func_graph 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.ops import array_ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.saved_model import function_deserialization 31from tensorflow.python.saved_model import loader_impl 32from tensorflow.python.saved_model import signature_serialization 33from tensorflow.python.saved_model.pywrap_saved_model import metrics 34from tensorflow.python.trackable import asset 35from tensorflow.python.trackable import autotrackable 36from tensorflow.python.trackable import resource 37from tensorflow.python.training import monitored_session 38from tensorflow.python.training import saver as tf_saver 39from tensorflow.python.util import nest 40 41# API label for SavedModel metrics. 42_LOAD_V1_V2_LABEL = "load_v1_in_v2" 43 44 45class _Initializer(resource.CapturableResource): 46 """Represents an initialization operation restored from a SavedModel. 47 48 Without this object re-export of imported 1.x SavedModels would omit the 49 original SavedModel's initialization procedure. 50 51 Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an 52 initialization op. This object holds a function that runs the 53 initialization. It does not require any manual user intervention; 54 `tf.saved_model.save` will see this object and automatically add it to the 55 exported SavedModel, and `tf.saved_model.load` runs the initialization 56 function automatically. 57 """ 58 59 def __init__(self, init_fn, asset_paths): 60 super(_Initializer, self).__init__() 61 self._asset_paths = asset_paths 62 self._init_fn = init_fn 63 64 def _create_resource(self): 65 return array_ops.placeholder( 66 dtype=dtypes.resource, shape=[], name="unused_resource") 67 68 def _initialize(self): 69 return self._init_fn(*[path.asset_path for path in self._asset_paths]) 70 71 72class _EagerSavedModelLoader(loader_impl.SavedModelLoader): 73 """Loads a SavedModel without using Sessions.""" 74 75 def get_meta_graph_def_from_tags(self, tags): 76 """Override to support implicit one-MetaGraph loading with tags=None.""" 77 if tags is None: 78 if len(self._saved_model.meta_graphs) != 1: 79 tag_sets = [mg.meta_info_def.tags 80 for mg in self._saved_model.meta_graphs] 81 raise ValueError( 82 "Importing a SavedModel with `tf.saved_model.load` requires a " 83 "`tags=` argument if there is more than one MetaGraph. Got " 84 f"`tags=None`, but there are {len(self._saved_model.meta_graphs)} " 85 f"MetaGraphs in the SavedModel with tag sets: {tag_sets}. Pass a " 86 "`tags=` argument to load this SavedModel.") 87 return self._saved_model.meta_graphs[0] 88 return super(_EagerSavedModelLoader, self).get_meta_graph_def_from_tags( 89 tags) 90 91 def load_graph(self, returns, meta_graph_def): 92 """Called from wrap_function to import `meta_graph_def`.""" 93 # pylint: disable=protected-access 94 saver, _ = tf_saver._import_meta_graph_with_return_elements( 95 meta_graph_def) 96 # pylint: enable=protected-access 97 returns[0] = saver 98 99 def _extract_saver_restore(self, wrapped, saver): 100 if saver is None: 101 return None 102 saver_def = saver.saver_def 103 filename_tensor = wrapped.graph.as_graph_element( 104 saver_def.filename_tensor_name) 105 # We both feed and fetch filename_tensor so we have an operation to use to 106 # feed into variable initializers (only relevant for v1 graph building). 107 return wrapped.prune( 108 feeds=[filename_tensor], 109 fetches=[filename_tensor, 110 wrapped.graph.as_graph_element(saver_def.restore_op_name)]) 111 112 def restore_variables(self, wrapped, restore_from_saver): 113 """Restores variables from the checkpoint.""" 114 if restore_from_saver is not None: 115 initializer, _ = restore_from_saver( 116 constant_op.constant(self._variables_path)) 117 if not ops.executing_eagerly_outside_functions(): 118 # Add the initialization operation to the "saved_model_initializers" 119 # collection in case we don't have any lifted variables to attach it to. 120 ops.add_to_collection("saved_model_initializers", initializer) 121 one_unlifted = False 122 123 for variable in wrapped.graph.get_collection_ref( 124 ops.GraphKeys.GLOBAL_VARIABLES): 125 if variable.graph is wrapped.graph: 126 one_unlifted = True 127 # pylint: disable=protected-access 128 variable._initializer_op = initializer 129 # pylint: enable=protected-access 130 if one_unlifted: 131 logging.warning( 132 "Some variables could not be lifted out of a loaded function. " 133 "Please run " 134 "`sess.run(tf.get_collection(\"saved_model_initializers\"))`to " 135 "restore these variables.") 136 137 def _extract_signatures(self, wrapped, meta_graph_def): 138 """Creates ConcreteFunctions for signatures in `meta_graph_def`.""" 139 signature_functions = {} 140 for signature_key, signature_def in meta_graph_def.signature_def.items(): 141 if signature_def.inputs: 142 input_items = sorted( 143 signature_def.inputs.items(), key=lambda item: item[0]) 144 original_input_names, input_specs = zip(*input_items) 145 else: 146 original_input_names = [] 147 input_specs = [] 148 # TODO(b/205015292): Support optional arguments 149 feeds = [ 150 wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access 151 for input_spec in input_specs 152 ] 153 input_names = [] 154 input_tensors = [] 155 for original_input_name, feed in zip(original_input_names, feeds): 156 if isinstance(feed, sparse_tensor.SparseTensor): 157 # We have to give explicit name for SparseTensor arguments, because 158 # these are not present in the TensorInfo. 159 indices_name = "%s_indices" % original_input_name 160 values_name = "%s_values" % original_input_name 161 dense_shape_name = "%s_dense_shape" % original_input_name 162 input_names.extend([indices_name, values_name, dense_shape_name]) 163 input_tensors.extend([feed.indices, feed.values, feed.dense_shape]) 164 elif isinstance(feed, composite_tensor.CompositeTensor): 165 component_tensors = nest.flatten(feed, expand_composites=True) 166 input_names.extend("%s_component_%d" % (original_input_name, n) 167 for n in range(len(component_tensors))) 168 input_tensors.extend(component_tensors) 169 else: 170 input_names.append(original_input_name) 171 input_tensors.append(feed) 172 fetches = {name: out for name, out in signature_def.outputs.items()} 173 try: 174 signature_fn = wrapped.prune(feeds=feeds, fetches=fetches) 175 except lift_to_graph.UnliftableError as ex: 176 # Mutate the exception to add a bit more detail. 177 args = ex.args 178 if not args: 179 message = "" 180 else: 181 message = args[0] 182 message = ( 183 ("A SavedModel signature needs an input for each placeholder the " 184 "signature's outputs use. An output for signature '{}' depends on " 185 "a placeholder which is not an input (i.e. the placeholder is not " 186 "fed a value).\n\n").format(signature_key) 187 + message) 188 ex.args = (message,) + args[1:] 189 raise 190 # pylint: disable=protected-access 191 signature_fn._arg_keywords = input_names 192 signature_fn._func_graph.structured_input_signature = ( 193 (), 194 func_graph.convert_structure_to_signature( 195 dict(zip(input_names, input_tensors)))) 196 197 if len(input_names) == 1: 198 # Allowing positional arguments does not create any ambiguity if there's 199 # only one. 200 signature_fn._num_positional_args = 1 201 else: 202 signature_fn._num_positional_args = 0 203 # pylint: enable=protected-access 204 signature_functions[signature_key] = signature_fn 205 return signature_functions 206 207 def load(self, tags): 208 """Creates an object from the MetaGraph identified by `tags`.""" 209 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 210 load_shared_name_suffix = "_load_{}".format(ops.uid()) 211 functions = function_deserialization.load_function_def_library( 212 meta_graph_def.graph_def.library, 213 load_shared_name_suffix=load_shared_name_suffix) 214 # Replace existing functions in the MetaGraphDef with renamed functions so 215 # we don't have duplicates or name collisions. 216 meta_graph_def.graph_def.library.Clear() 217 for function in functions.values(): 218 meta_graph_def.graph_def.library.function.add().CopyFrom( 219 function.function_def) 220 # We've renamed functions and shared names. We need the same operation on 221 # the GraphDef itself for consistency. 222 for node_def in meta_graph_def.graph_def.node: 223 function_deserialization.fix_node_def(node_def, functions, 224 load_shared_name_suffix) 225 226 load_graph_returns = [None] 227 wrapped = wrap_function.wrap_function( 228 functools.partial(self.load_graph, load_graph_returns, meta_graph_def), 229 signature=[]) 230 saver, = load_graph_returns 231 restore_from_saver = self._extract_saver_restore(wrapped, saver) 232 self.restore_variables(wrapped, restore_from_saver) 233 with wrapped.graph.as_default(): 234 init_op = loader_impl.get_init_op( 235 meta_graph_def) or monitored_session.Scaffold.default_local_init_op() 236 # Add a dummy Tensor we know we can fetch to add control dependencies to. 237 init_anchor = constant_op.constant(0., name="dummy_fetch") 238 239 root = autotrackable.AutoTrackable() 240 if restore_from_saver is not None: 241 root.restore = ( 242 lambda path: restore_from_saver(constant_op.constant(path))) 243 asset_feed_tensors = [] 244 asset_paths = [] 245 for tensor_name, value in loader_impl.get_asset_tensors( 246 self._export_dir, meta_graph_def).items(): 247 asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) 248 asset_paths.append(asset.Asset(value)) 249 init_fn = wrapped.prune( 250 feeds=asset_feed_tensors, 251 fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) 252 initializer = _Initializer(init_fn, asset_paths) 253 # pylint: disable=protected-access 254 local_init_op, _ = initializer._initialize() 255 # pylint: enable=protected-access 256 with ops.init_scope(): 257 if not context.executing_eagerly(): 258 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) 259 for variable in wrapped.graph.get_collection_ref( 260 ops.GraphKeys.LOCAL_VARIABLES): 261 # pylint: disable=protected-access 262 variable._initializer_op = local_init_op 263 # pylint: enable=protected-access 264 root.initializer = initializer 265 root.asset_paths = asset_paths 266 signature_functions = self._extract_signatures(wrapped, meta_graph_def) 267 268 root.signatures = signature_serialization.create_signature_map( 269 signature_functions) 270 root.variables = list(wrapped.graph.variables) 271 root.tensorflow_version = ( 272 meta_graph_def.meta_info_def.tensorflow_version) 273 root.tensorflow_git_version = ( 274 meta_graph_def.meta_info_def.tensorflow_git_version) 275 root.graph = wrapped.graph 276 root.prune = wrapped.prune 277 return root 278 279 280def load(export_dir, tags): 281 """Load a v1-style SavedModel as an object.""" 282 metrics.IncrementReadApi(_LOAD_V1_V2_LABEL) 283 loader = _EagerSavedModelLoader(export_dir) 284 result = loader.load(tags=tags) 285 metrics.IncrementRead(write_version="1") 286 return result 287