1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for deserializing `Function`s.""" 16 17import collections 18import pprint 19import re 20from absl import logging 21 22from tensorflow.core.protobuf import saved_object_graph_pb2 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import function as function_lib 25from tensorflow.python.eager import function_spec as function_spec_lib 26from tensorflow.python.framework import func_graph as func_graph_lib 27from tensorflow.python.framework import function_def_to_graph as function_def_lib 28from tensorflow.python.framework import op_def_registry 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import type_spec 32from tensorflow.python.ops import custom_gradient 33from tensorflow.python.ops import default_gradient 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.saved_model import nested_structure_coder 36from tensorflow.python.util import compat 37from tensorflow.python.util import nest 38from tensorflow.python.util import tf_decorator 39from tensorflow.python.util import tf_inspect 40 41 42def _is_tensor(t): 43 return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 44 45 46# TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the 47# structured signature. 48def _call_concrete_function(function, inputs): 49 """Calls a restored Function with structured inputs. 50 51 This differs from `function.__call__` in that inputs and outputs are 52 structured and that it casts inputs to tensors if needed. 53 54 Note: this does not checks that non-tensor inputs match. That should be 55 done before via `_concrete_function_callable_with`. 56 57 Args: 58 function: ConcreteFunction to call. 59 inputs: Structured inputs compatible with 60 `function.graph.structured_input_signature`. 61 62 Returns: 63 The structured function output. 64 """ 65 expected_structure = function.graph.structured_input_signature 66 flatten_inputs = nest.flatten_up_to( 67 expected_structure, inputs, expand_composites=True) 68 flatten_expected = nest.flatten(expected_structure, expand_composites=True) 69 tensor_inputs = [] 70 for arg, expected in zip(flatten_inputs, flatten_expected): 71 if isinstance(expected, tensor_spec.TensorSpec): 72 tensor_inputs.append( 73 ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) 74 elif isinstance(expected, resource_variable_ops.VariableSpec): 75 tensor_inputs.append(arg) 76 result = function._call_flat(tensor_inputs, function.captured_inputs) # pylint: disable=protected-access 77 if isinstance(result, ops.Operation): 78 return None 79 return result 80 81 82def _try_convert_to_tensor_spec(arg, dtype_hint): 83 """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" 84 try: 85 # Note: try conversion in a FuncGraph to avoid polluting current context. 86 with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): 87 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) 88 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) 89 except (TypeError, ValueError): 90 return None 91 92 93def _concrete_function_callable_with(function, inputs, allow_conversion): 94 """Returns whether concrete `function` can be called with `inputs`.""" 95 expected_structure = function.graph.structured_input_signature 96 try: 97 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 98 except (TypeError, ValueError): 99 return False 100 101 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 102 if isinstance(expected, tensor_spec.TensorSpec): 103 if allow_conversion: 104 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) 105 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): 106 return False 107 if arg.dtype != expected.dtype: 108 return False 109 if not expected.shape.is_compatible_with(arg.shape): 110 return False 111 elif isinstance(expected, type_spec.TypeSpec): 112 if not expected.is_compatible_with(arg): 113 return False 114 elif _is_tensor(arg): 115 if id(arg) != id(expected): 116 return False 117 else: 118 if arg != expected: 119 return False 120 return True 121 122 123def _deserialize_function_spec_as_nonmethod(function_spec_proto): 124 """Deserialize a FunctionSpec object from its proto representation.""" 125 typeless_fullargspec = nested_structure_coder.decode_proto( 126 function_spec_proto.fullargspec) 127 128 # Convert a method function into a non method. 129 if function_spec_proto.is_method: 130 if not typeless_fullargspec.args: 131 raise NotImplementedError( 132 "Cannot deserialize a method function without a named " 133 "'self' argument.") 134 args = typeless_fullargspec.args[1:] 135 else: 136 args = typeless_fullargspec.args 137 138 fullargspec = tf_inspect.FullArgSpec( 139 args=args, 140 varargs=typeless_fullargspec.varargs, 141 varkw=typeless_fullargspec.varkw, 142 defaults=typeless_fullargspec.defaults, 143 kwonlyargs=typeless_fullargspec.kwonlyargs, 144 kwonlydefaults=typeless_fullargspec.kwonlydefaults, 145 annotations=typeless_fullargspec.annotations) 146 input_signature = nested_structure_coder.decode_proto( 147 function_spec_proto.input_signature) 148 149 # See `tf.function` and the JitCompile proto for details. 150 jit_compile = { 151 saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, 152 saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, 153 saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, 154 }.get(function_spec_proto.jit_compile) 155 156 return function_spec_lib.FunctionSpec( 157 fullargspec=fullargspec, 158 is_method=False, 159 input_signature=input_signature, 160 jit_compile=jit_compile) 161 162 163# TODO(b/205016761): The fact that we can't derive ConcreteFunction calling 164# conventions from the serialized input spec right now is unfortunate. Merging 165# these would be good, maybe by adding TensorSpec names to cache keys so renamed 166# keyword arguments would yield different ConcreteFunctions. 167def setup_bare_concrete_function(saved_bare_concrete_function, 168 concrete_functions): 169 """Makes a restored bare concrete function callable.""" 170 concrete_function = concrete_functions[ 171 saved_bare_concrete_function.concrete_function_name] 172 # pylint: disable=protected-access 173 concrete_function._arg_keywords = ( 174 saved_bare_concrete_function.argument_keywords) 175 concrete_function._num_positional_args = ( 176 saved_bare_concrete_function.allowed_positional_arguments) 177 if saved_bare_concrete_function.HasField("function_spec"): 178 function_spec = _deserialize_function_spec_as_nonmethod( 179 saved_bare_concrete_function.function_spec) 180 concrete_function._set_function_spec(function_spec) 181 # pylint: enable=protected-access 182 concrete_function.add_to_graph() 183 return concrete_function 184 185 186class RestoredFunction(def_function.Function): 187 """Wrapper class for a function that has been restored from saved state. 188 189 See `def_function.Function`. 190 """ 191 192 def __init__(self, python_function, name, function_spec, concrete_functions): 193 # TODO(b/205016819): We may enable autograph once exceptions are supported. 194 super(RestoredFunction, self).__init__( 195 python_function, 196 name, 197 autograph=False, 198 jit_compile=function_spec.jit_compile) 199 self.concrete_functions = concrete_functions 200 self._function_spec = function_spec 201 202 # Prevent RestoredFunction from spamming users with frequent tracing 203 # warnings. 204 self._omit_frequent_tracing_warning = True 205 206 @property 207 def _run_functions_eagerly(self): 208 # We do not have access to the original python function, and thus, we 209 # cannot meaningfully do anything but call our concrete function graphs 210 # under the hood. 211 # 212 # Attempting to call our bespoke python function (i.e. 213 # `restored_function_body`) will work so long as the user passes in all 214 # required and optional arguments. If an optional argument is missing, 215 # however, the call will break. For this reason, we instead skip the 216 # eager call path altogether if a user has enabled eager function execution 217 # via `tf.config.run_functions_eagerly`. 218 return False 219 220 def _list_all_concrete_functions_for_serialization(self): 221 return self.concrete_functions 222 223 def _defun_with_scope(self, scope): 224 func = super(RestoredFunction, self)._defun_with_scope(scope) 225 func._function_spec = self._function_spec # pylint: disable=protected-access 226 return func 227 228 229def recreate_function(saved_function, concrete_functions): 230 """Creates a `Function` from a `SavedFunction`. 231 232 Args: 233 saved_function: `SavedFunction` proto. 234 concrete_functions: map from function name to `ConcreteFunction`. As a side 235 effect of this function, the `FunctionSpec` from `saved_function` is added 236 to each `ConcreteFunction` in this map. 237 238 Returns: 239 A `Function`. 240 """ 241 # TODO(b/205017389): Construct a `Function` with the cache populated 242 # instead of creating a new `Function` backed by a Python layer to 243 # glue things together. Current approach is nesting functions deeper for each 244 # serialization cycle. 245 246 # Note: handling method functions is tricky since make_decorator does not 247 # allows control of "ismethod". Additionally since restored functions do 248 # not behave as methods i.e. they always use the same captured tensors 249 # independent of the object they are bound to, there is little value on 250 # propagating that correctly. 251 # 252 # Ideally this conversion should happen at serialization time. But since 253 # there are SavedModels which have "ismethod" populated and have an extra 254 # argument that they expect to be ignored, we do it at deserialization. 255 function_spec = _deserialize_function_spec_as_nonmethod( 256 saved_function.function_spec) 257 258 def restored_function_body(*args, **kwargs): 259 """Calls a restored function or raises an error if no matching function.""" 260 if not saved_function.concrete_functions: 261 raise ValueError("Found zero restored functions for caller function.") 262 # This is the format of function.graph.structured_input_signature. At this 263 # point, the args and kwargs have already been canonicalized. 264 inputs = (args, kwargs) 265 266 # First try to find a concrete function that can be called without input 267 # conversions. This allows one to pick a more specific trace in case there 268 # was also a more expensive one that supported tensors. 269 for allow_conversion in [False, True]: 270 for function_name in saved_function.concrete_functions: 271 function = concrete_functions[function_name] 272 if any([inp is None for inp in function.captured_inputs]): 273 raise ValueError("Looks like you are trying to run a loaded " 274 "non-Keras model that was trained using " 275 "tf.distribute.experimental.ParameterServerStrategy " 276 "with variable partitioning, which is not currently " 277 "supported. Try using Keras to define your model " 278 "if possible.") 279 if _concrete_function_callable_with(function, inputs, allow_conversion): 280 return _call_concrete_function(function, inputs) 281 282 signature_descriptions = [] 283 284 def _pretty_format_positional(positional): 285 return "Positional arguments ({} total):\n * {}".format( 286 len(positional), 287 "\n * ".join(pprint.pformat(a) for a in positional)) 288 289 for index, function_name in enumerate(saved_function.concrete_functions): 290 concrete_function = concrete_functions[function_name] 291 positional, keyword = concrete_function.structured_input_signature 292 signature_descriptions.append( 293 "Option {}:\n {}\n Keyword arguments: {}".format( 294 index + 1, _pretty_format_positional(positional), keyword)) 295 raise ValueError( 296 "Could not find matching concrete function to call loaded from the " 297 f"SavedModel. Got:\n {_pretty_format_positional(args)}\n Keyword " 298 f"arguments: {kwargs}\n\n Expected these arguments to match one of the " 299 f"following {len(saved_function.concrete_functions)} option(s):\n\n" 300 f"{(chr(10)+chr(10)).join(signature_descriptions)}") 301 302 concrete_function_objects = [] 303 for concrete_function_name in saved_function.concrete_functions: 304 concrete_function_objects.append(concrete_functions[concrete_function_name]) 305 306 for cf in concrete_function_objects: 307 cf._set_function_spec(function_spec) # pylint: disable=protected-access 308 309 restored_function = RestoredFunction(restored_function_body, 310 restored_function_body.__name__, 311 function_spec, concrete_function_objects) 312 313 return tf_decorator.make_decorator( 314 restored_function_body, 315 restored_function, 316 decorator_argspec=function_spec.fullargspec) 317 318 319def load_function_def_library(library, 320 saved_object_graph=None, 321 load_shared_name_suffix=None, 322 wrapper_function=None): 323 """Load a set of functions as concrete functions without captured inputs. 324 325 Functions names are manipulated during load such that they do not overlap 326 with previously created ones. 327 328 Gradients are re-registered under new names. Ops that reference the gradients 329 are updated to reflect the new registered names. 330 331 Args: 332 library: FunctionDefLibrary proto message. 333 saved_object_graph: SavedObjectGraph proto message. If not passed in, 334 concrete function structured signatures and outputs will not be set. 335 load_shared_name_suffix: If specified, used to uniquify shared names. 336 Otherwise, a unique name is generated. 337 wrapper_function: An object that will be wrapped on newly created functions. 338 339 Returns: 340 Map of original function names in the library to instances of 341 `ConcreteFunction` without captured inputs. 342 343 Raises: 344 ValueError: if functions dependencies have a cycle. 345 """ 346 library_function_names = set(fdef.signature.name for fdef in library.function) 347 functions = {} 348 renamed_functions = {} 349 350 # Our graph building code currently requires functions to be registered with 351 # some tf.Graph in order to import functions using the 352 # op-name-is-function-name calling convention. To avoid leaking memory into 353 # the global default graph when executing eagerly, we create a temporary 354 # Graph. 355 # 356 # TODO(b/205023033): Make this Graph creation unnecessary when executing 357 # eagerly by fixing function_def_to_graph_def. 358 if ops.executing_eagerly_outside_functions(): 359 graph = ops.Graph() 360 else: 361 graph = ops.get_default_graph() 362 363 if load_shared_name_suffix is None: 364 load_shared_name_suffix = "_load_{}".format(ops.uid()) 365 366 # Custom gradient functions must be re-registered under new UIDs. 367 library_gradient_names = {} # Maps old op type to old function name 368 new_gradient_op_types = {} # Maps old gradient op type to new op type. 369 gradients_to_register = {} # Maps old function name to new op type 370 for gdef in library.registered_gradients: 371 if gdef.registered_op_type: 372 new_op_type = custom_gradient.generate_name() 373 old_op_type = compat.as_bytes(gdef.registered_op_type) 374 375 library_gradient_names[old_op_type] = gdef.gradient_func 376 new_gradient_op_types[old_op_type] = new_op_type 377 gradients_to_register[gdef.gradient_func] = new_op_type 378 379 function_deps = {} 380 for fdef in library.function: 381 function_deps[fdef.signature.name] = _list_function_deps( 382 fdef, library_function_names, library_gradient_names) 383 384 loaded_gradients = {} 385 for fdef in _sort_function_defs(library, function_deps): 386 orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, 387 new_gradient_op_types) 388 389 # Setup function signatures and outputs 390 # 391 # When concrete functions are created normally (i.e. when they're originally 392 # created and not loaded via saved model), the inputs and outputs are 393 # calculated based on the values passed in by the user and returned from the 394 # original function, respectively. We don't have access to those anymore at 395 # restore time, so we must instead pass them to the FuncGraph explicitly. 396 structured_input_signature = None 397 structured_outputs = None 398 if (saved_object_graph is not None and 399 orig_name in saved_object_graph.concrete_functions): 400 # TODO(b/204324043): Offload the deserialization of the protos to the 401 # first class objects by passing the actual protos. This is blocked on 402 # importing `nested_structure_coder` in function.py causing a circular 403 # dependency. 404 proto = saved_object_graph.concrete_functions[orig_name] 405 structured_input_signature = nested_structure_coder.decode_proto( 406 proto.canonicalized_input_signature) 407 structured_outputs = nested_structure_coder.decode_proto( 408 proto.output_signature) 409 410 # There is no need to copy all functions into the function def graph. It 411 # leads to a O(n^2) increase of memory when importing functions and the 412 # extra function definitions are a no-op since they already imported as a 413 # function before and passed in explicitly (due to the topologic sort 414 # import). 415 with graph.as_default(): 416 func_graph = function_def_lib.function_def_to_graph( 417 fdef, 418 structured_input_signature=structured_input_signature, 419 structured_outputs=structured_outputs) 420 # Restores gradients for function-call ops (not the same as ops that use 421 # custom gradients) 422 _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients) 423 424 for dep in function_deps[orig_name]: 425 functions[dep].add_to_graph(func_graph) 426 427 # We do not initialize the new ConcreteFunction's function_spec and/or 428 # arg_keywords here (which are used to parse the structured and flat 429 # signatures, respectively). ConcreteFunction that are part of a saved 430 # function is set up later by recreate_function(); and bare ConcreteFunction 431 # is set up by by setup_bare_concrete_function(). 432 # However, we copy the FunctionDef attributes to the new ConcreteFunction, 433 # excluding the "_input_shapes", which may cause an error during input shape 434 # initialization at a later stage. 435 if "_input_shapes" in fdef.attr: 436 del fdef.attr["_input_shapes"] 437 func = function_lib.ConcreteFunction(func_graph, attrs=fdef.attr) 438 if wrapper_function: 439 func = wrapper_function(func) 440 func.add_to_graph(graph) 441 442 functions[orig_name] = func 443 renamed_functions[func.name] = func 444 if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): 445 # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration 446 # is fixed. Currently it's leaking memory to maintain bug compatibility 447 # with previous behavior. 448 func.add_to_graph(ops.get_default_graph()) 449 450 if orig_name in gradients_to_register: 451 gradient_op_type = gradients_to_register[orig_name] 452 loaded_gradients[compat.as_bytes(gradient_op_type)] = func 453 ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func)) 454 455 return functions 456 457 458def _gen_gradient_func(func): 459 """Wraps a deserialized function.""" 460 461 def gradient_func(unused_op, *result_grads): 462 # Replace all `None` arguments, because the traced custom gradient function 463 # expects tensors. Replacing with zeros is correct since the `None` values 464 # occur when the gradient is unconnected, and thus the gradient is 465 # "statically proven to be zero." See `tf.UnconnectedGradients` for details. 466 result_grads = [ 467 x if x is not None else default_gradient.zeros_like(t) 468 for (x, t) in zip(result_grads, func.graph.inputs) 469 ] 470 471 return func(*result_grads) 472 473 return gradient_func 474 475 476def _restore_gradient_functions(func_graph, renamed_functions, 477 loaded_gradients): 478 """Populate function op's _gradient_function with default gradient.""" 479 for op in func_graph.get_operations(): 480 # TODO(b/205024208): This code assumes that the gradient registered for this 481 # function call is the default gradient for the function and not a custom 482 # one. 483 if op.type in ["StatefulPartitionedCall", "PartitionedCall"]: 484 function = renamed_functions[compat.as_bytes( 485 op.node_def.attr["f"].func.name)] 486 op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access 487 try: 488 gradient_op_type = op.get_attr("_gradient_op_type") 489 except ValueError: 490 pass 491 else: 492 if gradient_op_type in loaded_gradients: 493 grad_fn = loaded_gradients[gradient_op_type] 494 grad_fn._num_positional_args = len(op.inputs) # pylint: disable=protected-access 495 grad_fn._arg_keywords = [inp.name for inp in op.inputs] # pylint: disable=protected-access 496 497 498def _sort_function_defs(library, function_deps): 499 """Return a topologic sort of FunctionDefs in a library.""" 500 edges = collections.defaultdict(list) 501 in_count = collections.defaultdict(lambda: 0) 502 503 for fname, deps in function_deps.items(): 504 for dep in deps: 505 edges[dep].append(fname) 506 in_count[fname] += 1 507 ready = [ 508 fdef.signature.name 509 for fdef in library.function 510 if in_count[fdef.signature.name] == 0 511 ] 512 output = [] 513 while ready: 514 node = ready.pop() 515 output.append(node) 516 for dest in edges[node]: 517 in_count[dest] -= 1 518 if not in_count[dest]: 519 ready.append(dest) 520 521 if len(output) != len(library.function): 522 failed_to_resolve = sorted(set(in_count.keys()) - set(output)) 523 raise ValueError("There is a cyclic dependency between functions. ", 524 f"Could not resolve {failed_to_resolve}.") 525 526 reverse = {fdef.signature.name: fdef for fdef in library.function} 527 return [reverse[x] for x in output] 528 529 530def _get_gradient_op_type(node_def): 531 """Returns the custom gradient op type.""" 532 if ("_gradient_op_type" in node_def.attr and 533 node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): 534 return node_def.attr["_gradient_op_type"].s 535 return None 536 537 538def fix_node_def(node_def, functions, shared_name_suffix): 539 """Replace functions calls and shared names in `node_def`.""" 540 if node_def.op in functions: 541 node_def.op = functions[node_def.op].name 542 for _, attr_value in node_def.attr.items(): 543 if attr_value.WhichOneof("value") == "func": 544 attr_value.func.name = functions[attr_value.func.name].name 545 elif attr_value.WhichOneof("value") == "list": 546 for fn in attr_value.list.func: 547 fn.name = functions[fn.name].name 548 549 # Fix old table creation bug. 550 if node_def.op == "HashTableV2": 551 if ("use_node_name_sharing" not in node_def.attr or 552 not node_def.attr["use_node_name_sharing"].b): 553 node_def.attr["use_node_name_sharing"].b = True 554 # We are turning on node mame sharing, so have to make sure we don't 555 # accidentally share a table resource. 556 shared_name_suffix += "_{}".format(ops.uid()) 557 558 # TODO(b/124205571): Avoid accidental sharing and destruction of restored 559 # resources. For now uniquify "shared_name" when loading functions to avoid 560 # sharing. 561 # TODO: Add regression test for b/150826922. 562 op_def = op_def_registry.get(node_def.op) 563 if op_def: 564 attr = next((a for a in op_def.attr if a.name == "shared_name"), None) 565 if attr: 566 shared_name = None 567 if "shared_name" in node_def.attr and node_def.attr["shared_name"].s: 568 shared_name = node_def.attr["shared_name"].s 569 elif attr.default_value.s: 570 shared_name = compat.as_bytes(attr.default_value.s) 571 if not shared_name: 572 shared_name = compat.as_bytes(node_def.name) 573 574 node_def.attr["shared_name"].s = ( 575 shared_name + compat.as_bytes(shared_name_suffix)) 576 577 578def _fix_fdef_in_place(fdef, functions, shared_name_suffix, 579 new_gradient_op_types): 580 """Fixes a FunctionDef proto to be loaded in current context. 581 582 In particular, when loading a function library into an eager context, one 583 must rename the functions to avoid conflicts with existent functions. 584 585 Args: 586 fdef: FunctionDef proto to fix. It is mutated in-place. 587 functions: map from function name to a ConcreteFunction instance. 588 shared_name_suffix: A unique string for this load which helps to avoid 589 `shared_name` collisions across loads. Two functions from the same load 590 using the same `shared_name` still need to share, but functions from 591 different loads with the same `shared_name` should not. 592 new_gradient_op_types: map from old gradient op type to newly generated op 593 type. 594 595 Returns: 596 orig_name: original value of fdef.signature.name 597 """ 598 orig_name = fdef.signature.name 599 contains_unsaved_custom_gradients = False 600 601 for node_def in fdef.node_def: 602 fix_node_def(node_def, functions, shared_name_suffix) 603 op_type = _get_gradient_op_type(node_def) 604 if op_type is not None: 605 if op_type in new_gradient_op_types: 606 node_def.attr["_gradient_op_type"].s = compat.as_bytes( 607 new_gradient_op_types[op_type]) 608 else: 609 contains_unsaved_custom_gradients = True 610 if contains_unsaved_custom_gradients: 611 logging.warning( 612 "Importing a function (%s) with ops with unsaved custom gradients. Will" 613 " likely fail if a gradient is requested.", fdef.signature.name) 614 615 fdef.signature.name = _clean_function_name(fdef.signature.name) 616 return orig_name 617 618 619def _list_function_deps(fdef, library_function_names, library_gradient_names): 620 """Find functions referenced in `fdef`.""" 621 # TODO(b/205023953): Recurse into list attributes and into NameAttrList attrs 622 # both when listing deps and when fixing them. `function_def_to_graph` also 623 # requires fixes. 624 deps = set() 625 for node_def in fdef.node_def: 626 grad_op_type = _get_gradient_op_type(node_def) 627 if node_def.op in library_function_names: 628 deps.add(node_def.op) 629 elif grad_op_type and grad_op_type in library_gradient_names: 630 deps.add(library_gradient_names[grad_op_type]) 631 else: 632 for _, attr_value in node_def.attr.items(): 633 if attr_value.WhichOneof("value") == "func": 634 deps.add(attr_value.func.name) 635 elif attr_value.WhichOneof("value") == "list": 636 for fn in attr_value.list.func: 637 deps.add(fn.name) 638 639 return deps 640 641 642_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX 643 ) # pylint:disable=protected-access 644 645 646def _clean_function_name(name): 647 """Vanity function to keep the function names comprehensible.""" 648 # Note: each time a function is wrapped into `function_lib.ConcreteFunction` 649 # its name becomes "__inference_<orig>_xyz". 650 match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name) 651 if match: 652 return match.group(1) 653 else: 654 return name 655