xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/importer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A utility function for importing TensorFlow graphs."""
16import contextlib
17
18from tensorflow.core.framework import graph_pb2
19from tensorflow.python import tf2
20from tensorflow.python.client import pywrap_tf_session as c_api
21from tensorflow.python.framework import c_api_util
22from tensorflow.python.framework import device as pydev
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import function
25from tensorflow.python.framework import op_def_registry
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import control_flow_util
28from tensorflow.python.util import compat
29from tensorflow.python.util.deprecation import deprecated_args
30from tensorflow.python.util.tf_export import tf_export
31
32
33def _IsControlInput(input_name):
34  # Expected format: '^operation_name' (control input).
35  return input_name.startswith('^')
36
37
38def _ParseTensorName(tensor_name):
39  """Parses a tensor name into an operation name and output index.
40
41  This function will canonicalize tensor names as follows:
42
43  * "foo:0"       -> ("foo", 0)
44  * "foo:7"       -> ("foo", 7)
45  * "foo"         -> ("foo", 0)
46  * "foo:bar:baz" -> ValueError
47
48  Args:
49    tensor_name: The name of a tensor.
50
51  Returns:
52    A tuple containing the operation name, and the output index.
53
54  Raises:
55    ValueError: If `tensor_name' cannot be interpreted as the name of a tensor.
56  """
57  components = tensor_name.split(':')
58  if len(components) == 2:
59    # Expected format: 'operation_name:output_index'.
60    try:
61      output_index = int(components[1])
62    except ValueError:
63      raise ValueError(f'Cannot convert {tensor_name!r} to a tensor name. '
64                       'Second component of the name following the `:` should '
65                       f'be an int. Got {components[1]}.')
66    return components[0], output_index
67  elif len(components) == 1:
68    # Expected format: 'operation_name' (implicit 0th output).
69    return components[0], 0
70  else:
71    raise ValueError(f"Cannot convert '{tensor_name}' to a tensor name. Tensor "
72                     'names should not contain more than 1 `:`. Obtained '
73                     f'{len(components) - 1}')
74
75
76@contextlib.contextmanager
77def _MaybeDevice(device):
78  """Applies the given device only if device is not None or empty."""
79  if device:
80    with ops.device(device):
81      yield
82  else:
83    yield
84
85
86def _ProcessGraphDefParam(graph_def):
87  """Type-checks and possibly canonicalizes `graph_def`."""
88  if not isinstance(graph_def, graph_pb2.GraphDef):
89    # `graph_def` could be a dynamically-created message, so try a duck-typed
90    # approach
91    try:
92      old_graph_def = graph_def
93      graph_def = graph_pb2.GraphDef()
94      graph_def.MergeFrom(old_graph_def)
95    except TypeError:
96      raise TypeError('Argument `graph_def` must be a GraphDef proto.')
97  else:
98    # If we're using the graph_def provided by the caller, modify graph_def
99    # in-place to add attr defaults to the NodeDefs (this is visible to the
100    # caller).
101    # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py
102    # depends on. It might make sense to move this to meta_graph.py and have
103    # import_graph_def not modify the graph_def argument (we'd have to make sure
104    # this doesn't break anything else.)
105    for node in graph_def.node:
106      op_def = op_def_registry.get(node.op)
107      if op_def is None:
108        # Assume unrecognized ops are functions for now. TF_ImportGraphDef will
109        # report an error if the op is actually missing.
110        continue
111      _SetDefaultAttrValues(node, op_def)
112
113  return graph_def
114
115
116def _ProcessInputMapParam(input_map):
117  """Type-checks and possibly canonicalizes `input_map`."""
118  if input_map is None:
119    input_map = {}
120  else:
121    if not isinstance(input_map, dict):
122      raise TypeError('Argument `input_map` must be a dictionary. Obtained '
123                      f'{type(input_map).__name__}')
124    if not all(
125        isinstance(k, compat.bytes_or_text_types) for k in input_map.keys()):
126      raise TypeError('All keys for argument `input_map` must be strings. '
127                      f'Obtained keys: {list(input_map.keys())}')
128  return input_map
129
130
131def _ProcessReturnElementsParam(return_elements):
132  """Type-checks and possibly canonicalizes `return_elements`."""
133  if return_elements is None:
134    return None
135  if not all(
136      isinstance(x, compat.bytes_or_text_types) for x in return_elements):
137    raise TypeError('Argument `return_elements` must be a list of strings. '
138                    f'Obtained {return_elements}.')
139  return tuple(compat.as_str(x) for x in return_elements)
140
141
142def _FindAttrInOpDef(attr_name, op_def):
143  for attr_def in op_def.attr:
144    if attr_name == attr_def.name:
145      return attr_def
146  return None
147
148
149def _RemoveDefaultAttrs(producer_op_list, graph_def):
150  """Removes unknown default attrs according to `producer_op_list`.
151
152  Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in
153  registered OpDefs) that have a default value in `producer_op_list`.
154
155  Args:
156    producer_op_list: OpList proto.
157    graph_def: GraphDef proto
158  """
159  producer_op_dict = {op.name: op for op in producer_op_list.op}
160  for node in graph_def.node:
161    # Remove any default attr values that aren't in op_def.
162    if node.op in producer_op_dict:
163      op_def = op_def_registry.get(node.op)
164      if op_def is None:
165        # Some custom op registrations won't show up here. That's OK, attribute
166        # stripping just won't be available.
167        continue
168      producer_op_def = producer_op_dict[node.op]
169      # We make a copy of node.attr to iterate through since we may modify
170      # node.attr inside the loop.
171      for key in list(node.attr):
172        if _FindAttrInOpDef(key, op_def) is None:
173          # No attr_def in consumer, look in producer.
174          attr_def = _FindAttrInOpDef(key, producer_op_def)
175          if (attr_def and attr_def.HasField('default_value') and
176              node.attr[key] == attr_def.default_value):
177            # Unknown attr had default value in producer, delete it so it can be
178            # understood by consumer.
179            del node.attr[key]
180
181
182def _ConvertInputMapValues(name, input_map):
183  """Ensures all input map values are tensors.
184
185  This should be called from inside the import name scope.
186
187  Args:
188    name: the `name` argument passed to import_graph_def
189    input_map: the `input_map` argument passed to import_graph_def.
190
191  Returns:
192    An possibly-updated version of `input_map`.
193
194  Raises:
195    ValueError: if input map values cannot be converted due to empty name scope.
196  """
197  if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
198    if name == '':  # pylint: disable=g-explicit-bool-comparison
199      raise ValueError(
200          'tf.import_graph_def() requires a non-empty `name` if `input_map` '
201          'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
202          '`input_map` values before calling tf.import_graph_def().')
203    with ops.name_scope('_inputs'):
204      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
205  return input_map
206
207
208def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
209                                     return_elements,
210                                     validate_colocation_constraints):
211  """Populates the TF_ImportGraphDefOptions `options`."""
212  c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
213  c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
214
215  for input_src, input_dst in input_map.items():
216    input_src = compat.as_str(input_src)
217    if input_src.startswith('^'):
218      src_name = compat.as_str(input_src[1:])
219      dst_op = input_dst._as_tf_output().oper  # pylint: disable=protected-access
220      c_api.TF_ImportGraphDefOptionsRemapControlDependency(
221          options, src_name, dst_op)
222    else:
223      src_name, src_idx = _ParseTensorName(input_src)
224      src_name = compat.as_str(src_name)
225      dst_output = input_dst._as_tf_output()  # pylint: disable=protected-access
226      c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx,
227                                                    dst_output)
228  for name in return_elements or []:
229    if ':' in name:
230      op_name, index = _ParseTensorName(name)
231      op_name = compat.as_str(op_name)
232      c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
233    else:
234      c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
235                                                       compat.as_str(name))
236
237  c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(
238      options, validate_colocation_constraints)
239
240
241def _ProcessNewOps(graph):
242  """Processes the newly-added TF_Operations in `graph`."""
243  # Maps from a node to the names of the ops it's colocated with, if colocation
244  # is specified in the attributes.
245  colocation_pairs = {}
246
247  for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
248    original_device = new_op.device
249    new_op._set_device('')  # pylint: disable=protected-access
250    colocation_names = _GetColocationNames(new_op)
251    if colocation_names:
252      colocation_pairs[new_op] = colocation_names
253      # Don't set a device for this op, since colocation constraints override
254      # device functions and the original device. Note that this op's device may
255      # still be set by the loop below.
256      # TODO(skyewm): why does it override the original device?
257    else:
258      with _MaybeDevice(original_device):
259        graph._apply_device_functions(new_op)  # pylint: disable=protected-access
260
261  # The following loop populates the device field of ops that are colocated
262  # with another op.  This is implied by the colocation attribute, but we
263  # propagate the device field for completeness.
264  for op, coloc_op_list in colocation_pairs.items():
265    coloc_device = None
266    # Find any device in the list of colocated ops that have a device, if it
267    # exists.  We assume that if multiple ops have devices, they refer to the
268    # same device.  Otherwise, a runtime error will occur since the colocation
269    # property cannot be guaranteed.  Note in TF2 colocations have been removed
270    # from the public API and will be considered a hint, so there is no runtime
271    # error.
272    #
273    # One possible improvement is to try to check for compatibility of all
274    # devices in this list at import time here, which would require
275    # implementing a compatibility function for device specs in python.
276    for coloc_op_name in coloc_op_list:
277      try:
278        coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name)  # pylint: disable=protected-access
279      except KeyError:
280        # Do not error in TF2 if the colocation cannot be guaranteed
281        if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph):
282          continue
283
284        raise ValueError(f'Specified colocation to an op: {coloc_op_name} that '
285                         f'does not exist during import for op: {op.name}')
286      if coloc_op.device:
287        coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
288        break
289    if coloc_device:
290      op._set_device(coloc_device)  # pylint: disable=protected-access
291
292
293def _GetColocationNames(op):
294  """Returns names of the ops that `op` should be colocated with."""
295  colocation_names = []
296  try:
297    class_values = op.get_attr('_class')
298  except ValueError:
299    # No _class attr
300    return
301  for val in class_values:
302    val = compat.as_str(val)
303    if val.startswith('loc:@'):
304      colocation_node_name = val[len('loc:@'):]
305      if colocation_node_name != op.name:
306        colocation_names.append(colocation_node_name)
307  return colocation_names
308
309
310def _GatherReturnElements(requested_return_elements, graph, results):
311  """Returns the requested return elements from results.
312
313  Args:
314    requested_return_elements: list of strings of operation and tensor names
315    graph: Graph
316    results: wrapped TF_ImportGraphDefResults
317
318  Returns:
319    list of `Operation` and/or `Tensor` objects
320  """
321  return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results)
322  return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results)
323
324  combined_return_elements = []
325  outputs_idx = 0
326  opers_idx = 0
327  for name in requested_return_elements:
328    if ':' in name:
329      combined_return_elements.append(
330          graph._get_tensor_by_tf_output(return_outputs[outputs_idx]))  # pylint: disable=protected-access
331      outputs_idx += 1
332    else:
333      combined_return_elements.append(
334          graph._get_operation_by_tf_operation(return_opers[opers_idx]))  # pylint: disable=protected-access
335      opers_idx += 1
336  return combined_return_elements
337
338
339def _SetDefaultAttrValues(node_def, op_def):
340  """Set any default attr values in `node_def` that aren't present."""
341  assert node_def.op == op_def.name
342  for attr_def in op_def.attr:
343    key = attr_def.name
344    if attr_def.HasField('default_value'):
345      value = node_def.attr[key]
346      if value is None or value.WhichOneof('value') is None:
347        node_def.attr[key].CopyFrom(attr_def.default_value)
348
349
350@tf_export('graph_util.import_graph_def', 'import_graph_def')
351@deprecated_args(None, 'Please file an issue at '
352                 'https://github.com/tensorflow/tensorflow/issues if you depend'
353                 ' on this feature.', 'op_dict')
354def import_graph_def(graph_def,
355                     input_map=None,
356                     return_elements=None,
357                     name=None,
358                     op_dict=None,
359                     producer_op_list=None):
360  """Imports the graph from `graph_def` into the current default `Graph`.
361
362  This function provides a way to import a serialized TensorFlow
363  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
364  protocol buffer, and extract individual objects in the `GraphDef` as
365  `tf.Tensor` and `tf.Operation` objects. Once extracted,
366  these objects are placed into the current default `Graph`. See
367  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
368  proto.
369
370  Args:
371    graph_def: A `GraphDef` proto containing operations to be imported into
372      the default graph.
373    input_map: A dictionary mapping input names (as strings) in `graph_def`
374      to `Tensor` objects. The values of the named input tensors in the
375      imported graph will be re-mapped to the respective `Tensor` values.
376    return_elements: A list of strings containing operation names in
377      `graph_def` that will be returned as `Operation` objects; and/or
378      tensor names in `graph_def` that will be returned as `Tensor` objects.
379    name: (Optional.) A prefix that will be prepended to the names in
380      `graph_def`. Note that this does not apply to imported function names.
381      Defaults to `"import"`.
382    op_dict: (Optional.) Deprecated, do not use.
383    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
384      list of `OpDef`s used by the producer of the graph. If provided,
385      unrecognized attrs for ops in `graph_def` that have their default value
386      according to `producer_op_list` will be removed. This will allow some more
387      `GraphDef`s produced by later binaries to be accepted by earlier binaries.
388
389  Returns:
390    A list of `Operation` and/or `Tensor` objects from the imported graph,
391    corresponding to the names in `return_elements`,
392    and None if `returns_elements` is None.
393
394  Raises:
395    TypeError: If `graph_def` is not a `GraphDef` proto,
396      `input_map` is not a dictionary mapping strings to `Tensor` objects,
397      or `return_elements` is not a list of strings.
398    ValueError: If `input_map`, or `return_elements` contains names that
399      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
400      it refers to an unknown tensor).
401  """
402  del op_dict
403  return _import_graph_def_internal(
404      graph_def,
405      input_map=input_map,
406      return_elements=return_elements,
407      name=name,
408      producer_op_list=producer_op_list)
409
410
411def import_graph_def_for_function(  # pylint: disable=invalid-name
412    graph_def, name=None):
413  """Like import_graph_def but does not validate colocation constraints."""
414  return _import_graph_def_internal(
415      graph_def, validate_colocation_constraints=False, name=name)
416
417
418def _import_graph_def_internal(  # pylint: disable=invalid-name
419    graph_def,
420    input_map=None,
421    return_elements=None,
422    validate_colocation_constraints=True,
423    name=None,
424    producer_op_list=None):
425  """Imports the graph from `graph_def` into the current default `Graph`.
426
427  This function provides a way to import a serialized TensorFlow
428  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
429  protocol buffer, and extract individual objects in the `GraphDef` as
430  `tf.Tensor` and `tf.Operation` objects. Once extracted,
431  these objects are placed into the current default `Graph`. See
432  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
433  proto.
434
435  Args:
436    graph_def: A `GraphDef` proto containing operations to be imported into the
437      default graph.
438    input_map: A dictionary mapping input names (as strings) in `graph_def` to
439      `Tensor` objects. The values of the named input tensors in the imported
440      graph will be re-mapped to the respective `Tensor` values.
441    return_elements: A list of strings containing operation names in `graph_def`
442      that will be returned as `Operation` objects; and/or tensor names in
443      `graph_def` that will be returned as `Tensor` objects.
444    validate_colocation_constraints: Whether to validate colocation constraints.
445    name: (Optional.) A prefix that will be prepended to the names in
446      `graph_def`. Note that this does not apply to imported function names.
447      Defaults to `"import"`.
448    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
449      list of `OpDef`s used by the producer of the graph. If provided,
450      unrecognized attrs for ops in `graph_def` that have their default value
451      according to `producer_op_list` will be removed. This will allow some more
452      `GraphDef`s produced by later binaries to be accepted by earlier binaries.
453
454  Returns:
455    A list of `Operation` and/or `Tensor` objects from the imported graph,
456    corresponding to the names in `return_elements`,
457    and None if `returns_elements` is None.
458
459  Raises:
460    TypeError: If `graph_def` is not a `GraphDef` proto,
461      `input_map` is not a dictionary mapping strings to `Tensor` objects,
462      or `return_elements` is not a list of strings.
463    ValueError: If `input_map`, or `return_elements` contains names that
464      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
465      it refers to an unknown tensor).
466  """
467  graph_def = _ProcessGraphDefParam(graph_def)
468  input_map = _ProcessInputMapParam(input_map)
469  return_elements = _ProcessReturnElementsParam(return_elements)
470
471  if producer_op_list is not None:
472    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
473    _RemoveDefaultAttrs(producer_op_list, graph_def)
474
475  graph = ops.get_default_graph()
476  with ops.name_scope(name, 'import', input_map.values()) as scope:
477    # Save unique prefix generated by name_scope
478    if scope:
479      assert scope.endswith('/')
480      prefix = scope[:-1]
481    else:
482      prefix = ''
483
484    # Generate any input map tensors inside name scope
485    input_map = _ConvertInputMapValues(name, input_map)
486
487  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
488  options = scoped_options.options
489  _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements,
490                                   validate_colocation_constraints)
491
492  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
493  # Session.run call cannot occur between creating the TF_Operations in the
494  # TF_GraphImportGraphDefWithResults call and mutating the them in
495  # _ProcessNewOps.
496  with graph._mutation_lock():  # pylint: disable=protected-access
497    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
498      try:
499        with graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
500          results = c_api.TF_GraphImportGraphDefWithResults(
501              c_graph, serialized, options)
502        results = c_api_util.ScopedTFImportGraphDefResults(results)
503      except errors.InvalidArgumentError as e:
504        # Convert to ValueError for backwards compatibility.
505        raise ValueError(str(e))
506
507    # Create _DefinedFunctions for any imported functions.
508    #
509    # We do this by creating _DefinedFunctions directly from `graph_def`, and
510    # adding them to `graph`. Adding an existing function to a TF_Graph is a
511    # no-op, so this only has the effect of updating the Python state (usually
512    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
513    #
514    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
515    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
516
517    _ProcessNewOps(graph)
518
519  if graph_def.library and graph_def.library.function:
520    functions = function.from_library(graph_def.library)
521    for f in functions:
522      f.add_to_graph(graph)
523
524  # Treat input mappings that don't appear in the graph as an error, because
525  # they are likely to be due to a typo.
526  missing_unused_input_keys = (
527      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
528          results.results))
529  if missing_unused_input_keys:
530    missing_unused_input_keys = [
531        compat.as_str(s) for s in missing_unused_input_keys
532    ]
533    missing_keys = ', '.join(missing_unused_input_keys)
534    raise ValueError(
535        'Attempted to map inputs that were not found in graph_def: '
536        f'[{missing_keys}]')
537
538  if return_elements is None:
539    return None
540  else:
541    return _GatherReturnElements(return_elements, graph, results.results)
542