xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/function_deserialization.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
16
17import collections
18import pprint
19import re
20from absl import logging
21
22from tensorflow.core.protobuf import saved_object_graph_pb2
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function as function_lib
25from tensorflow.python.eager import function_spec as function_spec_lib
26from tensorflow.python.framework import func_graph as func_graph_lib
27from tensorflow.python.framework import function_def_to_graph as function_def_lib
28from tensorflow.python.framework import op_def_registry
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import custom_gradient
33from tensorflow.python.ops import default_gradient
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.saved_model import nested_structure_coder
36from tensorflow.python.util import compat
37from tensorflow.python.util import nest
38from tensorflow.python.util import tf_decorator
39from tensorflow.python.util import tf_inspect
40
41
42def _is_tensor(t):
43  return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
44
45
46# TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the
47# structured signature.
48def _call_concrete_function(function, inputs):
49  """Calls a restored Function with structured inputs.
50
51  This differs from `function.__call__` in that inputs and outputs are
52  structured and that it casts inputs to tensors if needed.
53
54  Note: this does not checks that non-tensor inputs match. That should be
55  done before via `_concrete_function_callable_with`.
56
57  Args:
58    function: ConcreteFunction to call.
59    inputs: Structured inputs compatible with
60      `function.graph.structured_input_signature`.
61
62  Returns:
63    The structured function output.
64  """
65  expected_structure = function.graph.structured_input_signature
66  flatten_inputs = nest.flatten_up_to(
67      expected_structure, inputs, expand_composites=True)
68  flatten_expected = nest.flatten(expected_structure, expand_composites=True)
69  tensor_inputs = []
70  for arg, expected in zip(flatten_inputs, flatten_expected):
71    if isinstance(expected, tensor_spec.TensorSpec):
72      tensor_inputs.append(
73          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
74    elif isinstance(expected, resource_variable_ops.VariableSpec):
75      tensor_inputs.append(arg)
76  result = function._call_flat(tensor_inputs, function.captured_inputs)  # pylint: disable=protected-access
77  if isinstance(result, ops.Operation):
78    return None
79  return result
80
81
82def _try_convert_to_tensor_spec(arg, dtype_hint):
83  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
84  try:
85    # Note: try conversion in a FuncGraph to avoid polluting current context.
86    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
87      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
88      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
89  except (TypeError, ValueError):
90    return None
91
92
93def _concrete_function_callable_with(function, inputs, allow_conversion):
94  """Returns whether concrete `function` can be called with `inputs`."""
95  expected_structure = function.graph.structured_input_signature
96  try:
97    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
98  except (TypeError, ValueError):
99    return False
100
101  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
102    if isinstance(expected, tensor_spec.TensorSpec):
103      if allow_conversion:
104        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
105      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
106        return False
107      if arg.dtype != expected.dtype:
108        return False
109      if not expected.shape.is_compatible_with(arg.shape):
110        return False
111    elif isinstance(expected, type_spec.TypeSpec):
112      if not expected.is_compatible_with(arg):
113        return False
114    elif _is_tensor(arg):
115      if id(arg) != id(expected):
116        return False
117    else:
118      if arg != expected:
119        return False
120  return True
121
122
123def _deserialize_function_spec_as_nonmethod(function_spec_proto):
124  """Deserialize a FunctionSpec object from its proto representation."""
125  typeless_fullargspec = nested_structure_coder.decode_proto(
126      function_spec_proto.fullargspec)
127
128  # Convert a method function into a non method.
129  if function_spec_proto.is_method:
130    if not typeless_fullargspec.args:
131      raise NotImplementedError(
132          "Cannot deserialize a method function without a named "
133          "'self' argument.")
134    args = typeless_fullargspec.args[1:]
135  else:
136    args = typeless_fullargspec.args
137
138  fullargspec = tf_inspect.FullArgSpec(
139      args=args,
140      varargs=typeless_fullargspec.varargs,
141      varkw=typeless_fullargspec.varkw,
142      defaults=typeless_fullargspec.defaults,
143      kwonlyargs=typeless_fullargspec.kwonlyargs,
144      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
145      annotations=typeless_fullargspec.annotations)
146  input_signature = nested_structure_coder.decode_proto(
147      function_spec_proto.input_signature)
148
149  # See `tf.function` and the JitCompile proto for details.
150  jit_compile = {
151      saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
152      saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
153      saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
154  }.get(function_spec_proto.jit_compile)
155
156  return function_spec_lib.FunctionSpec(
157      fullargspec=fullargspec,
158      is_method=False,
159      input_signature=input_signature,
160      jit_compile=jit_compile)
161
162
163# TODO(b/205016761): The fact that we can't derive ConcreteFunction calling
164# conventions from the serialized input spec right now is unfortunate. Merging
165# these would be good, maybe by adding TensorSpec names to cache keys so renamed
166# keyword arguments would yield different ConcreteFunctions.
167def setup_bare_concrete_function(saved_bare_concrete_function,
168                                 concrete_functions):
169  """Makes a restored bare concrete function callable."""
170  concrete_function = concrete_functions[
171      saved_bare_concrete_function.concrete_function_name]
172  # pylint: disable=protected-access
173  concrete_function._arg_keywords = (
174      saved_bare_concrete_function.argument_keywords)
175  concrete_function._num_positional_args = (
176      saved_bare_concrete_function.allowed_positional_arguments)
177  if saved_bare_concrete_function.HasField("function_spec"):
178    function_spec = _deserialize_function_spec_as_nonmethod(
179        saved_bare_concrete_function.function_spec)
180    concrete_function._set_function_spec(function_spec)
181  # pylint: enable=protected-access
182  concrete_function.add_to_graph()
183  return concrete_function
184
185
186class RestoredFunction(def_function.Function):
187  """Wrapper class for a function that has been restored from saved state.
188
189  See `def_function.Function`.
190  """
191
192  def __init__(self, python_function, name, function_spec, concrete_functions):
193    # TODO(b/205016819): We may enable autograph once exceptions are supported.
194    super(RestoredFunction, self).__init__(
195        python_function,
196        name,
197        autograph=False,
198        jit_compile=function_spec.jit_compile)
199    self.concrete_functions = concrete_functions
200    self._function_spec = function_spec
201
202    # Prevent RestoredFunction from spamming users with frequent tracing
203    # warnings.
204    self._omit_frequent_tracing_warning = True
205
206  @property
207  def _run_functions_eagerly(self):
208    # We do not have access to the original python function, and thus, we
209    # cannot meaningfully do anything but call our concrete function graphs
210    # under the hood.
211    #
212    # Attempting to call our bespoke python function (i.e.
213    # `restored_function_body`) will work so long as the user passes in all
214    # required and optional arguments. If an optional argument is missing,
215    # however, the call will break. For this reason, we instead skip the
216    # eager call path altogether if a user has enabled eager function execution
217    # via `tf.config.run_functions_eagerly`.
218    return False
219
220  def _list_all_concrete_functions_for_serialization(self):
221    return self.concrete_functions
222
223  def _defun_with_scope(self, scope):
224    func = super(RestoredFunction, self)._defun_with_scope(scope)
225    func._function_spec = self._function_spec  # pylint: disable=protected-access
226    return func
227
228
229def recreate_function(saved_function, concrete_functions):
230  """Creates a `Function` from a `SavedFunction`.
231
232  Args:
233    saved_function: `SavedFunction` proto.
234    concrete_functions: map from function name to `ConcreteFunction`. As a side
235      effect of this function, the `FunctionSpec` from `saved_function` is added
236      to each `ConcreteFunction` in this map.
237
238  Returns:
239    A `Function`.
240  """
241  # TODO(b/205017389): Construct a `Function` with the cache populated
242  # instead of creating a new `Function` backed by a Python layer to
243  # glue things together. Current approach is nesting functions deeper for each
244  # serialization cycle.
245
246  # Note: handling method functions is tricky since make_decorator does not
247  # allows control of "ismethod". Additionally since restored functions do
248  # not behave as methods i.e. they always use the same captured tensors
249  # independent of the object they are bound to, there is little value on
250  # propagating that correctly.
251  #
252  # Ideally this conversion should happen at serialization time. But since
253  # there are SavedModels which have "ismethod" populated and have an extra
254  # argument that they expect to be ignored, we do it at deserialization.
255  function_spec = _deserialize_function_spec_as_nonmethod(
256      saved_function.function_spec)
257
258  def restored_function_body(*args, **kwargs):
259    """Calls a restored function or raises an error if no matching function."""
260    if not saved_function.concrete_functions:
261      raise ValueError("Found zero restored functions for caller function.")
262    # This is the format of function.graph.structured_input_signature. At this
263    # point, the args and kwargs have already been canonicalized.
264    inputs = (args, kwargs)
265
266    # First try to find a concrete function that can be called without input
267    # conversions. This allows one to pick a more specific trace in case there
268    # was also a more expensive one that supported tensors.
269    for allow_conversion in [False, True]:
270      for function_name in saved_function.concrete_functions:
271        function = concrete_functions[function_name]
272        if any([inp is None for inp in function.captured_inputs]):
273          raise ValueError("Looks like you are trying to run a loaded "
274                           "non-Keras model that was trained using "
275                           "tf.distribute.experimental.ParameterServerStrategy "
276                           "with variable partitioning, which is not currently "
277                           "supported. Try using Keras to define your model "
278                           "if possible.")
279        if _concrete_function_callable_with(function, inputs, allow_conversion):
280          return _call_concrete_function(function, inputs)
281
282    signature_descriptions = []
283
284    def _pretty_format_positional(positional):
285      return "Positional arguments ({} total):\n    * {}".format(
286          len(positional),
287          "\n    * ".join(pprint.pformat(a) for a in positional))
288
289    for index, function_name in enumerate(saved_function.concrete_functions):
290      concrete_function = concrete_functions[function_name]
291      positional, keyword = concrete_function.structured_input_signature
292      signature_descriptions.append(
293          "Option {}:\n  {}\n  Keyword arguments: {}".format(
294              index + 1, _pretty_format_positional(positional), keyword))
295    raise ValueError(
296        "Could not find matching concrete function to call loaded from the "
297        f"SavedModel. Got:\n  {_pretty_format_positional(args)}\n  Keyword "
298        f"arguments: {kwargs}\n\n Expected these arguments to match one of the "
299        f"following {len(saved_function.concrete_functions)} option(s):\n\n"
300        f"{(chr(10)+chr(10)).join(signature_descriptions)}")
301
302  concrete_function_objects = []
303  for concrete_function_name in saved_function.concrete_functions:
304    concrete_function_objects.append(concrete_functions[concrete_function_name])
305
306  for cf in concrete_function_objects:
307    cf._set_function_spec(function_spec)  # pylint: disable=protected-access
308
309  restored_function = RestoredFunction(restored_function_body,
310                                       restored_function_body.__name__,
311                                       function_spec, concrete_function_objects)
312
313  return tf_decorator.make_decorator(
314      restored_function_body,
315      restored_function,
316      decorator_argspec=function_spec.fullargspec)
317
318
319def load_function_def_library(library,
320                              saved_object_graph=None,
321                              load_shared_name_suffix=None,
322                              wrapper_function=None):
323  """Load a set of functions as concrete functions without captured inputs.
324
325  Functions names are manipulated during load such that they do not overlap
326  with previously created ones.
327
328  Gradients are re-registered under new names. Ops that reference the gradients
329  are updated to reflect the new registered names.
330
331  Args:
332    library: FunctionDefLibrary proto message.
333    saved_object_graph: SavedObjectGraph proto message. If not passed in,
334      concrete function structured signatures and outputs will not be set.
335    load_shared_name_suffix: If specified, used to uniquify shared names.
336      Otherwise, a unique name is generated.
337    wrapper_function: An object that will be wrapped on newly created functions.
338
339  Returns:
340    Map of original function names in the library to instances of
341    `ConcreteFunction` without captured inputs.
342
343  Raises:
344    ValueError: if functions dependencies have a cycle.
345  """
346  library_function_names = set(fdef.signature.name for fdef in library.function)
347  functions = {}
348  renamed_functions = {}
349
350  # Our graph building code currently requires functions to be registered with
351  # some tf.Graph in order to import functions using the
352  # op-name-is-function-name calling convention. To avoid leaking memory into
353  # the global default graph when executing eagerly, we create a temporary
354  # Graph.
355  #
356  # TODO(b/205023033): Make this Graph creation unnecessary when executing
357  # eagerly by fixing function_def_to_graph_def.
358  if ops.executing_eagerly_outside_functions():
359    graph = ops.Graph()
360  else:
361    graph = ops.get_default_graph()
362
363  if load_shared_name_suffix is None:
364    load_shared_name_suffix = "_load_{}".format(ops.uid())
365
366  # Custom gradient functions must be re-registered under new UIDs.
367  library_gradient_names = {}  # Maps old op type to old function name
368  new_gradient_op_types = {}  # Maps old gradient op type to new op type.
369  gradients_to_register = {}  # Maps old function name to new op type
370  for gdef in library.registered_gradients:
371    if gdef.registered_op_type:
372      new_op_type = custom_gradient.generate_name()
373      old_op_type = compat.as_bytes(gdef.registered_op_type)
374
375      library_gradient_names[old_op_type] = gdef.gradient_func
376      new_gradient_op_types[old_op_type] = new_op_type
377      gradients_to_register[gdef.gradient_func] = new_op_type
378
379  function_deps = {}
380  for fdef in library.function:
381    function_deps[fdef.signature.name] = _list_function_deps(
382        fdef, library_function_names, library_gradient_names)
383
384  loaded_gradients = {}
385  for fdef in _sort_function_defs(library, function_deps):
386    orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix,
387                                   new_gradient_op_types)
388
389    # Setup function signatures and outputs
390    #
391    # When concrete functions are created normally (i.e. when they're originally
392    # created and not loaded via saved model), the inputs and outputs are
393    # calculated based on the values passed in by the user and returned from the
394    # original function, respectively. We don't have access to those anymore at
395    # restore time, so we must instead pass them to the FuncGraph explicitly.
396    structured_input_signature = None
397    structured_outputs = None
398    if (saved_object_graph is not None and
399        orig_name in saved_object_graph.concrete_functions):
400      # TODO(b/204324043): Offload the deserialization of the protos to the
401      # first class objects by passing the actual protos. This is blocked on
402      # importing `nested_structure_coder` in function.py causing a circular
403      # dependency.
404      proto = saved_object_graph.concrete_functions[orig_name]
405      structured_input_signature = nested_structure_coder.decode_proto(
406          proto.canonicalized_input_signature)
407      structured_outputs = nested_structure_coder.decode_proto(
408          proto.output_signature)
409
410    # There is no need to copy all functions into the function def graph. It
411    # leads to a O(n^2) increase of memory when importing functions and the
412    # extra function definitions are a no-op since they already imported as a
413    # function before and passed in explicitly (due to the topologic sort
414    # import).
415    with graph.as_default():
416      func_graph = function_def_lib.function_def_to_graph(
417          fdef,
418          structured_input_signature=structured_input_signature,
419          structured_outputs=structured_outputs)
420    # Restores gradients for function-call ops (not the same as ops that use
421    # custom gradients)
422    _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients)
423
424    for dep in function_deps[orig_name]:
425      functions[dep].add_to_graph(func_graph)
426
427    # We do not initialize the new ConcreteFunction's function_spec and/or
428    # arg_keywords here (which are used to parse the structured and flat
429    # signatures, respectively). ConcreteFunction that are part of a saved
430    # function is set up later by recreate_function(); and bare ConcreteFunction
431    # is set up by by setup_bare_concrete_function().
432    # However, we copy the FunctionDef attributes to the new ConcreteFunction,
433    # excluding the "_input_shapes", which may cause an error during input shape
434    # initialization at a later stage.
435    if "_input_shapes" in fdef.attr:
436      del fdef.attr["_input_shapes"]
437    func = function_lib.ConcreteFunction(func_graph, attrs=fdef.attr)
438    if wrapper_function:
439      func = wrapper_function(func)
440    func.add_to_graph(graph)
441
442    functions[orig_name] = func
443    renamed_functions[func.name] = func
444    if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
445      # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
446      # is fixed. Currently it's leaking memory to maintain bug compatibility
447      # with previous behavior.
448      func.add_to_graph(ops.get_default_graph())
449
450    if orig_name in gradients_to_register:
451      gradient_op_type = gradients_to_register[orig_name]
452      loaded_gradients[compat.as_bytes(gradient_op_type)] = func
453      ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))
454
455  return functions
456
457
458def _gen_gradient_func(func):
459  """Wraps a deserialized function."""
460
461  def gradient_func(unused_op, *result_grads):
462    # Replace all `None` arguments, because the traced custom gradient function
463    # expects tensors. Replacing with zeros is correct since the `None` values
464    # occur when the gradient is unconnected, and thus the gradient is
465    # "statically proven to be zero." See `tf.UnconnectedGradients` for details.
466    result_grads = [
467        x if x is not None else default_gradient.zeros_like(t)
468        for (x, t) in zip(result_grads, func.graph.inputs)
469    ]
470
471    return func(*result_grads)
472
473  return gradient_func
474
475
476def _restore_gradient_functions(func_graph, renamed_functions,
477                                loaded_gradients):
478  """Populate function op's _gradient_function with default gradient."""
479  for op in func_graph.get_operations():
480    # TODO(b/205024208): This code assumes that the gradient registered for this
481    # function call is the default gradient for the function and not a custom
482    # one.
483    if op.type in ["StatefulPartitionedCall", "PartitionedCall"]:
484      function = renamed_functions[compat.as_bytes(
485          op.node_def.attr["f"].func.name)]
486      op._gradient_function = function._get_gradient_function()  # pylint: disable=protected-access
487    try:
488      gradient_op_type = op.get_attr("_gradient_op_type")
489    except ValueError:
490      pass
491    else:
492      if gradient_op_type in loaded_gradients:
493        grad_fn = loaded_gradients[gradient_op_type]
494        grad_fn._num_positional_args = len(op.inputs)  # pylint: disable=protected-access
495        grad_fn._arg_keywords = [inp.name for inp in op.inputs]  # pylint: disable=protected-access
496
497
498def _sort_function_defs(library, function_deps):
499  """Return a topologic sort of FunctionDefs in a library."""
500  edges = collections.defaultdict(list)
501  in_count = collections.defaultdict(lambda: 0)
502
503  for fname, deps in function_deps.items():
504    for dep in deps:
505      edges[dep].append(fname)
506      in_count[fname] += 1
507  ready = [
508      fdef.signature.name
509      for fdef in library.function
510      if in_count[fdef.signature.name] == 0
511  ]
512  output = []
513  while ready:
514    node = ready.pop()
515    output.append(node)
516    for dest in edges[node]:
517      in_count[dest] -= 1
518      if not in_count[dest]:
519        ready.append(dest)
520
521  if len(output) != len(library.function):
522    failed_to_resolve = sorted(set(in_count.keys()) - set(output))
523    raise ValueError("There is a cyclic dependency between functions. ",
524                     f"Could not resolve {failed_to_resolve}.")
525
526  reverse = {fdef.signature.name: fdef for fdef in library.function}
527  return [reverse[x] for x in output]
528
529
530def _get_gradient_op_type(node_def):
531  """Returns the custom gradient op type."""
532  if ("_gradient_op_type" in node_def.attr and
533      node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]):
534    return node_def.attr["_gradient_op_type"].s
535  return None
536
537
538def fix_node_def(node_def, functions, shared_name_suffix):
539  """Replace functions calls and shared names in `node_def`."""
540  if node_def.op in functions:
541    node_def.op = functions[node_def.op].name
542  for _, attr_value in node_def.attr.items():
543    if attr_value.WhichOneof("value") == "func":
544      attr_value.func.name = functions[attr_value.func.name].name
545    elif attr_value.WhichOneof("value") == "list":
546      for fn in attr_value.list.func:
547        fn.name = functions[fn.name].name
548
549  # Fix old table creation bug.
550  if node_def.op == "HashTableV2":
551    if ("use_node_name_sharing" not in node_def.attr or
552        not node_def.attr["use_node_name_sharing"].b):
553      node_def.attr["use_node_name_sharing"].b = True
554      # We are turning on node mame sharing, so have to make sure we don't
555      # accidentally share a table resource.
556      shared_name_suffix += "_{}".format(ops.uid())
557
558  # TODO(b/124205571): Avoid accidental sharing and destruction of restored
559  # resources. For now uniquify "shared_name" when loading functions to avoid
560  # sharing.
561  # TODO: Add regression test for b/150826922.
562  op_def = op_def_registry.get(node_def.op)
563  if op_def:
564    attr = next((a for a in op_def.attr if a.name == "shared_name"), None)
565    if attr:
566      shared_name = None
567      if "shared_name" in node_def.attr and node_def.attr["shared_name"].s:
568        shared_name = node_def.attr["shared_name"].s
569      elif attr.default_value.s:
570        shared_name = compat.as_bytes(attr.default_value.s)
571      if not shared_name:
572        shared_name = compat.as_bytes(node_def.name)
573
574      node_def.attr["shared_name"].s = (
575          shared_name + compat.as_bytes(shared_name_suffix))
576
577
578def _fix_fdef_in_place(fdef, functions, shared_name_suffix,
579                       new_gradient_op_types):
580  """Fixes a FunctionDef proto to be loaded in current context.
581
582  In particular, when loading a function library into an eager context, one
583  must rename the functions to avoid conflicts with existent functions.
584
585  Args:
586    fdef: FunctionDef proto to fix. It is mutated in-place.
587    functions: map from function name to a ConcreteFunction instance.
588    shared_name_suffix: A unique string for this load which helps to avoid
589      `shared_name` collisions across loads. Two functions from the same load
590      using the same `shared_name` still need to share, but functions from
591      different loads with the same `shared_name` should not.
592    new_gradient_op_types: map from old gradient op type to newly generated op
593      type.
594
595  Returns:
596    orig_name: original value of fdef.signature.name
597  """
598  orig_name = fdef.signature.name
599  contains_unsaved_custom_gradients = False
600
601  for node_def in fdef.node_def:
602    fix_node_def(node_def, functions, shared_name_suffix)
603    op_type = _get_gradient_op_type(node_def)
604    if op_type is not None:
605      if op_type in new_gradient_op_types:
606        node_def.attr["_gradient_op_type"].s = compat.as_bytes(
607            new_gradient_op_types[op_type])
608      else:
609        contains_unsaved_custom_gradients = True
610  if contains_unsaved_custom_gradients:
611    logging.warning(
612        "Importing a function (%s) with ops with unsaved custom gradients. Will"
613        " likely fail if a gradient is requested.", fdef.signature.name)
614
615  fdef.signature.name = _clean_function_name(fdef.signature.name)
616  return orig_name
617
618
619def _list_function_deps(fdef, library_function_names, library_gradient_names):
620  """Find functions referenced in `fdef`."""
621  # TODO(b/205023953): Recurse into list attributes and into NameAttrList attrs
622  # both when listing deps and when fixing them. `function_def_to_graph` also
623  # requires fixes.
624  deps = set()
625  for node_def in fdef.node_def:
626    grad_op_type = _get_gradient_op_type(node_def)
627    if node_def.op in library_function_names:
628      deps.add(node_def.op)
629    elif grad_op_type and grad_op_type in library_gradient_names:
630      deps.add(library_gradient_names[grad_op_type])
631    else:
632      for _, attr_value in node_def.attr.items():
633        if attr_value.WhichOneof("value") == "func":
634          deps.add(attr_value.func.name)
635        elif attr_value.WhichOneof("value") == "list":
636          for fn in attr_value.list.func:
637            deps.add(fn.name)
638
639  return deps
640
641
642_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX
643                                                 )  # pylint:disable=protected-access
644
645
646def _clean_function_name(name):
647  """Vanity function to keep the function names comprehensible."""
648  # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
649  # its name becomes "__inference_<orig>_xyz".
650  match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name)
651  if match:
652    return match.group(1)
653  else:
654    return name
655