xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/meta_graph.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""MetaGraph and related functions."""
17import copy
18from packaging import version as packaging_version  # pylint: disable=g-bad-import-order
19import os.path
20import re
21
22from google.protobuf.any_pb2 import Any
23from google.protobuf import text_format
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import op_def_pb2
28from tensorflow.core.protobuf import meta_graph_pb2
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.client import pywrap_tf_session as c_api
31from tensorflow.python.eager import context
32from tensorflow.python.framework import error_interpolation
33from tensorflow.python.framework import graph_io
34from tensorflow.python.framework import importer
35from tensorflow.python.framework import op_def_registry
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import versions
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import compat
41
42
43# Prefix to be added to unbound input names so they are easily identifiable.
44_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
45
46# List of collections that didn't register proto functions, as a result in
47# a previously exported meta_graph the items are of a different data type.
48_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
49                           ops.GraphKeys.MODEL_VARIABLES,
50                           ops.GraphKeys.METRIC_VARIABLES]
51
52
53def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
54  """Create a `NodeDef` proto with export_scope stripped.
55
56  Args:
57    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
58    export_scope: A `string` representing the name scope to remove.
59    unbound_inputs: An array of unbound input names if they exist.
60    clear_devices: Boolean which controls whether to clear device information
61      from node_def. Default false.
62
63  Returns:
64    A `node_def_pb2.NodeDef` protocol buffer.
65  """
66  node_def = copy.deepcopy(from_node_def)
67  for i, v in enumerate(node_def.input):
68    if (export_scope and
69        not node_def.input[i].lstrip("^").startswith(export_scope)):
70      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
71      # identifiable.
72      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
73                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
74                                 compat.as_str(v))
75      unbound_inputs.append(node_def.input[i])
76    else:
77      node_def.input[i] = ops.strip_name_scope(v, export_scope)
78  node_def.name = compat.as_bytes(
79      ops.strip_name_scope(from_node_def.name, export_scope))
80  for k, v in from_node_def.attr.items():
81    if k == "_class":
82      new_s = [compat.as_bytes(
83          ops.strip_name_scope(s, export_scope)) for s in v.list.s
84               if not export_scope or
85               compat.as_str(s).split("@")[1].startswith(export_scope)]
86      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
87          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
88    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
89      if not export_scope or compat.as_str(v.s).startswith(export_scope):
90        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
91      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
92    else:
93      node_def.attr[k].CopyFrom(v)
94
95  if clear_devices:
96    node_def.device = ""
97
98  return node_def
99
100
101def _read_file(filename):
102  """Reads a file containing `GraphDef` and returns the protocol buffer.
103
104  Args:
105    filename: `graph_def` filename including the path.
106
107  Returns:
108    A `GraphDef` protocol buffer.
109
110  Raises:
111    IOError: If the file doesn't exist, or cannot be successfully parsed.
112  """
113  graph_def = graph_pb2.GraphDef()
114  if not file_io.file_exists(filename):
115    raise IOError(f"File {filename} does not exist.")
116  # First try to read it as a binary file.
117  with file_io.FileIO(filename, "rb") as f:
118    file_content = f.read()
119  try:
120    graph_def.ParseFromString(file_content)
121    return graph_def
122  except Exception:  # pylint: disable=broad-except
123    pass
124
125  # Next try to read it as a text file.
126  try:
127    text_format.Merge(file_content, graph_def)
128  except text_format.ParseError as e:
129    raise IOError(f"Cannot parse file {filename}: {str(e)}.")
130
131  return graph_def
132
133
134def ops_used_by_graph_def(graph_def):
135  """Collect the list of ops used by a graph.
136
137  Does not validate that the ops are all registered.
138
139  Args:
140    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
141
142  Returns:
143    A list of strings, each naming an op used by the graph.
144  """
145  # Map function names to definitions
146  name_to_function = {}
147  for fun in graph_def.library.function:
148    name_to_function[fun.signature.name] = fun
149
150  # Collect the list of op names.  Since functions can reference functions, we
151  # need a recursive traversal.
152  used_ops = set()  # Includes both primitive ops and functions
153  functions_to_process = []  # A subset of used_ops
154
155  def mark_op_as_used(op):
156    if op not in used_ops and op in name_to_function:
157      functions_to_process.append(name_to_function[op])
158    used_ops.add(op)
159
160  def process_node(node):
161    mark_op_as_used(node.op)
162    if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
163      mark_op_as_used(node.attr["f"].func.name)
164
165  for node in graph_def.node:
166    process_node(node)
167  while functions_to_process:
168    fun = functions_to_process.pop()
169    for node in fun.node_def:
170      process_node(node)
171
172  return [op for op in used_ops if op not in name_to_function]
173
174
175def stripped_op_list_for_graph(graph_def):
176  """Collect the stripped OpDefs for ops used by a graph.
177
178  This function computes the `stripped_op_list` field of `MetaGraphDef` and
179  similar protos.  The result can be communicated from the producer to the
180  consumer, which can then use the C++ function
181  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
182
183  Args:
184    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
185
186  Returns:
187    An `OpList` of ops used by the graph.
188  """
189  # This is similar to StrippedOpListForGraph in C++, but unlike its
190  # C++ counterpart, this version does not require all ops to be registered.
191  # This is done to support Prelu fusion in tfjs.
192  used_ops = ops_used_by_graph_def(graph_def)
193  op_defs = []
194  for op in sorted(used_ops):
195    op_def = op_def_registry.get(op)
196    if op_def is not None:
197      op_defs.append(op_def)
198  return op_def_pb2.OpList(op=op_defs)
199
200
201def _get_kind_name(item):
202  """Returns the kind name in CollectionDef.
203
204  Args:
205    item: A data item.
206
207  Returns:
208    The string representation of the kind in CollectionDef.
209  """
210  if isinstance(item, (str, bytes)):
211    kind = "bytes_list"
212  elif isinstance(item, int):
213    kind = "int64_list"
214  elif isinstance(item, float):
215    kind = "float_list"
216  elif isinstance(item, Any):
217    kind = "any_list"
218  else:
219    kind = "node_list"
220  return kind
221
222
223SAVE_AND_RESTORE_OPS = ["SaveV2",
224                        "Save", "SaveSlice",
225                        "LegacySave", "LegacySaveSlice",
226                        "RestoreV2",
227                        "Restore", "RestoreSlice",
228                        "LegacyRestore", "LegacyRestoreSlice"]
229
230
231def _op_name(tensor_name):
232  """Extract the Op name from a Tensor name.
233
234  The Op name is everything before a colon, if present,
235  not including any ^ prefix denoting a control dependency.
236
237  Args:
238    tensor_name: the full name of a Tensor in the graph.
239  Returns:
240    The name of the Op of which the given Tensor is an output.
241  Raises:
242    ValueError: if tensor_name is None or empty.
243  """
244  if not tensor_name:
245    raise ValueError(
246        f"Tensor name cannot be empty or None. Received: {tensor_name}.")
247
248  # Control dependency inputs start with ^.
249  if tensor_name.startswith("^"):
250    tensor_name = tensor_name[1:]
251  if ":" in tensor_name:
252    op_name, _ = tensor_name.split(":")
253    return op_name
254  return tensor_name
255
256
257def _get_scope(node_name):
258  """Extract the scope name from a node name.
259
260  The scope name is everything before the final slash,
261  not including any ^ prefix denoting a control dependency.
262
263  Args:
264    node_name: the full name of an Op or a Tensor in the graph.
265  Returns:
266    The deepest named scope containing the node.
267  Raises:
268    ValueError: if tensor_name is None or empty
269  """
270  if not node_name:
271    raise ValueError(
272        f"Node name cannot be empty or None. Received: {node_name}.")
273
274  # Control dependency inputs start with ^.
275  if node_name.startswith("^"):
276    node_name = node_name[1:]
277  if "/" in node_name:
278    scope, _ = node_name.rsplit("/", 1)
279    return scope
280
281  return ""
282
283
284def _find_extraneous_saver_nodes(graph_def, saver_def):
285  """Identifies any nodes in the graph_def related to unused Savers.
286
287  This approach assumes that each Saver is cleanly isolated in its own name
288  scope, so we need only identify the scopes associated with extraneous Savers
289  and return all the nodes in those scopes.
290
291  Args:
292    graph_def: a GraphDef proto to evaluate.
293    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
294  Returns:
295    An iterable of node names that may be safely omitted.
296  """
297  # TODO(soergel): confirm that the assumption of scope isolation is valid.
298  # If not, we need to walk up the graph from any restore_all nodes, and walk
299  # down the graph from any Save/Restore nodes.  I drafted that approach too,
300  # but it seems unnecessarily complex given the name scope solution.
301
302  # load the graph DAG in minimal form, without initializing a full Graph object
303  nodes = {
304      node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op)
305      for node_def in graph_def.node
306  }
307
308  retain_scope_save = None
309  retain_scope_restore = None
310  # It's possible to have no saver if the graph has no Variables
311  if saver_def is not None:
312    save_op_name = _op_name(saver_def.save_tensor_name)
313    restore_op_name = _op_name(saver_def.restore_op_name)
314
315    # The save and restore scopes should always be the same, but if they differ
316    # for some reason, we retain them both to be safe.
317    retain_scope_restore = _get_scope(restore_op_name) + "/"
318    retain_scope_save = _get_scope(save_op_name) + "/"
319
320  all_saver_node_names = set(
321      name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
322
323  all_saver_scopes = (
324      set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
325  all_saver_scopes = set(x + "/" for x in all_saver_scopes)
326
327  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
328                                              retain_scope_restore])
329
330  extraneous_node_names = set()
331  for name, _ in nodes.items():
332    for extraneous_scope in extraneous_scopes:
333      if name.startswith(extraneous_scope):
334        extraneous_node_names.add(name)
335        break
336
337  return extraneous_node_names
338
339
340def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
341  """Returns `True` if a node should be included.
342
343  Args:
344    node_or_node_name: A node or `string` node name.
345    export_scope: `string`. Name scope under which to extract the subgraph. The
346      scope name will be stripped from the node definitions for easy import
347      later into new name scopes.
348    exclude_nodes: An iterable of nodes or `string` node names to omit from the
349      export, or None.  Note no sanity-checking is done, so this list must be
350      carefully constructed to avoid producing an invalid graph.
351
352  Returns:
353    `True` if the node should be included.
354  """
355  if not isinstance(node_or_node_name, str):
356    try:
357      node_name = node_or_node_name.name
358    except AttributeError:
359      # Keep the object that we don't know how to process.
360      return True
361  else:
362    node_name = node_or_node_name
363
364  if exclude_nodes and (node_or_node_name in exclude_nodes
365                        or node_name in exclude_nodes):
366    return False
367
368  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
369          (not export_scope or node_name.startswith(export_scope)))
370
371
372def add_collection_def(meta_graph_def, key, graph=None,
373                       export_scope=None, exclude_nodes=None,
374                       override_contents=None):
375  """Adds a collection to MetaGraphDef protocol buffer.
376
377  Args:
378    meta_graph_def: MetaGraphDef protocol buffer.
379    key: One of the GraphKeys or user-defined string.
380    graph: The `Graph` from which to get collections.
381    export_scope: Optional `string`. Name scope to remove.
382    exclude_nodes: An iterable of nodes or `string` node names to omit from the
383      collection, or None.
384    override_contents: An iterable of values to place in the collection,
385      ignoring the current values (if set).
386  """
387  if graph and not isinstance(graph, ops.Graph):
388    raise TypeError(
389        f"graph must be of type Graph. Received type: {type(graph)}.")
390
391  if not isinstance(key, str) and not isinstance(key, bytes):
392    logging.warning("Only collections with string type keys will be "
393                    "serialized. This key has %s", type(key))
394    return
395
396  # Sets graph to default graph if it's not passed in.
397  graph = graph or ops.get_default_graph()
398
399  if override_contents:
400    collection_list = override_contents
401  else:
402    collection_list = graph.get_collection(key)
403
404  # Remove nodes that should not be exported from the collection list.
405  collection_list = [x for x in collection_list if
406                     _should_include_node(x, export_scope, exclude_nodes)]
407  if not collection_list:
408    return
409
410  try:
411    col_def = meta_graph_def.collection_def[key]
412    to_proto = ops.get_to_proto_function(key)
413    proto_type = ops.get_collection_proto_type(key)
414    if to_proto:
415      kind = "bytes_list"
416      for x in collection_list:
417        # Additional type check to make sure the returned proto is indeed
418        # what we expect.
419        proto = to_proto(x, export_scope=export_scope)
420        if proto:
421          assert isinstance(proto, proto_type)
422          getattr(col_def, kind).value.append(proto.SerializeToString())
423    else:
424      kind = _get_kind_name(collection_list[0])
425      if kind == "node_list":
426        for x in collection_list:
427          if not export_scope or x.name.startswith(export_scope):
428            getattr(col_def, kind).value.append(
429                ops.strip_name_scope(x.name, export_scope))
430      elif kind == "bytes_list":
431        # NOTE(opensource): This force conversion is to work around the fact
432        # that Python3 distinguishes between bytes and strings.
433        getattr(col_def, kind).value.extend(
434            [compat.as_bytes(x) for x in collection_list])
435      else:
436        getattr(col_def, kind).value.extend([x for x in collection_list])
437  except Exception as e:  # pylint: disable=broad-except
438    logging.warning("Issue encountered when serializing %s.\n"
439                    "Type is unsupported, or the types of the items don't "
440                    "match field type in CollectionDef. Note this is a warning "
441                    "and probably safe to ignore.\n%s", key, str(e))
442    if key in meta_graph_def.collection_def:
443      del meta_graph_def.collection_def[key]
444    return
445
446
447def _is_default_attr_value(op_def, attr_name, attr_value):
448  """Checks if given attribute matches the default value in the op def."""
449  for attr_def in op_def.attr:
450    if attr_def.name == attr_name:
451      if not attr_def.HasField("default_value"):
452        return False
453      # c_api.EqualAttrValueWrapper returns an empty string
454      # if both arguments represent an equivalent AttrValue instance.
455      return not c_api.EqualAttrValueWrapper(
456          attr_value.SerializeToString(),
457          attr_def.default_value.SerializeToString())
458  return False
459
460
461def strip_graph_default_valued_attrs(meta_graph_def):
462  """Strips default valued attributes for node defs in given MetaGraphDef.
463
464  This method also sets `meta_info_def.stripped_default_attrs` in the given
465  `MetaGraphDef` proto to True.
466
467  Args:
468    meta_graph_def: `MetaGraphDef` protocol buffer
469
470  Returns:
471    None.
472  """
473  # Map function op names to their function definitions.
474  op_name_to_function = {}
475  for function_def in meta_graph_def.graph_def.library.function:
476    op_name_to_function[function_def.signature.name] = function_def
477
478  def _strip_node_default_valued_attrs(node_def):
479    """Removes default valued attributes from a single node def."""
480    if node_def.op in op_name_to_function:
481      return
482
483    op_def = op_def_registry.get(node_def.op)
484    if op_def is None:
485      return
486
487    attrs_to_strip = set()
488    for attr_name, attr_value in node_def.attr.items():
489      if _is_default_attr_value(op_def, attr_name, attr_value):
490        attrs_to_strip.add(attr_name)
491
492    for attr in attrs_to_strip:
493      del node_def.attr[attr]
494
495  # Process all NodeDef instances in graph_def.
496  for node_def in meta_graph_def.graph_def.node:
497    _strip_node_default_valued_attrs(node_def)
498
499  # Process all NodeDef instances in graph_def.library.function.
500  for function_def in meta_graph_def.graph_def.library.function:
501    for function_node_def in function_def.node_def:
502      _strip_node_default_valued_attrs(function_node_def)
503
504  # Tell consumers of this graph that default valued attrs have been stripped.
505  meta_graph_def.meta_info_def.stripped_default_attrs = True
506
507
508def create_meta_graph_def(meta_info_def=None,
509                          graph_def=None,
510                          saver_def=None,
511                          collection_list=None,
512                          graph=None,
513                          export_scope=None,
514                          exclude_nodes=None,
515                          clear_extraneous_savers=False,
516                          strip_default_attrs=False):
517  # pylint: disable=line-too-long
518  """Construct and returns a `MetaGraphDef` protocol buffer.
519
520  Args:
521    meta_info_def: `MetaInfoDef` protocol buffer.
522    graph_def: `GraphDef` protocol buffer.
523    saver_def: `SaverDef` protocol buffer.
524    collection_list: List of string keys to collect.
525    graph: The `Graph` to create `MetaGraphDef` out of.
526    export_scope: Optional `string`. Name scope to remove.
527    exclude_nodes: An iterable of nodes or `string` node names to omit from all
528      collection, or None.
529    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
530        collection.  Note this method does not alter the graph, so any
531        extraneous Save/Restore ops should have been removed already, as needed.
532    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
533        removed from the NodeDefs. For a detailed guide, see
534        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
535
536  Returns:
537    MetaGraphDef protocol buffer.
538
539  Raises:
540    TypeError: If the arguments are not of the correct proto buffer type.
541  """
542  # pylint: enable=line-too-long
543  # Type check.
544  if graph and not isinstance(graph, ops.Graph):
545    raise TypeError(
546        f"graph must be of type Graph. Received type: {type(graph)}.")
547  if meta_info_def and not isinstance(meta_info_def,
548                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
549    raise TypeError(
550        "meta_info_def must be of type MetaInfoDef. "
551        f"Received type: {type(meta_info_def)}.")
552  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
553    raise TypeError(
554        "graph_def must be of type GraphDef. "
555        f"Received type: {type(graph_def)}.")
556  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
557    raise TypeError(
558        f"saver_def must be of type SaverDef. "
559        f"Received type: {type(saver_def)}.")
560
561  # Sets graph to default graph if it's not passed in.
562  graph = graph or ops.get_default_graph()
563
564  # Creates a MetaGraphDef proto.
565  meta_graph_def = meta_graph_pb2.MetaGraphDef()
566  # Adds meta_info_def.
567  if not meta_info_def:
568    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
569
570  # Set the tf version strings to the current tf build.
571  meta_info_def.tensorflow_version = versions.__version__
572  meta_info_def.tensorflow_git_version = versions.__git_version__
573  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
574
575  # Adds graph_def or the default.
576  if not graph_def:
577    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
578  else:
579    meta_graph_def.graph_def.MergeFrom(graph_def)
580
581  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
582  # pylint: disable=g-explicit-length-test
583  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
584    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
585        stripped_op_list_for_graph(meta_graph_def.graph_def))
586  # pylint: enable=g-explicit-length-test
587
588  # Strip default valued attributes in graph_def.
589  if strip_default_attrs:
590    strip_graph_default_valued_attrs(meta_graph_def)
591
592  # Adds saver_def.
593  if saver_def:
594    meta_graph_def.saver_def.MergeFrom(saver_def)
595
596  # Adds collection_list.
597  if collection_list is not None:
598    clist = collection_list
599  else:
600    clist = graph.get_all_collection_keys()
601
602  for ctype in clist:
603    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
604      # Avoid importing Saver here
605      from_proto = ops.get_from_proto_function(ctype)
606      add_collection_def(meta_graph_def, ctype,
607                         graph=graph,
608                         export_scope=export_scope,
609                         exclude_nodes=exclude_nodes,
610                         override_contents=[from_proto(saver_def)])
611    else:
612      add_collection_def(meta_graph_def, ctype,
613                         graph=graph,
614                         export_scope=export_scope,
615                         exclude_nodes=exclude_nodes)
616  return meta_graph_def
617
618
619def read_meta_graph_file(filename):
620  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
621
622  Args:
623    filename: `meta_graph_def` filename including the path.
624
625  Returns:
626    A `MetaGraphDef` protocol buffer.
627
628  Raises:
629    IOError: If the file doesn't exist, or cannot be successfully parsed.
630  """
631  meta_graph_def = meta_graph_pb2.MetaGraphDef()
632  if not file_io.file_exists(filename):
633    raise IOError(f"File does not exist. Received: {filename}.")
634  # First try to read it as a binary file.
635  with file_io.FileIO(filename, "rb") as f:
636    file_content = f.read()
637  try:
638    meta_graph_def.ParseFromString(file_content)
639    return meta_graph_def
640  except Exception:  # pylint: disable=broad-except
641    pass
642
643  # Next try to read it as a text file.
644  try:
645    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
646  except text_format.ParseError as e:
647    raise IOError(f"Cannot parse file {filename}: {str(e)}.")
648
649  return meta_graph_def
650
651
652def import_scoped_meta_graph(meta_graph_or_file,
653                             clear_devices=False,
654                             graph=None,
655                             import_scope=None,
656                             input_map=None,
657                             unbound_inputs_col_name="unbound_inputs",
658                             restore_collections_predicate=(lambda key: True)):
659  """Recreates a `Graph` saved in a `MetaGraphDef` proto.
660
661  This function takes a `MetaGraphDef` protocol buffer as input. If
662  the argument is a file containing a `MetaGraphDef` protocol buffer ,
663  it constructs a protocol buffer from the file content. The function
664  then adds all the nodes from the `graph_def` field to the
665  current graph, recreates the desired collections, and returns a dictionary of
666  all the Variables imported into the name scope.
667
668  In combination with `export_scoped_meta_graph()`, this function can be used to
669
670  * Serialize a graph along with other Python objects such as `QueueRunner`,
671    `Variable` into a `MetaGraphDef`.
672
673  * Restart training from a saved graph and checkpoints.
674
675  * Run inference from a saved graph and checkpoints.
676
677  Args:
678    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
679      the path) containing a `MetaGraphDef`.
680    clear_devices: Boolean which controls whether to clear device information
681      from graph_def. Default false.
682    graph: The `Graph` to import into. If `None`, use the default graph.
683    import_scope: Optional `string`. Name scope into which to import the
684      subgraph. If `None`, the graph is imported to the root name scope.
685    input_map: A dictionary mapping input names (as strings) in `graph_def` to
686      `Tensor` objects. The values of the named input tensors in the imported
687      graph will be re-mapped to the respective `Tensor` values.
688    unbound_inputs_col_name: Collection name for looking up unbound inputs.
689    restore_collections_predicate: a predicate on collection names. A collection
690      named c (i.e whose key is c) will be restored iff
691      1) `restore_collections_predicate(c)` is True, and
692      2) `c != unbound_inputs_col_name`.
693
694  Returns:
695    A dictionary of all the `Variables` imported into the name scope.
696
697  Raises:
698    ValueError: If the graph_def contains unbound inputs.
699  """
700  return import_scoped_meta_graph_with_return_elements(
701      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
702      unbound_inputs_col_name, restore_collections_predicate)[0]
703
704
705def import_scoped_meta_graph_with_return_elements(
706    meta_graph_or_file,
707    clear_devices=False,
708    graph=None,
709    import_scope=None,
710    input_map=None,
711    unbound_inputs_col_name="unbound_inputs",
712    restore_collections_predicate=(lambda key: True),
713    return_elements=None):
714  """Imports graph from `MetaGraphDef` and returns vars and return elements.
715
716  This function takes a `MetaGraphDef` protocol buffer as input. If
717  the argument is a file containing a `MetaGraphDef` protocol buffer ,
718  it constructs a protocol buffer from the file content. The function
719  then adds all the nodes from the `graph_def` field to the
720  current graph, recreates the desired collections, and returns a dictionary of
721  all the Variables imported into the name scope.
722
723  In combination with `export_scoped_meta_graph()`, this function can be used to
724
725  * Serialize a graph along with other Python objects such as `QueueRunner`,
726    `Variable` into a `MetaGraphDef`.
727
728  * Restart training from a saved graph and checkpoints.
729
730  * Run inference from a saved graph and checkpoints.
731
732  Args:
733    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
734      the path) containing a `MetaGraphDef`.
735    clear_devices: Boolean which controls whether to clear device information
736      from graph_def. Default false.
737    graph: The `Graph` to import into. If `None`, use the default graph.
738    import_scope: Optional `string`. Name scope into which to import the
739      subgraph. If `None`, the graph is imported to the root name scope.
740    input_map: A dictionary mapping input names (as strings) in `graph_def` to
741      `Tensor` objects. The values of the named input tensors in the imported
742      graph will be re-mapped to the respective `Tensor` values.
743    unbound_inputs_col_name: Collection name for looking up unbound inputs.
744    restore_collections_predicate: a predicate on collection names. A collection
745      named c (i.e whose key is c) will be restored iff
746      1) `restore_collections_predicate(c)` is True, and
747      2) `c != unbound_inputs_col_name`.
748    return_elements:  A list of strings containing operation names in the
749      `MetaGraphDef` that will be returned as `Operation` objects; and/or
750      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
751
752  Returns:
753    A tuple of (
754      dictionary of all the `Variables` imported into the name scope,
755      list of `Operation` or `Tensor` objects from the `return_elements` list).
756
757  Raises:
758    ValueError: If the graph_def contains unbound inputs.
759
760  """
761  if context.executing_eagerly():
762    raise ValueError("Exporting/importing meta graphs is not supported when "
763                     "eager execution is enabled.")
764  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
765    meta_graph_def = meta_graph_or_file
766  else:
767    meta_graph_def = read_meta_graph_file(meta_graph_or_file)
768
769  if unbound_inputs_col_name:
770    for key, col_def in meta_graph_def.collection_def.items():
771      if key == unbound_inputs_col_name:
772        kind = col_def.WhichOneof("kind")
773        field = getattr(col_def, kind)
774        if field.value and (
775            not input_map or
776            sorted([compat.as_str(v) for v in field.value]) !=
777            sorted(input_map)):
778          raise ValueError("Graph contains unbound inputs: %s. Must "
779                           "provide these inputs through input_map." % ",".join(
780                               compat.as_str(v)
781                               for v in field.value
782                               if not input_map or v not in input_map))
783        break
784
785  # Sets graph to default graph if it's not passed in.
786  graph = graph or ops.get_default_graph()
787
788  # Gathers the list of nodes we are interested in.
789  with graph.as_default():
790    producer_op_list = None
791    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
792      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
793    input_graph_def = meta_graph_def.graph_def
794    # Remove all the explicit device specifications for this node. This helps to
795    # make the graph more portable.
796    if clear_devices:
797      for node in input_graph_def.node:
798        node.device = ""
799
800    scope_to_prepend_to_names = graph.unique_name(
801        import_scope or "", mark_as_used=False)
802
803    imported_return_elements = importer.import_graph_def(
804        input_graph_def,
805        name=(import_scope or scope_to_prepend_to_names),
806        input_map=input_map,
807        producer_op_list=producer_op_list,
808        return_elements=return_elements)
809
810    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
811    # without a VariableDef.trainable field set.
812    tf_version = meta_graph_def.meta_info_def.tensorflow_version
813    if not tf_version:
814      variables_have_trainable = True
815    else:
816      variables_have_trainable = (
817          packaging_version.parse(tf_version) >= packaging_version.parse("1.9"))
818
819    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
820    # variables to trainable if the value is not set in their VariableDef.
821    sorted_collections = []
822    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
823      sorted_collections.append(
824          (ops.GraphKeys.TRAINABLE_VARIABLES,
825           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
826    for key, value in sorted(meta_graph_def.collection_def.items()):
827      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
828        sorted_collections.append((key, value))
829
830    # Restores all the other collections.
831    variable_objects = {}
832    for key, col_def in sorted_collections:
833      # Don't add unbound_inputs to the new graph.
834      if key == unbound_inputs_col_name:
835        continue
836      if not restore_collections_predicate(key):
837        continue
838
839      kind = col_def.WhichOneof("kind")
840      if kind is None:
841        logging.error("Cannot identify data type for collection %s. Skipping.",
842                      key)
843        continue
844      from_proto = ops.get_from_proto_function(key)
845
846      # Temporary change to allow the TFMA evaluator to read metric variables
847      # saved as a bytes list.
848      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
849      if key == ops.GraphKeys.METRIC_VARIABLES:
850        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
851        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
852      if from_proto and kind == "bytes_list":
853        proto_type = ops.get_collection_proto_type(key)
854        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
855          for value in col_def.bytes_list.value:
856            variable = variable_objects.get(value, None)
857            if variable is None:
858              proto = proto_type()
859              proto.ParseFromString(value)
860              if not variables_have_trainable:
861                # If the VariableDef proto does not contain a "trainable"
862                # property because it was exported before that property was
863                # added, we default it to whether the variable is in the
864                # TRAINABLE_VARIABLES collection. We've sorted
865                # TRAINABLE_VARIABLES to be first, so trainable variables will
866                # be created from that collection.
867                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
868              variable = from_proto(
869                  proto, import_scope=scope_to_prepend_to_names)
870              variable_objects[value] = variable
871            graph.add_to_collection(key, variable)
872        else:
873          for value in col_def.bytes_list.value:
874            proto = proto_type()
875            proto.ParseFromString(value)
876            graph.add_to_collection(
877                key, from_proto(
878                    proto, import_scope=scope_to_prepend_to_names))
879      else:
880        field = getattr(col_def, kind)
881        if key in _COMPAT_COLLECTION_LIST:
882          logging.warning(
883              "The saved meta_graph is possibly from an older release:\n"
884              "'%s' collection should be of type 'byte_list', but instead "
885              "is of type '%s'.", key, kind)
886        if kind == "node_list":
887          for value in field.value:
888            col_op = graph.as_graph_element(
889                ops.prepend_name_scope(value, scope_to_prepend_to_names))
890            graph.add_to_collection(key, col_op)
891        elif kind == "int64_list":
892          # NOTE(opensource): This force conversion is to work around the fact
893          # that Python2 distinguishes between int and long, while Python3 has
894          # only int.
895          for value in field.value:
896            graph.add_to_collection(key, int(value))
897        else:
898          for value in field.value:
899            graph.add_to_collection(
900                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
901
902    var_list = {}
903    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
904                                     scope=scope_to_prepend_to_names)
905    for v in variables:
906      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
907
908  return var_list, imported_return_elements
909
910
911def export_scoped_meta_graph(filename=None,
912                             graph_def=None,
913                             graph=None,
914                             export_scope=None,
915                             as_text=False,
916                             unbound_inputs_col_name="unbound_inputs",
917                             clear_devices=False,
918                             saver_def=None,
919                             clear_extraneous_savers=False,
920                             strip_default_attrs=False,
921                             save_debug_info=False,
922                             **kwargs):
923  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
924
925  This function exports the graph, saver, and collection objects into
926  `MetaGraphDef` protocol buffer with the intention of it being imported
927  at a later time or location to restart training, run inference, or be
928  a subgraph.
929
930  Args:
931    filename: Optional filename including the path for writing the
932      generated `MetaGraphDef` protocol buffer.
933    graph_def: `GraphDef` protocol buffer.
934    graph: The `Graph` to export. If `None`, use the default graph.
935    export_scope: Optional `string`. Name scope under which to extract
936      the subgraph. The scope name will be stripped from the node definitions
937      for easy import later into new name scopes. If `None`, the whole graph
938      is exported.
939    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
940    unbound_inputs_col_name: Optional `string`. If provided, a string collection
941      with the given name will be added to the returned `MetaGraphDef`,
942      containing the names of tensors that must be remapped when importing the
943      `MetaGraphDef`.
944    clear_devices: Boolean which controls whether to clear device information
945      before exporting the graph.
946    saver_def: `SaverDef` protocol buffer.
947    clear_extraneous_savers: Remove any Saver-related information from the
948        graph (both Save/Restore ops and SaverDefs) that are not associated
949        with the provided SaverDef.
950    strip_default_attrs: Set to true if default valued attributes must be
951      removed while exporting the GraphDef.
952    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
953      which in the same directory of filename and with `_debug` added before the
954      file extension.
955    **kwargs: Optional keyed arguments, including meta_info_def and
956        collection_list.
957
958  Returns:
959    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
960    name scope.
961
962  Raises:
963    ValueError: When the `GraphDef` is larger than 2GB.
964    ValueError: When executing in Eager mode and either `graph_def` or `graph`
965      is undefined.
966  """
967  if context.executing_eagerly() and not (graph_def is not None and
968                                          graph is not None):
969    raise ValueError("Exporting/importing meta graphs is not supported when "
970                     "Eager Execution is enabled.")
971  graph = graph or ops.get_default_graph()
972
973  exclude_nodes = None
974  unbound_inputs = []
975  if export_scope or clear_extraneous_savers or clear_devices:
976    if graph_def:
977      new_graph_def = graph_pb2.GraphDef()
978      new_graph_def.versions.CopyFrom(graph_def.versions)
979      new_graph_def.library.CopyFrom(graph_def.library)
980
981      if clear_extraneous_savers:
982        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
983
984      for node_def in graph_def.node:
985        if _should_include_node(node_def.name, export_scope, exclude_nodes):
986          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
987                                   clear_devices=clear_devices)
988          new_graph_def.node.extend([new_node_def])
989      graph_def = new_graph_def
990    else:
991      # Only do this complicated work if we want to remove a name scope.
992      graph_def = graph_pb2.GraphDef()
993      # pylint: disable=protected-access
994      graph_def.versions.CopyFrom(graph.graph_def_versions)
995      bytesize = 0
996
997      if clear_extraneous_savers:
998        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
999                                                     saver_def)
1000
1001      for key in sorted(graph._nodes_by_id):
1002        if _should_include_node(graph._nodes_by_id[key].name,
1003                                export_scope,
1004                                exclude_nodes):
1005          value = graph._nodes_by_id[key]
1006          # pylint: enable=protected-access
1007          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
1008                               clear_devices=clear_devices)
1009          graph_def.node.extend([node_def])
1010          if value.outputs:
1011            assert "_output_shapes" not in graph_def.node[-1].attr
1012            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
1013                output.get_shape().as_proto() for output in value.outputs])
1014          bytesize += value.node_def.ByteSize()
1015          if bytesize >= (1 << 31) or bytesize < 0:
1016            raise ValueError(
1017                "GraphDef cannot be larger than 2GB. "
1018                f"Received size: {bytesize}.")
1019
1020      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access
1021
1022    # It's possible that not all the inputs are in the export_scope.
1023    # If we would like such information included in the exported meta_graph,
1024    # add them to a special unbound_inputs collection.
1025    if unbound_inputs_col_name:
1026      # Clears the unbound_inputs collections.
1027      graph.clear_collection(unbound_inputs_col_name)
1028      for k in unbound_inputs:
1029        graph.add_to_collection(unbound_inputs_col_name, k)
1030
1031  var_list = {}
1032  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1033                                   scope=export_scope)
1034  for v in variables:
1035    if _should_include_node(v, export_scope, exclude_nodes):
1036      var_list[ops.strip_name_scope(v.name, export_scope)] = v
1037
1038  scoped_meta_graph_def = create_meta_graph_def(
1039      graph_def=graph_def,
1040      graph=graph,
1041      export_scope=export_scope,
1042      exclude_nodes=exclude_nodes,
1043      clear_extraneous_savers=clear_extraneous_savers,
1044      saver_def=saver_def,
1045      strip_default_attrs=strip_default_attrs,
1046      **kwargs)
1047
1048  if filename:
1049    graph_io.write_graph(
1050        scoped_meta_graph_def,
1051        os.path.dirname(filename),
1052        os.path.basename(filename),
1053        as_text=as_text)
1054    if save_debug_info:
1055      name, _ = os.path.splitext(filename)
1056      debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1057
1058      # Gets the operation from the graph by the name. Excludes variable nodes,
1059      # so only the nodes in the frozen models are included.
1060      # TODO(liufengdb): fix this for functions.
1061      ops_to_export = []
1062      for node in scoped_meta_graph_def.graph_def.node:
1063        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1064        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
1065
1066      graph_debug_info = error_interpolation.create_graph_debug_info_def(
1067          ops_to_export)
1068
1069      graph_io.write_graph(
1070          graph_debug_info,
1071          os.path.dirname(debug_filename),
1072          os.path.basename(debug_filename),
1073          as_text=as_text)
1074
1075  return scoped_meta_graph_def, var_list
1076
1077
1078def copy_scoped_meta_graph(from_scope, to_scope,
1079                           from_graph=None, to_graph=None):
1080  """Copies a sub-meta_graph from one scope to another.
1081
1082  Args:
1083    from_scope: `String` name scope containing the subgraph to be copied.
1084    to_scope: `String` name scope under which the copied subgraph will reside.
1085    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1086      default graph is use.
1087    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1088      default graph is used.
1089
1090  Returns:
1091    A dictionary of `Variables` that has been copied into `to_scope`.
1092
1093  Raises:
1094    ValueError: If `from_scope` and `to_scope` are the same while
1095      `from_graph` and `to_graph` are also the same.
1096  """
1097  from_graph = from_graph or ops.get_default_graph()
1098  to_graph = to_graph or ops.get_default_graph()
1099
1100  if from_graph == to_graph and from_scope == to_scope:
1101    raise ValueError("'from_scope' and 'to_scope' need to be different "
1102                     "when performing copy in the same graph. "
1103                     f"Received: 'from_graph': {from_graph}, "
1104                     f"'to_graph': {to_graph}, "
1105                     f"'from_scope': {from_scope}, 'to_scope': {to_scope}.")
1106
1107  orig_meta_graph, var_list = export_scoped_meta_graph(
1108      export_scope=from_scope, graph=from_graph)
1109  var_list = import_scoped_meta_graph(orig_meta_graph,
1110                                      graph=to_graph,
1111                                      import_scope=to_scope)
1112  return var_list
1113