1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""MetaGraph and related functions.""" 17import copy 18from packaging import version as packaging_version # pylint: disable=g-bad-import-order 19import os.path 20import re 21 22from google.protobuf.any_pb2 import Any 23from google.protobuf import text_format 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import op_def_pb2 28from tensorflow.core.protobuf import meta_graph_pb2 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.client import pywrap_tf_session as c_api 31from tensorflow.python.eager import context 32from tensorflow.python.framework import error_interpolation 33from tensorflow.python.framework import graph_io 34from tensorflow.python.framework import importer 35from tensorflow.python.framework import op_def_registry 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import versions 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util import compat 41 42 43# Prefix to be added to unbound input names so they are easily identifiable. 44_UNBOUND_INPUT_PREFIX = "$unbound_inputs_" 45 46# List of collections that didn't register proto functions, as a result in 47# a previously exported meta_graph the items are of a different data type. 48_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES, 49 ops.GraphKeys.MODEL_VARIABLES, 50 ops.GraphKeys.METRIC_VARIABLES] 51 52 53def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): 54 """Create a `NodeDef` proto with export_scope stripped. 55 56 Args: 57 from_node_def: A `node_def_pb2.NodeDef` protocol buffer. 58 export_scope: A `string` representing the name scope to remove. 59 unbound_inputs: An array of unbound input names if they exist. 60 clear_devices: Boolean which controls whether to clear device information 61 from node_def. Default false. 62 63 Returns: 64 A `node_def_pb2.NodeDef` protocol buffer. 65 """ 66 node_def = copy.deepcopy(from_node_def) 67 for i, v in enumerate(node_def.input): 68 if (export_scope and 69 not node_def.input[i].lstrip("^").startswith(export_scope)): 70 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily 71 # identifiable. 72 node_def.input[i] = re.sub(r"([\^]|^)(.*)", 73 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", 74 compat.as_str(v)) 75 unbound_inputs.append(node_def.input[i]) 76 else: 77 node_def.input[i] = ops.strip_name_scope(v, export_scope) 78 node_def.name = compat.as_bytes( 79 ops.strip_name_scope(from_node_def.name, export_scope)) 80 for k, v in from_node_def.attr.items(): 81 if k == "_class": 82 new_s = [compat.as_bytes( 83 ops.strip_name_scope(s, export_scope)) for s in v.list.s 84 if not export_scope or 85 compat.as_str(s).split("@")[1].startswith(export_scope)] 86 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( 87 list=attr_value_pb2.AttrValue.ListValue(s=new_s))) 88 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": 89 if not export_scope or compat.as_str(v.s).startswith(export_scope): 90 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) 91 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) 92 else: 93 node_def.attr[k].CopyFrom(v) 94 95 if clear_devices: 96 node_def.device = "" 97 98 return node_def 99 100 101def _read_file(filename): 102 """Reads a file containing `GraphDef` and returns the protocol buffer. 103 104 Args: 105 filename: `graph_def` filename including the path. 106 107 Returns: 108 A `GraphDef` protocol buffer. 109 110 Raises: 111 IOError: If the file doesn't exist, or cannot be successfully parsed. 112 """ 113 graph_def = graph_pb2.GraphDef() 114 if not file_io.file_exists(filename): 115 raise IOError(f"File {filename} does not exist.") 116 # First try to read it as a binary file. 117 with file_io.FileIO(filename, "rb") as f: 118 file_content = f.read() 119 try: 120 graph_def.ParseFromString(file_content) 121 return graph_def 122 except Exception: # pylint: disable=broad-except 123 pass 124 125 # Next try to read it as a text file. 126 try: 127 text_format.Merge(file_content, graph_def) 128 except text_format.ParseError as e: 129 raise IOError(f"Cannot parse file {filename}: {str(e)}.") 130 131 return graph_def 132 133 134def ops_used_by_graph_def(graph_def): 135 """Collect the list of ops used by a graph. 136 137 Does not validate that the ops are all registered. 138 139 Args: 140 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 141 142 Returns: 143 A list of strings, each naming an op used by the graph. 144 """ 145 # Map function names to definitions 146 name_to_function = {} 147 for fun in graph_def.library.function: 148 name_to_function[fun.signature.name] = fun 149 150 # Collect the list of op names. Since functions can reference functions, we 151 # need a recursive traversal. 152 used_ops = set() # Includes both primitive ops and functions 153 functions_to_process = [] # A subset of used_ops 154 155 def mark_op_as_used(op): 156 if op not in used_ops and op in name_to_function: 157 functions_to_process.append(name_to_function[op]) 158 used_ops.add(op) 159 160 def process_node(node): 161 mark_op_as_used(node.op) 162 if node.op in ["PartitionedCall", "StatefulPartitionedCall"]: 163 mark_op_as_used(node.attr["f"].func.name) 164 165 for node in graph_def.node: 166 process_node(node) 167 while functions_to_process: 168 fun = functions_to_process.pop() 169 for node in fun.node_def: 170 process_node(node) 171 172 return [op for op in used_ops if op not in name_to_function] 173 174 175def stripped_op_list_for_graph(graph_def): 176 """Collect the stripped OpDefs for ops used by a graph. 177 178 This function computes the `stripped_op_list` field of `MetaGraphDef` and 179 similar protos. The result can be communicated from the producer to the 180 consumer, which can then use the C++ function 181 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. 182 183 Args: 184 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 185 186 Returns: 187 An `OpList` of ops used by the graph. 188 """ 189 # This is similar to StrippedOpListForGraph in C++, but unlike its 190 # C++ counterpart, this version does not require all ops to be registered. 191 # This is done to support Prelu fusion in tfjs. 192 used_ops = ops_used_by_graph_def(graph_def) 193 op_defs = [] 194 for op in sorted(used_ops): 195 op_def = op_def_registry.get(op) 196 if op_def is not None: 197 op_defs.append(op_def) 198 return op_def_pb2.OpList(op=op_defs) 199 200 201def _get_kind_name(item): 202 """Returns the kind name in CollectionDef. 203 204 Args: 205 item: A data item. 206 207 Returns: 208 The string representation of the kind in CollectionDef. 209 """ 210 if isinstance(item, (str, bytes)): 211 kind = "bytes_list" 212 elif isinstance(item, int): 213 kind = "int64_list" 214 elif isinstance(item, float): 215 kind = "float_list" 216 elif isinstance(item, Any): 217 kind = "any_list" 218 else: 219 kind = "node_list" 220 return kind 221 222 223SAVE_AND_RESTORE_OPS = ["SaveV2", 224 "Save", "SaveSlice", 225 "LegacySave", "LegacySaveSlice", 226 "RestoreV2", 227 "Restore", "RestoreSlice", 228 "LegacyRestore", "LegacyRestoreSlice"] 229 230 231def _op_name(tensor_name): 232 """Extract the Op name from a Tensor name. 233 234 The Op name is everything before a colon, if present, 235 not including any ^ prefix denoting a control dependency. 236 237 Args: 238 tensor_name: the full name of a Tensor in the graph. 239 Returns: 240 The name of the Op of which the given Tensor is an output. 241 Raises: 242 ValueError: if tensor_name is None or empty. 243 """ 244 if not tensor_name: 245 raise ValueError( 246 f"Tensor name cannot be empty or None. Received: {tensor_name}.") 247 248 # Control dependency inputs start with ^. 249 if tensor_name.startswith("^"): 250 tensor_name = tensor_name[1:] 251 if ":" in tensor_name: 252 op_name, _ = tensor_name.split(":") 253 return op_name 254 return tensor_name 255 256 257def _get_scope(node_name): 258 """Extract the scope name from a node name. 259 260 The scope name is everything before the final slash, 261 not including any ^ prefix denoting a control dependency. 262 263 Args: 264 node_name: the full name of an Op or a Tensor in the graph. 265 Returns: 266 The deepest named scope containing the node. 267 Raises: 268 ValueError: if tensor_name is None or empty 269 """ 270 if not node_name: 271 raise ValueError( 272 f"Node name cannot be empty or None. Received: {node_name}.") 273 274 # Control dependency inputs start with ^. 275 if node_name.startswith("^"): 276 node_name = node_name[1:] 277 if "/" in node_name: 278 scope, _ = node_name.rsplit("/", 1) 279 return scope 280 281 return "" 282 283 284def _find_extraneous_saver_nodes(graph_def, saver_def): 285 """Identifies any nodes in the graph_def related to unused Savers. 286 287 This approach assumes that each Saver is cleanly isolated in its own name 288 scope, so we need only identify the scopes associated with extraneous Savers 289 and return all the nodes in those scopes. 290 291 Args: 292 graph_def: a GraphDef proto to evaluate. 293 saver_def: a SaverDef proto referencing Save/Restore ops to be retained. 294 Returns: 295 An iterable of node names that may be safely omitted. 296 """ 297 # TODO(soergel): confirm that the assumption of scope isolation is valid. 298 # If not, we need to walk up the graph from any restore_all nodes, and walk 299 # down the graph from any Save/Restore nodes. I drafted that approach too, 300 # but it seems unnecessarily complex given the name scope solution. 301 302 # load the graph DAG in minimal form, without initializing a full Graph object 303 nodes = { 304 node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op) 305 for node_def in graph_def.node 306 } 307 308 retain_scope_save = None 309 retain_scope_restore = None 310 # It's possible to have no saver if the graph has no Variables 311 if saver_def is not None: 312 save_op_name = _op_name(saver_def.save_tensor_name) 313 restore_op_name = _op_name(saver_def.restore_op_name) 314 315 # The save and restore scopes should always be the same, but if they differ 316 # for some reason, we retain them both to be safe. 317 retain_scope_restore = _get_scope(restore_op_name) + "/" 318 retain_scope_save = _get_scope(save_op_name) + "/" 319 320 all_saver_node_names = set( 321 name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS) 322 323 all_saver_scopes = ( 324 set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names) 325 all_saver_scopes = set(x + "/" for x in all_saver_scopes) 326 327 extraneous_scopes = all_saver_scopes - set([retain_scope_save, 328 retain_scope_restore]) 329 330 extraneous_node_names = set() 331 for name, _ in nodes.items(): 332 for extraneous_scope in extraneous_scopes: 333 if name.startswith(extraneous_scope): 334 extraneous_node_names.add(name) 335 break 336 337 return extraneous_node_names 338 339 340def _should_include_node(node_or_node_name, export_scope, exclude_nodes): 341 """Returns `True` if a node should be included. 342 343 Args: 344 node_or_node_name: A node or `string` node name. 345 export_scope: `string`. Name scope under which to extract the subgraph. The 346 scope name will be stripped from the node definitions for easy import 347 later into new name scopes. 348 exclude_nodes: An iterable of nodes or `string` node names to omit from the 349 export, or None. Note no sanity-checking is done, so this list must be 350 carefully constructed to avoid producing an invalid graph. 351 352 Returns: 353 `True` if the node should be included. 354 """ 355 if not isinstance(node_or_node_name, str): 356 try: 357 node_name = node_or_node_name.name 358 except AttributeError: 359 # Keep the object that we don't know how to process. 360 return True 361 else: 362 node_name = node_or_node_name 363 364 if exclude_nodes and (node_or_node_name in exclude_nodes 365 or node_name in exclude_nodes): 366 return False 367 368 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or 369 (not export_scope or node_name.startswith(export_scope))) 370 371 372def add_collection_def(meta_graph_def, key, graph=None, 373 export_scope=None, exclude_nodes=None, 374 override_contents=None): 375 """Adds a collection to MetaGraphDef protocol buffer. 376 377 Args: 378 meta_graph_def: MetaGraphDef protocol buffer. 379 key: One of the GraphKeys or user-defined string. 380 graph: The `Graph` from which to get collections. 381 export_scope: Optional `string`. Name scope to remove. 382 exclude_nodes: An iterable of nodes or `string` node names to omit from the 383 collection, or None. 384 override_contents: An iterable of values to place in the collection, 385 ignoring the current values (if set). 386 """ 387 if graph and not isinstance(graph, ops.Graph): 388 raise TypeError( 389 f"graph must be of type Graph. Received type: {type(graph)}.") 390 391 if not isinstance(key, str) and not isinstance(key, bytes): 392 logging.warning("Only collections with string type keys will be " 393 "serialized. This key has %s", type(key)) 394 return 395 396 # Sets graph to default graph if it's not passed in. 397 graph = graph or ops.get_default_graph() 398 399 if override_contents: 400 collection_list = override_contents 401 else: 402 collection_list = graph.get_collection(key) 403 404 # Remove nodes that should not be exported from the collection list. 405 collection_list = [x for x in collection_list if 406 _should_include_node(x, export_scope, exclude_nodes)] 407 if not collection_list: 408 return 409 410 try: 411 col_def = meta_graph_def.collection_def[key] 412 to_proto = ops.get_to_proto_function(key) 413 proto_type = ops.get_collection_proto_type(key) 414 if to_proto: 415 kind = "bytes_list" 416 for x in collection_list: 417 # Additional type check to make sure the returned proto is indeed 418 # what we expect. 419 proto = to_proto(x, export_scope=export_scope) 420 if proto: 421 assert isinstance(proto, proto_type) 422 getattr(col_def, kind).value.append(proto.SerializeToString()) 423 else: 424 kind = _get_kind_name(collection_list[0]) 425 if kind == "node_list": 426 for x in collection_list: 427 if not export_scope or x.name.startswith(export_scope): 428 getattr(col_def, kind).value.append( 429 ops.strip_name_scope(x.name, export_scope)) 430 elif kind == "bytes_list": 431 # NOTE(opensource): This force conversion is to work around the fact 432 # that Python3 distinguishes between bytes and strings. 433 getattr(col_def, kind).value.extend( 434 [compat.as_bytes(x) for x in collection_list]) 435 else: 436 getattr(col_def, kind).value.extend([x for x in collection_list]) 437 except Exception as e: # pylint: disable=broad-except 438 logging.warning("Issue encountered when serializing %s.\n" 439 "Type is unsupported, or the types of the items don't " 440 "match field type in CollectionDef. Note this is a warning " 441 "and probably safe to ignore.\n%s", key, str(e)) 442 if key in meta_graph_def.collection_def: 443 del meta_graph_def.collection_def[key] 444 return 445 446 447def _is_default_attr_value(op_def, attr_name, attr_value): 448 """Checks if given attribute matches the default value in the op def.""" 449 for attr_def in op_def.attr: 450 if attr_def.name == attr_name: 451 if not attr_def.HasField("default_value"): 452 return False 453 # c_api.EqualAttrValueWrapper returns an empty string 454 # if both arguments represent an equivalent AttrValue instance. 455 return not c_api.EqualAttrValueWrapper( 456 attr_value.SerializeToString(), 457 attr_def.default_value.SerializeToString()) 458 return False 459 460 461def strip_graph_default_valued_attrs(meta_graph_def): 462 """Strips default valued attributes for node defs in given MetaGraphDef. 463 464 This method also sets `meta_info_def.stripped_default_attrs` in the given 465 `MetaGraphDef` proto to True. 466 467 Args: 468 meta_graph_def: `MetaGraphDef` protocol buffer 469 470 Returns: 471 None. 472 """ 473 # Map function op names to their function definitions. 474 op_name_to_function = {} 475 for function_def in meta_graph_def.graph_def.library.function: 476 op_name_to_function[function_def.signature.name] = function_def 477 478 def _strip_node_default_valued_attrs(node_def): 479 """Removes default valued attributes from a single node def.""" 480 if node_def.op in op_name_to_function: 481 return 482 483 op_def = op_def_registry.get(node_def.op) 484 if op_def is None: 485 return 486 487 attrs_to_strip = set() 488 for attr_name, attr_value in node_def.attr.items(): 489 if _is_default_attr_value(op_def, attr_name, attr_value): 490 attrs_to_strip.add(attr_name) 491 492 for attr in attrs_to_strip: 493 del node_def.attr[attr] 494 495 # Process all NodeDef instances in graph_def. 496 for node_def in meta_graph_def.graph_def.node: 497 _strip_node_default_valued_attrs(node_def) 498 499 # Process all NodeDef instances in graph_def.library.function. 500 for function_def in meta_graph_def.graph_def.library.function: 501 for function_node_def in function_def.node_def: 502 _strip_node_default_valued_attrs(function_node_def) 503 504 # Tell consumers of this graph that default valued attrs have been stripped. 505 meta_graph_def.meta_info_def.stripped_default_attrs = True 506 507 508def create_meta_graph_def(meta_info_def=None, 509 graph_def=None, 510 saver_def=None, 511 collection_list=None, 512 graph=None, 513 export_scope=None, 514 exclude_nodes=None, 515 clear_extraneous_savers=False, 516 strip_default_attrs=False): 517 # pylint: disable=line-too-long 518 """Construct and returns a `MetaGraphDef` protocol buffer. 519 520 Args: 521 meta_info_def: `MetaInfoDef` protocol buffer. 522 graph_def: `GraphDef` protocol buffer. 523 saver_def: `SaverDef` protocol buffer. 524 collection_list: List of string keys to collect. 525 graph: The `Graph` to create `MetaGraphDef` out of. 526 export_scope: Optional `string`. Name scope to remove. 527 exclude_nodes: An iterable of nodes or `string` node names to omit from all 528 collection, or None. 529 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS 530 collection. Note this method does not alter the graph, so any 531 extraneous Save/Restore ops should have been removed already, as needed. 532 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 533 removed from the NodeDefs. For a detailed guide, see 534 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 535 536 Returns: 537 MetaGraphDef protocol buffer. 538 539 Raises: 540 TypeError: If the arguments are not of the correct proto buffer type. 541 """ 542 # pylint: enable=line-too-long 543 # Type check. 544 if graph and not isinstance(graph, ops.Graph): 545 raise TypeError( 546 f"graph must be of type Graph. Received type: {type(graph)}.") 547 if meta_info_def and not isinstance(meta_info_def, 548 meta_graph_pb2.MetaGraphDef.MetaInfoDef): 549 raise TypeError( 550 "meta_info_def must be of type MetaInfoDef. " 551 f"Received type: {type(meta_info_def)}.") 552 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): 553 raise TypeError( 554 "graph_def must be of type GraphDef. " 555 f"Received type: {type(graph_def)}.") 556 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): 557 raise TypeError( 558 f"saver_def must be of type SaverDef. " 559 f"Received type: {type(saver_def)}.") 560 561 # Sets graph to default graph if it's not passed in. 562 graph = graph or ops.get_default_graph() 563 564 # Creates a MetaGraphDef proto. 565 meta_graph_def = meta_graph_pb2.MetaGraphDef() 566 # Adds meta_info_def. 567 if not meta_info_def: 568 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 569 570 # Set the tf version strings to the current tf build. 571 meta_info_def.tensorflow_version = versions.__version__ 572 meta_info_def.tensorflow_git_version = versions.__git_version__ 573 meta_graph_def.meta_info_def.MergeFrom(meta_info_def) 574 575 # Adds graph_def or the default. 576 if not graph_def: 577 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) 578 else: 579 meta_graph_def.graph_def.MergeFrom(graph_def) 580 581 # Fills in meta_info_def.stripped_op_list using the ops from graph_def. 582 # pylint: disable=g-explicit-length-test 583 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: 584 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 585 stripped_op_list_for_graph(meta_graph_def.graph_def)) 586 # pylint: enable=g-explicit-length-test 587 588 # Strip default valued attributes in graph_def. 589 if strip_default_attrs: 590 strip_graph_default_valued_attrs(meta_graph_def) 591 592 # Adds saver_def. 593 if saver_def: 594 meta_graph_def.saver_def.MergeFrom(saver_def) 595 596 # Adds collection_list. 597 if collection_list is not None: 598 clist = collection_list 599 else: 600 clist = graph.get_all_collection_keys() 601 602 for ctype in clist: 603 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: 604 # Avoid importing Saver here 605 from_proto = ops.get_from_proto_function(ctype) 606 add_collection_def(meta_graph_def, ctype, 607 graph=graph, 608 export_scope=export_scope, 609 exclude_nodes=exclude_nodes, 610 override_contents=[from_proto(saver_def)]) 611 else: 612 add_collection_def(meta_graph_def, ctype, 613 graph=graph, 614 export_scope=export_scope, 615 exclude_nodes=exclude_nodes) 616 return meta_graph_def 617 618 619def read_meta_graph_file(filename): 620 """Reads a file containing `MetaGraphDef` and returns the protocol buffer. 621 622 Args: 623 filename: `meta_graph_def` filename including the path. 624 625 Returns: 626 A `MetaGraphDef` protocol buffer. 627 628 Raises: 629 IOError: If the file doesn't exist, or cannot be successfully parsed. 630 """ 631 meta_graph_def = meta_graph_pb2.MetaGraphDef() 632 if not file_io.file_exists(filename): 633 raise IOError(f"File does not exist. Received: {filename}.") 634 # First try to read it as a binary file. 635 with file_io.FileIO(filename, "rb") as f: 636 file_content = f.read() 637 try: 638 meta_graph_def.ParseFromString(file_content) 639 return meta_graph_def 640 except Exception: # pylint: disable=broad-except 641 pass 642 643 # Next try to read it as a text file. 644 try: 645 text_format.Merge(file_content.decode("utf-8"), meta_graph_def) 646 except text_format.ParseError as e: 647 raise IOError(f"Cannot parse file {filename}: {str(e)}.") 648 649 return meta_graph_def 650 651 652def import_scoped_meta_graph(meta_graph_or_file, 653 clear_devices=False, 654 graph=None, 655 import_scope=None, 656 input_map=None, 657 unbound_inputs_col_name="unbound_inputs", 658 restore_collections_predicate=(lambda key: True)): 659 """Recreates a `Graph` saved in a `MetaGraphDef` proto. 660 661 This function takes a `MetaGraphDef` protocol buffer as input. If 662 the argument is a file containing a `MetaGraphDef` protocol buffer , 663 it constructs a protocol buffer from the file content. The function 664 then adds all the nodes from the `graph_def` field to the 665 current graph, recreates the desired collections, and returns a dictionary of 666 all the Variables imported into the name scope. 667 668 In combination with `export_scoped_meta_graph()`, this function can be used to 669 670 * Serialize a graph along with other Python objects such as `QueueRunner`, 671 `Variable` into a `MetaGraphDef`. 672 673 * Restart training from a saved graph and checkpoints. 674 675 * Run inference from a saved graph and checkpoints. 676 677 Args: 678 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 679 the path) containing a `MetaGraphDef`. 680 clear_devices: Boolean which controls whether to clear device information 681 from graph_def. Default false. 682 graph: The `Graph` to import into. If `None`, use the default graph. 683 import_scope: Optional `string`. Name scope into which to import the 684 subgraph. If `None`, the graph is imported to the root name scope. 685 input_map: A dictionary mapping input names (as strings) in `graph_def` to 686 `Tensor` objects. The values of the named input tensors in the imported 687 graph will be re-mapped to the respective `Tensor` values. 688 unbound_inputs_col_name: Collection name for looking up unbound inputs. 689 restore_collections_predicate: a predicate on collection names. A collection 690 named c (i.e whose key is c) will be restored iff 691 1) `restore_collections_predicate(c)` is True, and 692 2) `c != unbound_inputs_col_name`. 693 694 Returns: 695 A dictionary of all the `Variables` imported into the name scope. 696 697 Raises: 698 ValueError: If the graph_def contains unbound inputs. 699 """ 700 return import_scoped_meta_graph_with_return_elements( 701 meta_graph_or_file, clear_devices, graph, import_scope, input_map, 702 unbound_inputs_col_name, restore_collections_predicate)[0] 703 704 705def import_scoped_meta_graph_with_return_elements( 706 meta_graph_or_file, 707 clear_devices=False, 708 graph=None, 709 import_scope=None, 710 input_map=None, 711 unbound_inputs_col_name="unbound_inputs", 712 restore_collections_predicate=(lambda key: True), 713 return_elements=None): 714 """Imports graph from `MetaGraphDef` and returns vars and return elements. 715 716 This function takes a `MetaGraphDef` protocol buffer as input. If 717 the argument is a file containing a `MetaGraphDef` protocol buffer , 718 it constructs a protocol buffer from the file content. The function 719 then adds all the nodes from the `graph_def` field to the 720 current graph, recreates the desired collections, and returns a dictionary of 721 all the Variables imported into the name scope. 722 723 In combination with `export_scoped_meta_graph()`, this function can be used to 724 725 * Serialize a graph along with other Python objects such as `QueueRunner`, 726 `Variable` into a `MetaGraphDef`. 727 728 * Restart training from a saved graph and checkpoints. 729 730 * Run inference from a saved graph and checkpoints. 731 732 Args: 733 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 734 the path) containing a `MetaGraphDef`. 735 clear_devices: Boolean which controls whether to clear device information 736 from graph_def. Default false. 737 graph: The `Graph` to import into. If `None`, use the default graph. 738 import_scope: Optional `string`. Name scope into which to import the 739 subgraph. If `None`, the graph is imported to the root name scope. 740 input_map: A dictionary mapping input names (as strings) in `graph_def` to 741 `Tensor` objects. The values of the named input tensors in the imported 742 graph will be re-mapped to the respective `Tensor` values. 743 unbound_inputs_col_name: Collection name for looking up unbound inputs. 744 restore_collections_predicate: a predicate on collection names. A collection 745 named c (i.e whose key is c) will be restored iff 746 1) `restore_collections_predicate(c)` is True, and 747 2) `c != unbound_inputs_col_name`. 748 return_elements: A list of strings containing operation names in the 749 `MetaGraphDef` that will be returned as `Operation` objects; and/or 750 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. 751 752 Returns: 753 A tuple of ( 754 dictionary of all the `Variables` imported into the name scope, 755 list of `Operation` or `Tensor` objects from the `return_elements` list). 756 757 Raises: 758 ValueError: If the graph_def contains unbound inputs. 759 760 """ 761 if context.executing_eagerly(): 762 raise ValueError("Exporting/importing meta graphs is not supported when " 763 "eager execution is enabled.") 764 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 765 meta_graph_def = meta_graph_or_file 766 else: 767 meta_graph_def = read_meta_graph_file(meta_graph_or_file) 768 769 if unbound_inputs_col_name: 770 for key, col_def in meta_graph_def.collection_def.items(): 771 if key == unbound_inputs_col_name: 772 kind = col_def.WhichOneof("kind") 773 field = getattr(col_def, kind) 774 if field.value and ( 775 not input_map or 776 sorted([compat.as_str(v) for v in field.value]) != 777 sorted(input_map)): 778 raise ValueError("Graph contains unbound inputs: %s. Must " 779 "provide these inputs through input_map." % ",".join( 780 compat.as_str(v) 781 for v in field.value 782 if not input_map or v not in input_map)) 783 break 784 785 # Sets graph to default graph if it's not passed in. 786 graph = graph or ops.get_default_graph() 787 788 # Gathers the list of nodes we are interested in. 789 with graph.as_default(): 790 producer_op_list = None 791 if meta_graph_def.meta_info_def.HasField("stripped_op_list"): 792 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list 793 input_graph_def = meta_graph_def.graph_def 794 # Remove all the explicit device specifications for this node. This helps to 795 # make the graph more portable. 796 if clear_devices: 797 for node in input_graph_def.node: 798 node.device = "" 799 800 scope_to_prepend_to_names = graph.unique_name( 801 import_scope or "", mark_as_used=False) 802 803 imported_return_elements = importer.import_graph_def( 804 input_graph_def, 805 name=(import_scope or scope_to_prepend_to_names), 806 input_map=input_map, 807 producer_op_list=producer_op_list, 808 return_elements=return_elements) 809 810 # TensorFlow versions before 1.9 (not inclusive) exported SavedModels 811 # without a VariableDef.trainable field set. 812 tf_version = meta_graph_def.meta_info_def.tensorflow_version 813 if not tf_version: 814 variables_have_trainable = True 815 else: 816 variables_have_trainable = ( 817 packaging_version.parse(tf_version) >= packaging_version.parse("1.9")) 818 819 # Sort collections so we see TRAINABLE_VARIABLES first and can default these 820 # variables to trainable if the value is not set in their VariableDef. 821 sorted_collections = [] 822 if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: 823 sorted_collections.append( 824 (ops.GraphKeys.TRAINABLE_VARIABLES, 825 meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES])) 826 for key, value in sorted(meta_graph_def.collection_def.items()): 827 if key != ops.GraphKeys.TRAINABLE_VARIABLES: 828 sorted_collections.append((key, value)) 829 830 # Restores all the other collections. 831 variable_objects = {} 832 for key, col_def in sorted_collections: 833 # Don't add unbound_inputs to the new graph. 834 if key == unbound_inputs_col_name: 835 continue 836 if not restore_collections_predicate(key): 837 continue 838 839 kind = col_def.WhichOneof("kind") 840 if kind is None: 841 logging.error("Cannot identify data type for collection %s. Skipping.", 842 key) 843 continue 844 from_proto = ops.get_from_proto_function(key) 845 846 # Temporary change to allow the TFMA evaluator to read metric variables 847 # saved as a bytes list. 848 # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. 849 if key == ops.GraphKeys.METRIC_VARIABLES: 850 # Metric variables will use the same proto functions as GLOBAL_VARIABLES 851 from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES) 852 if from_proto and kind == "bytes_list": 853 proto_type = ops.get_collection_proto_type(key) 854 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 855 for value in col_def.bytes_list.value: 856 variable = variable_objects.get(value, None) 857 if variable is None: 858 proto = proto_type() 859 proto.ParseFromString(value) 860 if not variables_have_trainable: 861 # If the VariableDef proto does not contain a "trainable" 862 # property because it was exported before that property was 863 # added, we default it to whether the variable is in the 864 # TRAINABLE_VARIABLES collection. We've sorted 865 # TRAINABLE_VARIABLES to be first, so trainable variables will 866 # be created from that collection. 867 proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES) 868 variable = from_proto( 869 proto, import_scope=scope_to_prepend_to_names) 870 variable_objects[value] = variable 871 graph.add_to_collection(key, variable) 872 else: 873 for value in col_def.bytes_list.value: 874 proto = proto_type() 875 proto.ParseFromString(value) 876 graph.add_to_collection( 877 key, from_proto( 878 proto, import_scope=scope_to_prepend_to_names)) 879 else: 880 field = getattr(col_def, kind) 881 if key in _COMPAT_COLLECTION_LIST: 882 logging.warning( 883 "The saved meta_graph is possibly from an older release:\n" 884 "'%s' collection should be of type 'byte_list', but instead " 885 "is of type '%s'.", key, kind) 886 if kind == "node_list": 887 for value in field.value: 888 col_op = graph.as_graph_element( 889 ops.prepend_name_scope(value, scope_to_prepend_to_names)) 890 graph.add_to_collection(key, col_op) 891 elif kind == "int64_list": 892 # NOTE(opensource): This force conversion is to work around the fact 893 # that Python2 distinguishes between int and long, while Python3 has 894 # only int. 895 for value in field.value: 896 graph.add_to_collection(key, int(value)) 897 else: 898 for value in field.value: 899 graph.add_to_collection( 900 key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) 901 902 var_list = {} 903 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 904 scope=scope_to_prepend_to_names) 905 for v in variables: 906 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v 907 908 return var_list, imported_return_elements 909 910 911def export_scoped_meta_graph(filename=None, 912 graph_def=None, 913 graph=None, 914 export_scope=None, 915 as_text=False, 916 unbound_inputs_col_name="unbound_inputs", 917 clear_devices=False, 918 saver_def=None, 919 clear_extraneous_savers=False, 920 strip_default_attrs=False, 921 save_debug_info=False, 922 **kwargs): 923 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 924 925 This function exports the graph, saver, and collection objects into 926 `MetaGraphDef` protocol buffer with the intention of it being imported 927 at a later time or location to restart training, run inference, or be 928 a subgraph. 929 930 Args: 931 filename: Optional filename including the path for writing the 932 generated `MetaGraphDef` protocol buffer. 933 graph_def: `GraphDef` protocol buffer. 934 graph: The `Graph` to export. If `None`, use the default graph. 935 export_scope: Optional `string`. Name scope under which to extract 936 the subgraph. The scope name will be stripped from the node definitions 937 for easy import later into new name scopes. If `None`, the whole graph 938 is exported. 939 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 940 unbound_inputs_col_name: Optional `string`. If provided, a string collection 941 with the given name will be added to the returned `MetaGraphDef`, 942 containing the names of tensors that must be remapped when importing the 943 `MetaGraphDef`. 944 clear_devices: Boolean which controls whether to clear device information 945 before exporting the graph. 946 saver_def: `SaverDef` protocol buffer. 947 clear_extraneous_savers: Remove any Saver-related information from the 948 graph (both Save/Restore ops and SaverDefs) that are not associated 949 with the provided SaverDef. 950 strip_default_attrs: Set to true if default valued attributes must be 951 removed while exporting the GraphDef. 952 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 953 which in the same directory of filename and with `_debug` added before the 954 file extension. 955 **kwargs: Optional keyed arguments, including meta_info_def and 956 collection_list. 957 958 Returns: 959 A `MetaGraphDef` proto and dictionary of `Variables` in the exported 960 name scope. 961 962 Raises: 963 ValueError: When the `GraphDef` is larger than 2GB. 964 ValueError: When executing in Eager mode and either `graph_def` or `graph` 965 is undefined. 966 """ 967 if context.executing_eagerly() and not (graph_def is not None and 968 graph is not None): 969 raise ValueError("Exporting/importing meta graphs is not supported when " 970 "Eager Execution is enabled.") 971 graph = graph or ops.get_default_graph() 972 973 exclude_nodes = None 974 unbound_inputs = [] 975 if export_scope or clear_extraneous_savers or clear_devices: 976 if graph_def: 977 new_graph_def = graph_pb2.GraphDef() 978 new_graph_def.versions.CopyFrom(graph_def.versions) 979 new_graph_def.library.CopyFrom(graph_def.library) 980 981 if clear_extraneous_savers: 982 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) 983 984 for node_def in graph_def.node: 985 if _should_include_node(node_def.name, export_scope, exclude_nodes): 986 new_node_def = _node_def(node_def, export_scope, unbound_inputs, 987 clear_devices=clear_devices) 988 new_graph_def.node.extend([new_node_def]) 989 graph_def = new_graph_def 990 else: 991 # Only do this complicated work if we want to remove a name scope. 992 graph_def = graph_pb2.GraphDef() 993 # pylint: disable=protected-access 994 graph_def.versions.CopyFrom(graph.graph_def_versions) 995 bytesize = 0 996 997 if clear_extraneous_savers: 998 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), 999 saver_def) 1000 1001 for key in sorted(graph._nodes_by_id): 1002 if _should_include_node(graph._nodes_by_id[key].name, 1003 export_scope, 1004 exclude_nodes): 1005 value = graph._nodes_by_id[key] 1006 # pylint: enable=protected-access 1007 node_def = _node_def(value.node_def, export_scope, unbound_inputs, 1008 clear_devices=clear_devices) 1009 graph_def.node.extend([node_def]) 1010 if value.outputs: 1011 assert "_output_shapes" not in graph_def.node[-1].attr 1012 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ 1013 output.get_shape().as_proto() for output in value.outputs]) 1014 bytesize += value.node_def.ByteSize() 1015 if bytesize >= (1 << 31) or bytesize < 0: 1016 raise ValueError( 1017 "GraphDef cannot be larger than 2GB. " 1018 f"Received size: {bytesize}.") 1019 1020 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access 1021 1022 # It's possible that not all the inputs are in the export_scope. 1023 # If we would like such information included in the exported meta_graph, 1024 # add them to a special unbound_inputs collection. 1025 if unbound_inputs_col_name: 1026 # Clears the unbound_inputs collections. 1027 graph.clear_collection(unbound_inputs_col_name) 1028 for k in unbound_inputs: 1029 graph.add_to_collection(unbound_inputs_col_name, k) 1030 1031 var_list = {} 1032 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 1033 scope=export_scope) 1034 for v in variables: 1035 if _should_include_node(v, export_scope, exclude_nodes): 1036 var_list[ops.strip_name_scope(v.name, export_scope)] = v 1037 1038 scoped_meta_graph_def = create_meta_graph_def( 1039 graph_def=graph_def, 1040 graph=graph, 1041 export_scope=export_scope, 1042 exclude_nodes=exclude_nodes, 1043 clear_extraneous_savers=clear_extraneous_savers, 1044 saver_def=saver_def, 1045 strip_default_attrs=strip_default_attrs, 1046 **kwargs) 1047 1048 if filename: 1049 graph_io.write_graph( 1050 scoped_meta_graph_def, 1051 os.path.dirname(filename), 1052 os.path.basename(filename), 1053 as_text=as_text) 1054 if save_debug_info: 1055 name, _ = os.path.splitext(filename) 1056 debug_filename = "{name}{ext}".format(name=name, ext=".debug") 1057 1058 # Gets the operation from the graph by the name. Excludes variable nodes, 1059 # so only the nodes in the frozen models are included. 1060 # TODO(liufengdb): fix this for functions. 1061 ops_to_export = [] 1062 for node in scoped_meta_graph_def.graph_def.node: 1063 scoped_op_name = ops.prepend_name_scope(node.name, export_scope) 1064 ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name))) 1065 1066 graph_debug_info = error_interpolation.create_graph_debug_info_def( 1067 ops_to_export) 1068 1069 graph_io.write_graph( 1070 graph_debug_info, 1071 os.path.dirname(debug_filename), 1072 os.path.basename(debug_filename), 1073 as_text=as_text) 1074 1075 return scoped_meta_graph_def, var_list 1076 1077 1078def copy_scoped_meta_graph(from_scope, to_scope, 1079 from_graph=None, to_graph=None): 1080 """Copies a sub-meta_graph from one scope to another. 1081 1082 Args: 1083 from_scope: `String` name scope containing the subgraph to be copied. 1084 to_scope: `String` name scope under which the copied subgraph will reside. 1085 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the 1086 default graph is use. 1087 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the 1088 default graph is used. 1089 1090 Returns: 1091 A dictionary of `Variables` that has been copied into `to_scope`. 1092 1093 Raises: 1094 ValueError: If `from_scope` and `to_scope` are the same while 1095 `from_graph` and `to_graph` are also the same. 1096 """ 1097 from_graph = from_graph or ops.get_default_graph() 1098 to_graph = to_graph or ops.get_default_graph() 1099 1100 if from_graph == to_graph and from_scope == to_scope: 1101 raise ValueError("'from_scope' and 'to_scope' need to be different " 1102 "when performing copy in the same graph. " 1103 f"Received: 'from_graph': {from_graph}, " 1104 f"'to_graph': {to_graph}, " 1105 f"'from_scope': {from_scope}, 'to_scope': {to_scope}.") 1106 1107 orig_meta_graph, var_list = export_scoped_meta_graph( 1108 export_scope=from_scope, graph=from_graph) 1109 var_list = import_scoped_meta_graph(orig_meta_graph, 1110 graph=to_graph, 1111 import_scope=to_scope) 1112 return var_list 1113