1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel builder implementation.""" 16 17import functools 18import os 19 20from google.protobuf.any_pb2 import Any 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.core.protobuf import saved_model_pb2 25from tensorflow.core.protobuf import saver_pb2 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.lib.io import file_io 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import tf_logging 31from tensorflow.python.saved_model import constants 32from tensorflow.python.saved_model import signature_def_utils 33from tensorflow.python.saved_model import utils_impl as saved_model_utils 34from tensorflow.python.saved_model.pywrap_saved_model import metrics 35from tensorflow.python.training import saver as tf_saver 36from tensorflow.python.util import compat 37from tensorflow.python.util.deprecation import deprecated_args 38from tensorflow.python.util.tf_export import tf_export 39 40# API label for SavedModel metrics. 41_SAVE_BUILDER_LABEL = "save_v1_builder" 42 43 44# Base class for the SavedModelBuilder that is only used by Tensorflow 45# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead. 46@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[]) 47class _SavedModelBuilder(object): 48 """Builds the `SavedModel` protocol buffer and saves variables and assets. 49 50 The `SavedModelBuilder` class provides the functionality to build a 51 `SavedModel` protocol buffer. Specifically, this allows multiple meta 52 graphs to be saved as part of a single language-neutral `SavedModel`, 53 while sharing variables and assets. 54 55 To build a SavedModel, the first meta graph must be saved with variables. 56 Subsequent meta graphs will simply be saved with their graph definitions. If 57 assets need to be saved and written or copied to disk, they can be provided 58 when the meta graph def is added. If multiple meta graph defs are associated 59 an asset of the same name, only the first version is retained. 60 61 Each meta graph added to the SavedModel must be annotated with tags. The tags 62 provide a means to identify the specific meta graph to load and restore, along 63 with the shared set of variables and assets. 64 65 Typical usage for the `SavedModelBuilder`: 66 67 ```python 68 ... 69 builder = tf.compat.v1.saved_model.Builder(export_dir) 70 71 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 72 ... 73 builder.add_meta_graph_and_variables(sess, 74 ["foo-tag"], 75 signature_def_map=foo_signatures, 76 assets_list=foo_assets) 77 ... 78 79 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 80 ... 81 builder.add_meta_graph(["bar-tag", "baz-tag"]) 82 ... 83 84 builder.save() 85 ``` 86 87 Note: This function will only be available through the v1 compatibility 88 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or 89 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new 90 object-based method of creating SavedModels. 91 """ 92 93 def __init__(self, export_dir): 94 self._saved_model = saved_model_pb2.SavedModel() 95 self._saved_model.saved_model_schema_version = ( 96 constants.SAVED_MODEL_SCHEMA_VERSION) 97 98 self._export_dir = export_dir 99 if file_io.file_exists(export_dir): 100 if file_io.list_directory(export_dir): 101 raise AssertionError( 102 f"Export directory {export_dir} already exists, and isn't empty. " 103 "Please choose a different export directory, or delete all the " 104 "contents of the specified directory.") 105 else: 106 file_io.recursive_create_dir(self._export_dir) 107 108 # Boolean to track whether variables and assets corresponding to the 109 # SavedModel have been saved. Specifically, the first meta graph to be added 110 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations 111 # on the SavedModel MUST use the add_meta_graph() API which does not save 112 # weights. 113 self._has_saved_variables = False 114 115 def _save_and_write_assets(self, meta_graph_def, assets_list=None): 116 """Saves asset to the meta graph and writes asset files to disk. 117 118 Args: 119 meta_graph_def: The meta graph def to which the assets will be added. 120 assets_list: The list where the asset paths are setup. 121 """ 122 # Creates a function that adds assets into the meta graph def. 123 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def) 124 asset_filename_map = _maybe_save_assets(write_fn, assets_list) 125 126 # Return if there are no assets to write. 127 if not asset_filename_map: 128 tf_logging.info("No assets to write.") 129 return 130 131 # Copy assets from source path to destination path. 132 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 133 134 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): 135 """Tags the meta graph def and adds it to the SavedModel. 136 137 Tags the meta graph def with the supplied tags, adds signature defs to it if 138 provided and appends the meta graph def to the SavedModel proto. 139 140 Args: 141 meta_graph_def: The meta graph def to add to the SavedModel. 142 tags: The set of tags to annotate the meta graph def with. 143 signature_def_map: The map of signature defs to be added to the meta graph 144 def. 145 """ 146 for tag in tags: 147 meta_graph_def.meta_info_def.tags.append(tag) 148 149 if signature_def_map is not None: 150 for key in signature_def_map: 151 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) 152 153 proto_meta_graph_def = self._saved_model.meta_graphs.add() 154 proto_meta_graph_def.CopyFrom(meta_graph_def) 155 156 def _validate_tensor_info(self, tensor_info): 157 """Validates the `TensorInfo` proto. 158 159 Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and 160 `dtype` fields exist and are non-empty. 161 162 Args: 163 tensor_info: `TensorInfo` protocol buffer to validate. 164 165 Raises: 166 AssertionError: If the `encoding` or `dtype` fields of the supplied 167 `TensorInfo` proto are not populated. 168 """ 169 if tensor_info is None: 170 raise AssertionError( 171 "All TensorInfo protos used in the SignatureDefs must have the name " 172 "and dtype fields set.") 173 if tensor_info.WhichOneof("encoding") is None: 174 # TODO(soergel) validate each of the fields of coo_sparse 175 raise AssertionError( 176 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used " 177 "in the SignatureDefs must have one of the 'encoding' fields (e.g., " 178 "name or coo_sparse) set.") 179 if tensor_info.WhichOneof("encoding") == "composite_tensor": 180 for component in tensor_info.composite_tensor.components: 181 self._validate_tensor_info(component) 182 elif tensor_info.dtype == types_pb2.DT_INVALID: 183 raise AssertionError( 184 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used in" 185 " the SignatureDefs must have the dtype field set.") 186 187 def _validate_signature_def_map(self, signature_def_map): 188 """Validates the `SignatureDef` entries in the signature def map. 189 190 Validation of entries in the signature def map includes ensuring that the 191 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and 192 `outputs` of each `SignatureDef` are populated. Also ensures that reserved 193 SignatureDef keys for the initialization and train ops are not used. 194 195 Args: 196 signature_def_map: The map of signature defs to be validated. 197 198 Raises: 199 AssertionError: If a TensorInfo is not valid. 200 KeyError: If a reserved signature key is used in the map. 201 """ 202 for signature_def_key in signature_def_map: 203 signature_def = signature_def_map[signature_def_key] 204 inputs = signature_def.inputs 205 outputs = signature_def.outputs 206 for inputs_key in inputs: 207 self._validate_tensor_info(inputs[inputs_key]) 208 for outputs_key in outputs: 209 self._validate_tensor_info(outputs[outputs_key]) 210 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map: 211 raise KeyError( 212 f"SignatureDef map key \"{constants.INIT_OP_SIGNATURE_KEY}\" is " 213 "reserved for initialization. Please use a different key.") 214 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map: 215 raise KeyError( 216 f"SignatureDef map key \"{constants.TRAIN_OP_SIGNATURE_KEY}\" is " 217 f"reserved for the train op. Please use a different key.") 218 219 def _maybe_create_saver(self, saver=None): 220 """Creates a sharded saver if one does not already exist.""" 221 if not saver: 222 # Initialize a saver to generate a sharded output for all saveables in the 223 # current scope. 224 saver = tf_saver.Saver( 225 variables._all_saveable_objects(), # pylint: disable=protected-access 226 sharded=True, 227 write_version=saver_pb2.SaverDef.V2, 228 allow_empty=True) 229 return saver 230 231 def add_meta_graph(self, 232 tags, 233 signature_def_map=None, 234 assets_list=None, 235 clear_devices=False, 236 init_op=None, 237 train_op=None, 238 saver=None): 239 """Adds the current meta graph to the SavedModel. 240 241 Creates a Saver in the current scope and uses the Saver to export the meta 242 graph def. Invoking this API requires the `add_meta_graph_and_variables()` 243 API to have been invoked before. 244 245 Args: 246 tags: The set of tags to annotate the meta graph def with. 247 signature_def_map: The map of signature defs to be added to the meta graph 248 def. 249 assets_list: Assets to be saved with SavedModel. Note 250 that this list should be a subset of the assets saved as part of 251 the first meta graph in the SavedModel. 252 clear_devices: Set to true if the device info on the default graph should 253 be cleared. 254 init_op: Op or group of ops to execute when the graph is loaded. Note 255 that when the init_op is specified it is run after the restore op at 256 load-time. 257 train_op: Op or group of opts that trains the model when run. This will 258 not be run automatically when the graph is loaded, instead saved in 259 a SignatureDef accessible through the exported MetaGraph. 260 saver: An instance of tf.compat.v1.train.Saver that will be used to export 261 the metagraph. If None, a sharded Saver that restores all variables will 262 be used. 263 264 Raises: 265 AssertionError: If the variables for the SavedModel have not been saved 266 yet, or if the graph already contains one or more legacy init ops. 267 """ 268 if not self._has_saved_variables: 269 raise AssertionError( 270 "Graph state including variables and assets has not been saved yet. " 271 "Please invoke `add_meta_graph_and_variables()` first.") 272 273 # Validate the signature def map to ensure all included TensorInfos are 274 # properly populated. 275 signature_def_map = signature_def_map or {} 276 self._validate_signature_def_map(signature_def_map) 277 278 # Create a SignatureDef pointing to the graph initialization op, which will 279 # be added to the MetaGraphDef. 280 _add_op_to_signature_def_map(signature_def_map, init_op, 281 constants.INIT_OP_SIGNATURE_KEY) 282 _add_op_to_signature_def_map(signature_def_map, train_op, 283 constants.TRAIN_OP_SIGNATURE_KEY) 284 285 saver = self._maybe_create_saver(saver) 286 287 # The graph almost certainly previously contained at least one Saver, and 288 # possibly several (e.g. one for loading a pretrained embedding, and another 289 # for the model weights). Removing the preexisting ones was the 290 # motivation for the clear_extraneous_savers option, but it turns out that 291 # there are edge cases where that option breaks the graph. Until that is 292 # resolved, we just leave the option set to False for now. 293 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 294 meta_graph_def = saver.export_meta_graph( 295 clear_devices=clear_devices, strip_default_attrs=True) 296 297 # Save asset files and write them to disk, if any. 298 self._save_and_write_assets(meta_graph_def, assets_list) 299 300 # Tag the meta graph def and add it to the SavedModel. 301 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 302 303 def add_meta_graph_and_variables(self, 304 sess, 305 tags, 306 signature_def_map=None, 307 assets_list=None, 308 clear_devices=False, 309 init_op=None, 310 train_op=None, 311 strip_default_attrs=False, 312 saver=None): 313 # pylint: disable=line-too-long 314 """Adds the current meta graph to the SavedModel and saves variables. 315 316 Creates a Saver to save the variables from the provided session. Exports the 317 corresponding meta graph def. This function assumes that the variables to be 318 saved have been initialized. For a given `SavedModelBuilder`, this API must 319 be called exactly once and for the first meta graph to save. For subsequent 320 meta graph defs to be added, the `add_meta_graph()` API must be used. 321 322 Args: 323 sess: The TensorFlow session from which to save the meta graph and 324 variables. 325 tags: The set of tags with which to save the meta graph. 326 signature_def_map: The map of signature def map to add to the meta graph 327 def. 328 assets_list: Assets to be saved with SavedModel. 329 clear_devices: Set to true if the device info on the default graph should 330 be cleared. 331 init_op: Op or group of ops to execute when the graph is loaded. Note 332 that when the init_op is specified it is run after the restore op at 333 load-time. 334 train_op: Op or group of ops that trains the model when run. This will 335 not be run automatically when the graph is loaded, instead saved in 336 a SignatureDef accessible through the exported MetaGraph. 337 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 338 removed from the NodeDefs. For a detailed guide, see 339 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 340 saver: An instance of tf.compat.v1.train.Saver that will be used to export the 341 metagraph and save variables. If None, a sharded Saver that restores 342 all variables will be used. 343 344 """ 345 # pylint: enable=line-too-long 346 if self._has_saved_variables: 347 raise AssertionError("Graph state including variables and assets has " 348 "already been saved. Please invoke " 349 "`add_meta_graph()` instead.") 350 351 # Validate the signature def map to ensure all included TensorInfos are 352 # properly populated. 353 signature_def_map = signature_def_map or {} 354 self._validate_signature_def_map(signature_def_map) 355 356 # Create a SignatureDef pointing to the graph initialization op, which will 357 # be added to the MetaGraphDef. 358 _add_op_to_signature_def_map(signature_def_map, init_op, 359 constants.INIT_OP_SIGNATURE_KEY) 360 _add_op_to_signature_def_map(signature_def_map, train_op, 361 constants.TRAIN_OP_SIGNATURE_KEY) 362 363 saved_model_utils.get_or_create_variables_dir(self._export_dir) 364 variables_path = saved_model_utils.get_variables_path(self._export_dir) 365 366 saver = self._maybe_create_saver(saver) 367 368 # Save the variables. Also, disable writing the checkpoint state proto. The 369 # file is not used during SavedModel loading. In addition, since a 370 # SavedModel can be copied or moved, this avoids the checkpoint state to 371 # become outdated. 372 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 373 374 # Export the meta graph def. 375 376 # The graph almost certainly previously contained at least one Saver, and 377 # possibly several (e.g. one for loading a pretrained embedding, and another 378 # for the model weights). Removing the preexisting ones was the 379 # motivation for the clear_extraneous_savers option, but it turns out that 380 # there are edge cases where that option breaks the graph. Until that is 381 # resolved, we just leave the option set to False for now. 382 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 383 meta_graph_def = saver.export_meta_graph( 384 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 385 386 # Save asset files and write them to disk, if any. 387 self._save_and_write_assets(meta_graph_def, assets_list) 388 389 # Tag the meta graph def and add it to the SavedModel. 390 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 391 392 # Mark this instance of SavedModel as having saved variables, such that 393 # subsequent attempts to save variables will fail. 394 self._has_saved_variables = True 395 396 def save(self, as_text=False): 397 """Writes a `SavedModel` protocol buffer to disk. 398 399 The function writes the SavedModel protocol buffer to the export directory 400 in a serialized format. 401 402 Args: 403 as_text: Writes the SavedModel protocol buffer in text format to 404 disk. Protocol buffers in text format are useful for debugging, but 405 parsing fails when it encounters an unknown field and so is not forward 406 compatible. This means changes to TensorFlow may prevent deployment of 407 new text format SavedModels to existing serving binaries. Do not deploy 408 `as_text` SavedModels to production. 409 410 Returns: 411 The path to which the SavedModel protocol buffer was written. 412 """ 413 metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL) 414 if not file_io.file_exists(self._export_dir): 415 file_io.recursive_create_dir(self._export_dir) 416 417 if as_text: 418 path = file_io.join( 419 compat.as_bytes(self._export_dir), 420 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 421 file_io.write_string_to_file(path, str(self._saved_model)) 422 else: 423 path = file_io.join( 424 compat.as_bytes(self._export_dir), 425 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 426 file_io.write_string_to_file( 427 path, self._saved_model.SerializeToString(deterministic=True)) 428 tf_logging.info("SavedModel written to: %s", compat.as_text(path)) 429 metrics.IncrementWrite(write_version="1") 430 return path 431 432 433@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring 434class SavedModelBuilder(_SavedModelBuilder): 435 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list", 436 "assets_collection") 437 438 def __init__(self, export_dir): 439 super(SavedModelBuilder, self).__init__(export_dir=export_dir) 440 441 def _add_collections(self, assets_collection, main_op, train_op): 442 """Add asset and op collections to be saved.""" 443 # Save asset files and write them to disk, if any. 444 self._save_and_write_assets(assets_collection) 445 446 self._maybe_add_main_op(main_op) 447 448 self._add_train_op(train_op) 449 450 def _save_and_write_assets(self, assets_collection_to_add=None): 451 """Saves asset to the meta graph and writes asset files to disk. 452 453 Args: 454 assets_collection_to_add: The collection where the asset paths are setup. 455 """ 456 # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the 457 # graph. 458 asset_filename_map = _maybe_save_assets(_add_asset_to_collection, 459 assets_collection_to_add) 460 461 # Return if there are no assets to write. 462 if not asset_filename_map: 463 tf_logging.info("No assets to write.") 464 return 465 466 # Copy assets from source path to destination path. 467 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 468 469 def _maybe_add_main_op(self, main_op): 470 """Adds main op to the SavedModel. 471 472 Args: 473 main_op: Main op to run as part of graph initialization. If None, no main 474 op will be added to the graph. 475 476 Raises: 477 TypeError: If the main op is provided but is not of type `Operation`. 478 ValueError: if the Graph already contains an init op. 479 """ 480 if main_op is None: 481 return 482 483 if not isinstance(main_op, ops.Operation): 484 raise TypeError(f"Expected {main_op} to be an Operation but got type " 485 f"{type(main_op)} instead.") 486 487 # Validate that no other init ops have been added to this graph already. 488 # We check main_op and legacy_init_op for thoroughness and explicitness. 489 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): 490 if ops.get_collection(init_op_key): 491 raise ValueError( 492 "Graph already contains one or more main ops under the " 493 f"collection {init_op_key}.") 494 495 ops.add_to_collection(constants.MAIN_OP_KEY, main_op) 496 497 def _add_train_op(self, train_op): 498 """Add train op to the SavedModel. 499 500 Note that this functionality is in development, and liable to be 501 moved elsewhere. 502 503 Args: 504 train_op: Op or group of ops that are used for training. These are stored 505 as a collection with key TRAIN_OP_KEY, but not executed. 506 507 Raises: 508 TypeError if Train op is not of type `Operation`. 509 """ 510 if train_op is not None: 511 if (not isinstance(train_op, ops.Tensor) and 512 not isinstance(train_op, ops.Operation)): 513 raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.") 514 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) 515 516 @deprecated_args(None, 517 "Pass your op to the equivalent parameter main_op instead.", 518 "legacy_init_op") 519 def add_meta_graph(self, 520 tags, 521 signature_def_map=None, 522 assets_collection=None, 523 legacy_init_op=None, 524 clear_devices=False, 525 main_op=None, 526 strip_default_attrs=False, 527 saver=None): 528 if not self._has_saved_variables: 529 raise AssertionError( 530 "Graph state including variables and assets has not been saved yet. " 531 "Please invoke `add_meta_graph_and_variables()` first.") 532 533 # Validate the signature def map to ensure all included TensorInfos are 534 # properly populated. 535 signature_def_map = signature_def_map or {} 536 self._validate_signature_def_map(signature_def_map) 537 538 # legacy_init_op is deprecated, and going away in TF 2.0. 539 # Re-mapping to main_op, as treatment is identical regardless. 540 main_op = main_op if main_op is not None else legacy_init_op 541 542 # Add assets and ops 543 self._add_collections(assets_collection, main_op, None) 544 545 saver = self._maybe_create_saver(saver) 546 547 # The graph almost certainly previously contained at least one Saver, and 548 # possibly several (e.g. one for loading a pretrained embedding, and another 549 # for the model weights). Removing the preexisting ones was the 550 # motivation for the clear_extraneous_savers option, but it turns out that 551 # there are edge cases where that option breaks the graph. Until that is 552 # resolved, we just leave the option set to False for now. 553 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 554 meta_graph_def = saver.export_meta_graph( 555 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 556 557 # Tag the meta graph def and add it to the SavedModel. 558 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 559 560 @deprecated_args(None, 561 "Pass your op to the equivalent parameter main_op instead.", 562 "legacy_init_op") 563 def add_meta_graph_and_variables(self, 564 sess, 565 tags, 566 signature_def_map=None, 567 assets_collection=None, 568 legacy_init_op=None, 569 clear_devices=False, 570 main_op=None, 571 strip_default_attrs=False, 572 saver=None): 573 if self._has_saved_variables: 574 raise AssertionError("Graph state including variables and assets has " 575 "already been saved. Please invoke " 576 "`add_meta_graph()` instead.") 577 578 # Validate the signature def map to ensure all included TensorInfos are 579 # properly populated. 580 signature_def_map = signature_def_map or {} 581 self._validate_signature_def_map(signature_def_map) 582 583 # legacy_init_op is deprecated, and going away in TF 2.0. 584 # Re-mapping to main_op, as treatment is identical regardless. 585 main_op = main_op or legacy_init_op 586 587 # Add assets and ops 588 self._add_collections(assets_collection, main_op, None) 589 590 saved_model_utils.get_or_create_variables_dir(self._export_dir) 591 variables_path = saved_model_utils.get_variables_path(self._export_dir) 592 593 saver = self._maybe_create_saver(saver) 594 595 # Save the variables. Also, disable writing the checkpoint state proto. The 596 # file is not used during SavedModel loading. In addition, since a 597 # SavedModel can be copied or moved, this avoids the checkpoint state to 598 # become outdated. 599 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 600 601 # Export the meta graph def. 602 603 # The graph almost certainly previously contained at least one Saver, and 604 # possibly several (e.g. one for loading a pretrained embedding, and another 605 # for the model weights). Removing the preexisting ones was the 606 # motivation for the clear_extraneous_savers option, but it turns out that 607 # there are edge cases where that option breaks the graph. Until that is 608 # resolved, we just leave the option set to False for now. 609 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 610 meta_graph_def = saver.export_meta_graph( 611 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 612 613 # Tag the meta graph def and add it to the SavedModel. 614 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 615 616 # Mark this instance of SavedModel as having saved variables, such that 617 # subsequent attempts to save variables will fail. 618 self._has_saved_variables = True 619 620 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace( 621 "assets_list", "assets_collection") 622 add_meta_graph_and_variables.__doc__ = \ 623 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace( 624 "assets_list", "assets_collection") 625 626 627def _maybe_save_assets(write_fn, assets_to_add=None): 628 """Saves assets to the meta graph. 629 630 Args: 631 write_fn: A function callback that writes assets into meta graph. 632 assets_to_add: The list where the asset paths are setup. 633 634 Returns: 635 A dict of asset basenames for saving to the original full path to the asset. 636 637 Raises: 638 ValueError: Indicating an invalid filepath tensor. 639 """ 640 # Map of target file names to original filenames 641 asset_filename_map = {} 642 643 if assets_to_add is None: 644 tf_logging.info("No assets to save.") 645 return asset_filename_map 646 647 # Iterate over the supplied assets, build the `AssetFile` proto and add them 648 # to the meta graph. 649 for asset_tensor in assets_to_add: 650 asset_source_filepath = _asset_path_from_tensor(asset_tensor) 651 if not asset_source_filepath: 652 raise ValueError(f"Asset filepath tensor {asset_tensor} in is invalid.") 653 654 asset_filename = get_asset_filename_to_add( 655 asset_source_filepath, asset_filename_map) 656 657 # Call the passed-in function that builds AssetFileDef proto and adds it 658 # to either the collection or asset_file_def field of the meta graph. 659 # Note that this should be done even when the file is a duplicate of an 660 # already-added file, as the tensor reference should still exist. 661 write_fn(asset_filename, asset_tensor) 662 663 # In the cases where we are adding a duplicate, this will result in the 664 # last of the filepaths being the one used for copying the file to the 665 # SavedModel. Since the files in question are the same, it doesn't matter 666 # either way. 667 asset_filename_map[asset_filename] = asset_source_filepath 668 669 tf_logging.info("Assets added to graph.") 670 return asset_filename_map 671 672 673def get_asset_filename_to_add(asset_filepath, asset_filename_map): 674 """Get a unique basename to add to the SavedModel if this file is unseen. 675 676 Assets come from users as full paths, and we save them out to the 677 SavedModel as basenames. In some cases, the basenames collide. Here, 678 we dedupe asset basenames by first checking if the file is the same, 679 and, if different, generate and return an index-suffixed basename 680 that can be used to add the asset to the SavedModel. 681 682 Args: 683 asset_filepath: the full path to the asset that is being saved 684 asset_filename_map: a dict of filenames used for saving the asset in 685 the SavedModel to full paths from which the filenames were derived. 686 687 Returns: 688 Uniquified filename string if the file is not a duplicate, or the original 689 filename if the file has already been seen and saved. 690 """ 691 asset_filename = os.path.basename(asset_filepath) 692 693 if asset_filename not in asset_filename_map: 694 # This is an unseen asset. Safe to add. 695 return asset_filename 696 697 other_asset_filepath = asset_filename_map[asset_filename] 698 if other_asset_filepath == asset_filepath: 699 # This is the same file, stored twice in the list. No need 700 # to make unique. 701 return asset_filename 702 703 # Else, asset_filename is in the map, and the filepath is different. Dedupe. 704 if not file_io.filecmp(asset_filepath, other_asset_filepath): 705 # Files are different; dedupe filenames. 706 return _get_unique_asset_filename(asset_filename, asset_filename_map) 707 708 # Files are the same; don't make unique. 709 return asset_filename 710 711 712def _get_unique_asset_filename(asset_filename, asset_filename_map): 713 i = 1 714 unique_filename = asset_filename 715 while unique_filename in asset_filename_map: 716 unique_filename = compat.as_bytes("_").join( 717 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) 718 i += 1 719 return unique_filename 720 721 722def _asset_path_from_tensor(path_tensor): 723 """Returns the filepath value stored in constant `path_tensor`. 724 725 Args: 726 path_tensor: Tensor of a file-path. 727 728 Returns: 729 The string value i.e. path of the tensor, if valid. 730 731 Raises: 732 TypeError if tensor does not match expected op type, dtype or value. 733 """ 734 if not isinstance(path_tensor, ops.Tensor): 735 raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.") 736 if path_tensor.op.type != "Const": 737 raise TypeError(f"Asset path tensor {path_tensor} must be of type constant." 738 f"Has type {path_tensor.op.type} instead.") 739 if path_tensor.dtype != dtypes.string: 740 raise TypeError(f"Asset path tensor {path_tensor}` must be of dtype string." 741 f"Has type {path_tensor.dtype} instead.") 742 str_values = path_tensor.op.get_attr("value").string_val 743 if len(str_values) != 1: 744 raise TypeError(f"Asset path tensor {path_tensor} must be a scalar.") 745 return str_values[0] 746 747 748def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor): 749 """Builds an asset proto and adds it to the meta graph def. 750 751 Args: 752 meta_graph_def: The meta graph def to which the asset will be added. 753 asset_filename: The filename of the asset to be added. 754 asset_tensor: The asset tensor used to populate the tensor info of the asset 755 proto. 756 """ 757 asset_proto = meta_graph_def.asset_file_def.add() 758 asset_proto.filename = asset_filename 759 asset_proto.tensor_info.name = asset_tensor.name 760 761 762def copy_assets_to_destination_dir(asset_filename_map, destination_dir): 763 """Copy all assets from source path to destination path.""" 764 assets_destination_dir = saved_model_utils.get_or_create_assets_dir( 765 destination_dir) 766 767 # Copy each asset from source path to destination path. 768 for asset_basename, asset_source_filepath in asset_filename_map.items(): 769 asset_destination_filepath = file_io.join( 770 compat.as_bytes(assets_destination_dir), 771 compat.as_bytes(asset_basename)) 772 773 # Only copy the asset file to the destination if it does not already 774 # exist. This is to ensure that an asset with the same name defined as 775 # part of multiple graphs is only copied the first time. 776 if not file_io.file_exists(asset_destination_filepath): 777 file_io.copy(asset_source_filepath, asset_destination_filepath) 778 779 tf_logging.info("Assets written to: %s", 780 compat.as_text(assets_destination_dir)) 781 782 783def _add_asset_to_collection(asset_filename, asset_tensor): 784 """Builds an asset proto and adds it to the asset collection of the graph. 785 786 Args: 787 asset_filename: The filename of the asset to be added. 788 asset_tensor: The asset tensor used to populate the tensor info of the 789 asset proto. 790 """ 791 asset_proto = meta_graph_pb2.AssetFileDef() 792 asset_proto.filename = asset_filename 793 asset_proto.tensor_info.name = asset_tensor.name 794 795 asset_any_proto = Any() 796 asset_any_proto.Pack(asset_proto) 797 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) 798 799 800def _add_op_to_signature_def_map(signature_def_map, op, key): 801 if op is not None: 802 signature_def_map[key] = signature_def_utils.op_signature_def(op, key) 803