1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Defun decorator for defining graph-mode functions.""" 17 18import collections 19import pprint 20import threading 21import types as types_lib 22from typing import List 23import weakref 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import function_pb2 27from tensorflow.core.function.polymorphism import function_cache 28from tensorflow.python import pywrap_tfe 29from tensorflow.python.client import pywrap_tf_session 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import backprop_util 32from tensorflow.python.eager import context 33from tensorflow.python.eager import execute 34from tensorflow.python.eager import forwardprop_util 35from tensorflow.python.eager import function_context 36from tensorflow.python.eager import function_saved_model_utils 37from tensorflow.python.eager import function_spec 38from tensorflow.python.eager import monitoring 39from tensorflow.python.eager import tape 40from tensorflow.python.eager.graph_only_ops import graph_placeholder 41from tensorflow.python.framework import c_api_util 42from tensorflow.python.framework import composite_tensor 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import error_interpolation 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import func_graph as func_graph_module 47from tensorflow.python.framework import indexed_slices 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.framework import tensor_spec 51from tensorflow.python.framework import type_spec 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import default_gradient 54from tensorflow.python.ops import functional_ops 55from tensorflow.python.ops import gradients_util 56from tensorflow.python.ops import handle_data_util 57from tensorflow.python.ops import resource_variable_ops 58from tensorflow.python.platform import tf_logging as logging 59from tensorflow.python.profiler import trace 60from tensorflow.python.trackable import base as trackable 61from tensorflow.python.types import core 62from tensorflow.python.util import _pywrap_utils 63from tensorflow.python.util import compat 64from tensorflow.python.util import function_utils 65from tensorflow.python.util import lazy_loader 66from tensorflow.python.util import memory 67from tensorflow.python.util import nest 68from tensorflow.python.util import object_identity 69from tensorflow.python.util import tf_decorator 70from tensorflow.python.util import tf_inspect 71from tensorflow.python.util.tf_export import tf_export 72 73# Loaded lazily due to a circular dependency (roughly 74# tf.function->autograph->->dataset->tf.function). 75# TODO(b/133251390): Use a regular import. 76ag_ctx = lazy_loader.LazyLoader( 77 "ag_ctx", globals(), 78 "tensorflow.python.autograph.core.ag_ctx") 79np_arrays = lazy_loader.LazyLoader( 80 "np_arrays", globals(), 81 "tensorflow.python.ops.numpy_ops.np_arrays") 82 83 84FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" 85BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" 86IMPLEMENTS_ATTRIBUTE_NAME = "_implements" 87SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous" 88 89_graph_building_time_counter = monitoring.Counter( 90 "/tensorflow/core/tf_function/graph_building_time_usecs", 91 "Time for tf.function to build a graph (us).") 92 93 94def _type_spec_for(x): 95 """Returns a TypeSpec for `x`, or `x` if `x` doesn't have a TensorSpec.""" 96 if isinstance(x, ops.Tensor): 97 # We intentionally leave out the name of x from the TensorSpec here, 98 # because the name of a TensorSpec will override arg_name 99 # in the '_get_defun_inputs' method in func_graph.py. 100 return tensor_spec.TensorSpec(x.shape, x.dtype) 101 elif isinstance(x, type_spec.TypeSpec): 102 return x 103 elif isinstance(x, composite_tensor.CompositeTensor): 104 return x._type_spec # pylint: disable=protected-access 105 else: 106 return x 107 108 109def _is_type_subset(a, b): 110 """Returns true if `b` is a subset of type `a` (or if a is not a TypeSpec.)""" 111 if isinstance(a, type_spec.TypeSpec): 112 return a.most_specific_compatible_type(b) == a 113 return True 114 115 116def common_shape(x, y): 117 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 118 if x is None != y is None: 119 raise RuntimeError( 120 "Cannot find a common shape when LHS shape is None but RHS shape " 121 f"is not (or vice versa): {x} vs. {y}.") 122 if x is None: 123 return None # The associated input was not a Tensor, no shape generated. 124 if not isinstance(x, tensor_shape.TensorShape): 125 raise TypeError(f"`x` must be a TensorShape, got type {type(x)}.") 126 if not isinstance(y, tensor_shape.TensorShape): 127 raise TypeError(f"`y` must be a TensorShape, got type {type(y)}.") 128 if x.rank != y.rank or x.rank is None: 129 return tensor_shape.TensorShape(None) 130 dims = [] 131 for dim_x, dim_y in zip(x.dims, y.dims): 132 if (dim_x != dim_y 133 or tensor_shape.dimension_value(dim_x) is None 134 or tensor_shape.dimension_value(dim_y) is None): 135 dims.append(None) 136 else: 137 dims.append(tensor_shape.dimension_value(dim_x)) 138 return tensor_shape.TensorShape(dims) 139 140 141def _parse_func_attrs(attributes): 142 """Convert the keyword arguments into function_def attributes. 143 144 Currently only support primitive types: bool, int, float and string. 145 146 Args: 147 attributes: the dictionary of attributes. 148 Returns: 149 A dict of attributes where the key is the name of attribute and the value 150 is the AttrValue proto. 151 Raises: 152 ValueError: If the kwargs contains unallowlisted name or unsupported value 153 types. 154 """ 155 attrs = {} 156 for key, value in attributes.items(): 157 if isinstance(value, attr_value_pb2.AttrValue): 158 attrs[key] = value 159 # bool type check has to happen before int since bool is a subclass of int. 160 elif isinstance(value, bool): 161 attrs[key] = attr_value_pb2.AttrValue(b=value) 162 elif isinstance(value, int): 163 attrs[key] = attr_value_pb2.AttrValue(i=value) 164 elif isinstance(value, float): 165 attrs[key] = attr_value_pb2.AttrValue(f=value) 166 elif isinstance(value, (str, bytes)): 167 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 168 else: 169 raise ValueError(f"Attribute {key} must be bool, int, float, string, or " 170 f"AttrValue. Got {type(value)}.") 171 return attrs 172 173 174class _InterpolateFunctionError(object): 175 """Context Manager that interpolates the exception from 'top_level_func'.""" 176 177 __slots__ = ["_func"] 178 179 def __init__(self, top_level_func): 180 self._func = top_level_func 181 182 def __enter__(self): 183 pass 184 185 def __exit__(self, typ, exc, tb): 186 if not exc or not isinstance(exc, errors.OpError): 187 return False 188 message = compat.as_text(exc.message) 189 _, func_tags, _ = error_interpolation.parse_message(message) 190 g = None 191 for func_tag in func_tags: 192 # TODO(mdan): Tests should cover this. 193 if func_tag.name == compat.as_str(self._func.name): 194 g = self._func.graph 195 elif g: 196 next_func = g._get_function(func_tag.name) # pylint: disable=protected-access 197 if next_func is not None and isinstance(next_func, 198 _EagerDefinedFunction): 199 g = next_func.graph 200 if g: 201 exc._message = error_interpolation.interpolate(message, g) # pylint: disable=protected-access 202 return False 203 204 205_function_callbacks = set() 206 207 208def add_function_callback(function_callback): 209 """Add a callback function for Function creation. 210 211 The callback function has the signature: 212 213 `def function_callback(function, name, graph, inputs, outputs):` 214 215 where: 216 - `function`: _EagerDefinedFunction being created before finalizing the graph. 217 Do not modify the function directly but instead modify the graph. 218 - `name`: name of the function. 219 - `graph`: Graph of the function. 220 - `inputs`: `tuple` of tensors used as inputs to the function. 221 - `outputs`: `tuple` of tensors used as outputs from the function. 222 223 The callback is at the top of the `_EagerDefinedFunction` construction, giving 224 callback an opportunity to make the last edits to the graph. Do not make 225 changes to `graph, inputs`, and `outputs` manually, but, instead, set the 226 `graph` as the default then define ops. 227 228 Repeated registration of the same callback function is idempotent. 229 After a callback is added, it can be removed with the 230 `remove_function_callback()` method. 231 232 Args: 233 function_callback: The callback to add. 234 """ 235 _function_callbacks.add(function_callback) 236 237 238def remove_function_callback(function_callback): 239 """Remove an already-added function callback. 240 241 See the doc string of `add_function_callback()` for more information. 242 243 Args: 244 function_callback: The callback to remove. 245 """ 246 _function_callbacks.remove(function_callback) 247 248 249def clear_function_callbacks(): 250 """Clear all function callbacks, if any have been regisered.""" 251 _function_callbacks.clear() 252 253 254_FORWARD_PREFIX = "__forward_" 255_BACKWARD_PREFIX = "__backward_" 256_INFERENCE_PREFIX = "__inference_" 257 258 259def _forward_name(n): 260 """The name of a generated forward defun named n.""" 261 return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid()) 262 263 264def _backward_name(n): 265 """The name of a generated backward defun named n.""" 266 return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid()) 267 268 269def _inference_name(n): 270 """The name of a forward-but-no-gradient defun named n.""" 271 return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid()) 272 273 274class _EagerDefinedFunctionDeleter(object): 275 """Unregister function from eager context.""" 276 277 __slots__ = ["name"] 278 279 def __init__(self, name): 280 self.name = name 281 282 def __del__(self): 283 try: 284 context.remove_function(self.name) 285 except TypeError: 286 # Suppress some exceptions, mainly for the case when we're running on 287 # module deletion. Things that can go wrong include the context module 288 # already being unloaded, self._handle._handle_data no longer being 289 # valid, and so on. Printing warnings in these cases is silly 290 # (exceptions raised from __del__ are printed as warnings to stderr). 291 pass # 'NoneType' object is not callable when the handle has been 292 # partially unloaded. 293 except AttributeError: 294 pass # 'NoneType' object has no attribute 'eager_mode' when context has 295 # been unloaded. Will catch other module unloads as well. 296 297 298# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction 299# so it doesn't have the definition-generating logic and is just a container for 300# an already-defined function. 301class _EagerDefinedFunction(object): 302 """Callable with the interface of `framework.function._DefinedFunction`. 303 304 `_EagerDefinedFunction` encapsulates a function definition and its properties, 305 and it provides a method for calling the encapsulated function. Some Ops 306 take functions as attributes, which have type `func`; an instance of this 307 class may be provided as the value of these `func` attributes. 308 """ 309 310 def __init__(self, name, graph, inputs, outputs, attrs): 311 """Initializes an eager defined function. 312 313 Args: 314 name: str, the name for the created function. 315 graph: Graph, the graph containing the operations in the function 316 inputs: the tensors in the graph to be used as inputs to the function 317 outputs: the tensors in the graph which will be outputs from the function 318 attrs: dict mapping names of attributes to their AttrValue values 319 """ 320 for function_callback in _function_callbacks: 321 function_callback(self, name, graph, tuple(inputs), tuple(outputs)) 322 323 input_ops = set(arg.op for arg in inputs) 324 operations = [op for op in graph.get_operations() if op not in input_ops] 325 326 graph_output_names = graph._output_names # pylint: disable=protected-access 327 if (graph_output_names is not None and 328 all(ops.tensor_id(t) in graph_output_names for t in outputs)): 329 output_names = [ 330 compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs 331 ] 332 if len(set(output_names)) != len(output_names): 333 # There are duplicate names for some reason, probably an invalid 334 # signature. Revert to auto-naming. 335 output_names = [] 336 else: 337 output_names = [] 338 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access 339 fn = pywrap_tf_session.TF_GraphToFunction_wrapper( 340 c_graph, 341 compat.as_str(name), 342 False, 343 [o._c_op for o in operations], # pylint: disable=protected-access 344 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 345 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 346 output_names, 347 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access 348 [], # control_output_names 349 None, 350 compat.as_str("")) 351 352 self._c_func = c_api_util.ScopedTFFunction(fn, name) 353 354 for name, attr_value in attrs.items(): 355 serialized = attr_value.SerializeToString() 356 # TODO(iga): this creates and deletes a new TF_Status for every attr. 357 # It might be worth creating a convenient way to re-use status. 358 pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name), 359 serialized) 360 361 # NOTE(feyu): Do not cache signature and definition at initialization to 362 # save memory usage of concrete functions never called through Python. We 363 # cache them on the first call of .definition and .signature. 364 signature = self._get_definition().signature 365 366 self._name = compat.as_bytes(signature.name) 367 with ops.init_scope(): 368 if context.executing_eagerly(): 369 context.ensure_initialized() 370 context.add_function(fn) 371 self._function_deleter = _EagerDefinedFunctionDeleter(self.name) 372 self._registered_on_context = True 373 374 self._num_outputs = len(signature.output_arg) 375 self._output_types = [o.type for o in signature.output_arg] 376 self._output_shapes = [o.shape for o in outputs] 377 self._control_captures = graph.control_captures 378 # Shallow copy outputs since ConcreteFunction may mutate it. 379 self._func_graph_outputs = list(outputs) 380 self.grad_func_name = None 381 self.python_grad_func = None 382 self._grad_func = None 383 self.graph = graph 384 self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access 385 386 @property 387 def signature(self): 388 try: 389 return self._signature 390 except AttributeError: 391 self._signature = self.definition.signature 392 return self._signature 393 394 @property 395 def definition(self): 396 try: 397 return self._definition 398 except AttributeError: 399 self._definition = self._get_definition() 400 return self._definition 401 402 def _get_definition(self): 403 # TODO(apassos) avoid creating a FunctionDef (specially to grab the 404 # signature, but also in general it's nice not to depend on it. 405 with c_api_util.tf_buffer() as buffer_: 406 with self._c_func.get() as func: 407 pywrap_tf_session.TF_FunctionToFunctionDef(func, buffer_) 408 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 409 function_def = function_pb2.FunctionDef() 410 function_def.ParseFromString(compat.as_bytes(proto_data)) 411 return function_def 412 413 def add_to_graph(self, g=None): 414 """Add the function to the current context or a graph, if supplied. 415 416 Args: 417 g: the graph to add the function to. If not supplied, the function will 418 be added to the current context. 419 """ 420 # pylint: disable=protected-access 421 if not g and context.executing_eagerly(): 422 ctx = context.context() 423 if not ctx.has_function(self.name): 424 ctx.add_function_def(self.definition) 425 else: 426 if not g._is_function(self.name): 427 g._add_function(self) 428 for f in self.graph._functions.values(): 429 if not g._is_function(f.name): 430 g._add_function(f) 431 # pylint: enable=protected-access 432 433 @property 434 def name(self): 435 return self._name 436 437 @property 438 def stateful_ops(self): 439 return self._stateful_ops 440 441 def call(self, ctx, args, cancellation_manager=None): 442 """Calls this function with `args` as inputs. 443 444 `ConcreteFunction` execution respects device annotations only if the 445 function won't be compiled with xla. 446 447 Args: 448 ctx: a Context object 449 args: a list of arguments to supply this function with. 450 cancellation_manager: a `CancellationManager` object that can be used to 451 cancel function execution. 452 453 Returns: 454 The outputs of the function call. 455 456 Raises: 457 ValueError: if the number of arguments is incorrect. 458 FunctionAlreadyGarbageCollectedError: if the function is no longer 459 available to be called because it has been garbage collected. 460 """ 461 if len(args) != len(self.signature.input_arg): 462 raise ValueError( 463 f"Signature specifies {len(list(self.signature.input_arg))} " 464 f"arguments, got: {len(args)}.") 465 466 function_call_options = ctx.function_call_options 467 if function_call_options.config_proto_serialized is None: 468 config = function_utils.get_disabled_rewriter_config() 469 else: 470 config = function_call_options.config_proto_serialized 471 executor_type = function_call_options.executor_type or "" 472 473 executing_eagerly = ctx.executing_eagerly() 474 attrs = ("executor_type", executor_type, "config_proto", config) 475 if executing_eagerly: 476 with _InterpolateFunctionError(self): 477 if cancellation_manager is None: 478 outputs = execute.execute( 479 str(self.signature.name), 480 num_outputs=self._num_outputs, 481 inputs=args, 482 attrs=attrs, 483 ctx=ctx) 484 else: 485 outputs = execute.execute_with_cancellation( 486 str(self.signature.name), 487 num_outputs=self._num_outputs, 488 inputs=args, 489 attrs=attrs, 490 ctx=ctx, 491 cancellation_manager=cancellation_manager) 492 # Replace empty list with None 493 outputs = outputs or None 494 else: 495 # TODO(akshayka): Either remove this if the FunctionLibraryRuntime 496 # creates `PartitionedCallOp` kernels by default, or remove the previous 497 # branch if a TPU kernel is registered for `PartitionedCall`. 498 with _InterpolateFunctionError(self): 499 with ops.control_dependencies(self._control_captures): 500 # The caller must use record_operation to record this operation in the 501 # eager case, so we enforce the same requirement for the non-eager 502 # case by explicitly pausing recording. We don't have a gradient 503 # registered for PartitionedCall, so recording this operation confuses 504 # forwardprop code (GradientTape manages to ignore it). 505 with tape.stop_recording(): 506 outputs = functional_ops.partitioned_call( 507 args=args, 508 f=self, 509 tout=self._output_types, 510 executing_eagerly=executing_eagerly, 511 config=config, 512 executor_type=executor_type) 513 514 for i, func_graph_output in enumerate(self._func_graph_outputs): 515 handle_data_util.copy_handle_data(func_graph_output, outputs[i]) 516 if executing_eagerly: 517 return outputs 518 else: 519 # TODO(b/128924522): This additional set_shape should not be 520 # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this 521 # once that's done. 522 for i, shape in enumerate(self._output_shapes): 523 outputs[i].set_shape(shape) 524 return outputs 525 526 527def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph): 528 """Creates forward and backward functions from the function graphs.""" 529 forward_function_name = _forward_name(forward_graph.name) 530 common_attributes = dict(attrs) 531 # NB: forward and backward function need to drop "_implements". 532 # attribute, because their signature contains all the intermediate tensors 533 # that they compute. Thus they don't have a stable signature which can 534 # be directly optimized downstream. 535 # See for more details: 536 # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions 537 common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None) 538 backward_function_attr = _parse_func_attrs( 539 {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) 540 backward_function_attr.update(common_attributes) 541 backward_function = ConcreteFunction( 542 backwards_graph, attrs=backward_function_attr) 543 forward_function_attr = _parse_func_attrs({ 544 BACKWARD_FUNCTION_ATTRIBUTE_NAME: 545 backward_function.name}) 546 forward_function_attr.update(common_attributes) 547 forward_function = _EagerDefinedFunction( 548 forward_function_name, forward_graph, forward_graph.inputs, 549 forward_graph.outputs, forward_function_attr) 550 return forward_function, backward_function 551 552 553class _DelayedRewriteGradientFunctions(object): 554 """Caches forward/backward functions with a delayed forward rewrite.""" 555 556 def __init__(self, func_graph, attrs, func_graph_deleter): 557 """Construct an inference function and initialize caches.""" 558 # A map from the number of forward function outputs with accepted gradients 559 # to forward and backward functions, used to cache non-tape backward 560 # function generation. 561 self._cached_function_pairs = {} 562 self._func_graph = func_graph 563 self._inference_function = _EagerDefinedFunction( 564 _inference_name(self._func_graph.name), self._func_graph, 565 self._func_graph.inputs, self._func_graph.outputs, attrs) 566 self._attrs = attrs 567 self._gradient_name = None 568 # Note that the FuncGraph is mutated later, so we need to inspect it now to 569 # figure out the user-specified outputs of the inference function. 570 self._num_inference_outputs = len(self._func_graph.outputs) 571 self._func_graph_deleter = func_graph_deleter 572 573 def forward_backward(self, num_doutputs=None): 574 """A possibly-cached pair of forward and backward functions.""" 575 if num_doutputs is None: 576 num_doutputs = self._num_inference_outputs 577 forward_backward = self._cached_function_pairs.get(num_doutputs) 578 if forward_backward is not None: 579 return forward_backward 580 forward, backward = self._construct_forward_backward(num_doutputs) 581 self._cached_function_pairs[num_doutputs] = (forward, backward) 582 return forward, backward 583 584 def _construct_forward_backward(self, num_doutputs): 585 """Constructs a pair of forward and backward functions. 586 587 Args: 588 num_doutputs: The constructed backprop function will take output gradients 589 for the first `num_doutputs` outputs of the forward function. Defaults 590 to the number of outputs for the inference function, but when 591 higher-order gradients are computed this will increase to include side 592 outputs. 593 594 Returns: 595 A pair of (forward_function, backward_function): 596 forward_function: A re-generated inference function (an 597 _EagerDefinedFunction) to account for new side outputs, if any extra 598 were required when building the backward pass. 599 backward_function: A ConcreteFunction that Takes `num_doutputs` 600 arguments and returns gradients with respect to inputs of the forward 601 function. 602 """ 603 trainable_outputs = [ 604 output for output in self._func_graph.outputs[:num_doutputs] 605 if backprop_util.IsTrainable(output)] 606 607 signature = [] 608 for t in trainable_outputs: 609 signature.append( 610 tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) 611 612 def _backprop_function(*grad_ys): 613 with ops.device(None): 614 return gradients_util._GradientsHelper( # pylint: disable=protected-access 615 trainable_outputs, 616 self._func_graph.inputs, 617 grad_ys=grad_ys, 618 src_graph=self._func_graph) 619 620 with self._func_graph.as_default(): 621 backwards_graph = func_graph_module.FuncGraph( 622 _backward_name(self._func_graph.name)) 623 func_graph_module.func_graph_from_py_func( 624 name=backwards_graph.name, 625 python_func=_backprop_function, 626 args=[], kwargs={}, 627 signature=signature, 628 func_graph=backwards_graph) 629 backwards_graph_captures = backwards_graph.external_captures 630 captures_from_forward = [ 631 c for c in backwards_graph_captures if 632 not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] 633 634 existing_outputs = object_identity.ObjectIdentitySet( 635 self._func_graph.outputs) 636 for capture in captures_from_forward: 637 if capture not in existing_outputs: 638 existing_outputs.add(capture) 639 self._func_graph.outputs.append(capture) 640 641 forward_function, backward_function = _create_forward_backward_with_graph( 642 self._attrs, self._func_graph, backwards_graph) 643 return forward_function, backward_function 644 645 def _rewrite_forward_and_call_backward(self, op, *doutputs): 646 """Add outputs to the forward call and feed them to the grad function.""" 647 forward_function, backwards_function = self.forward_backward(len(doutputs)) 648 if not backwards_function.outputs: 649 return backwards_function.structured_outputs 650 forward_function.add_to_graph(op.graph) 651 652 # pylint: disable=protected-access 653 # Rewrite an inference call op to be a forward call op 654 op._set_func_attr("f", forward_function.name) 655 op._set_type_list_attr("Tout", forward_function._output_types) 656 op._add_outputs( 657 forward_function._output_types[len(op.outputs):], 658 forward_function._output_shapes[len(op.outputs):]) 659 for i in range(len(op.outputs)): 660 func_graph_output = forward_function._func_graph_outputs[i] 661 handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]) 662 # pylint: enable=protected-access 663 664 capture_mapping = dict( 665 zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs)) 666 remapped_captures = [ 667 capture_mapping.get(ops.tensor_id(capture), capture) 668 for capture in backwards_function.captured_inputs 669 ] 670 671 # Replace Nones with zeros since we're calling a graph function which 672 # expects numeric inputs. 673 cleaned_doutputs = [] 674 for doutput, placeholder in zip(doutputs, self._func_graph.outputs): 675 if backprop_util.IsTrainable(placeholder): 676 if isinstance(doutput, indexed_slices.IndexedSlices): 677 # Gradient passed to a backward ConcreteFunction must be tf.Tensor, 678 # so we convert tf.IndexedSlices to tf.Tensor. 679 cleaned_doutputs.append(ops.convert_to_tensor(doutput)) 680 elif doutput is not None: 681 cleaned_doutputs.append(doutput) 682 else: 683 cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) 684 685 # Compute the gradients using the side outputs 686 return backwards_function._call_flat( # pylint: disable=protected-access 687 cleaned_doutputs, remapped_captures) 688 689 def get_gradient_function(self): 690 """Returns gradient function. 691 692 The gradient rewrites an inference call op to a forward call op, but does 693 not modify a pre-existing forward call op. It then computes the gradient 694 from the output's gradients and the side outputs of the forward op. 695 """ 696 return self._rewrite_forward_and_call_backward 697 698 def forward(self, inference_args=None, input_tangents=None): 699 """A forward function with only user-specified outputs. 700 701 The call operation for the returned inference function can be rewritten into 702 a forward function. This only happens if the backward function (from the 703 `backward` method) ends up being used to compute gradients. 704 705 This approach avoids constructing unnecessary graphs, but it only works if 706 we are calling this function when not executing eagerly. 707 708 Args: 709 inference_args: A flat list of Tensors, arguments to the inference 710 function. Unused, but taken for compatibility with 711 _TapeGradientFunctions. 712 input_tangents: A flat list of Tensors, jvps associated with 713 `inference_args`. Unused; if required, tape functions must be used 714 instead. 715 716 Returns: 717 An _EagerDefinedFunction. 718 """ 719 del inference_args # unused 720 if input_tangents: 721 # This class does not support special-cased forwardprop. The arguments are 722 # here for compatibility with _TapeGradientFunctions. 723 raise errors.InternalError("unexpectedly got forwardprop information in " 724 "a class that does not support forwardprop.") 725 return self._inference_function 726 727 def _backward(self, outputs): 728 """Fetch a backward function for `outputs` from the forward function.""" 729 def _backward_function(*args): 730 call_op = outputs[0].op 731 return self._rewrite_forward_and_call_backward(call_op, *args) 732 return _backward_function, outputs 733 734 def record(self, flat_outputs, inference_args, input_tangents): 735 """Record the function call operation. 736 737 _DelayedRewriteGradientFunctions supports only first-order backprop tape 738 gradients (and then only when graph building). It does not work with 739 higher-order tape gradients or forward autodiff, but does work with 740 higher-order symbolic gradients (tf.gradients). 741 742 Args: 743 flat_outputs: The result of running `forward`. 744 inference_args: A flat list of Tensors with inference inputs to the 745 operation. 746 input_tangents: A flat list of Tensors with input tangents consumed by the 747 operation. 748 """ 749 backward_function, to_record = self._backward(flat_outputs) 750 tape.record_operation(self._inference_function.signature.name, 751 to_record, inference_args + input_tangents, 752 backward_function) 753 754 755# Contains information about a forward function wrapped to compute jvps. 756_ForwardWrapper = collections.namedtuple( 757 "_ForwardWrapper", ( 758 # The wrapper Graph. 759 "graph", 760 # A flat list of non-tangent Tensor outputs from the wrapped forward 761 # function. 762 "outputs", 763 # Indices for output tangents, same format as 764 # forwardprop_util.pack_tangents. 765 "output_indices", 766 # A flat list of tangents for `outputs`. 767 "output_tangents")) 768 769 770class _TapeGradientFunctions(object): 771 """Caches forward and backward functions compatible with eager gradients. 772 773 In contrast to the delayed-rewrite approach in 774 `_DelayedRewriteGradientFunctions` which only works with delayed execution, 775 the forward function generated by this class has a fixed set of outputs which 776 may be preserved by a tape in order to compute gradients later. 777 778 This class is abstract; its child classes differ in how many side outputs of 779 the forward function their backward function accepts gradients for, which 780 determines whether higher-order tape gradients are possible. 781 """ 782 783 def __init__(self, func_graph, attrs, func_graph_deleter, 784 forwardprop_input_indices, delayed_rewrite_functions, 785 need_gradients_for_jvps): 786 self._func_graph = func_graph 787 self._forward_graph = None 788 self._attrs = attrs 789 self._forward = None 790 self._backward = None 791 self._num_outputs = len(func_graph.outputs) 792 self._func_graph_deleter = func_graph_deleter 793 self._forwardprop_input_indices = forwardprop_input_indices 794 self._forwardprop_output_indices = None 795 self._num_forwardprop_outputs = 0 796 self._num_inference_outputs = len(func_graph.outputs) 797 self._num_trainable_inference_outputs = len( 798 [t for t in func_graph.outputs if backprop_util.IsTrainable(t)]) 799 self._delayed_rewrite_functions = delayed_rewrite_functions 800 self._need_gradients_for_jvps = need_gradients_for_jvps 801 802 def _build_functions_for_outputs( 803 self, outputs, inference_args, input_tangents): 804 """Forward+backward functions where the backward function sees `outputs`.""" 805 # First figure out which of `outputs` are trainable. We'll accept gradients 806 # for each of these in the backward function. 807 handles_to_variables = self._func_graph.variable_captures 808 trainable_outputs = [] 809 trainable_indices = [] 810 for index, output in enumerate(outputs): 811 812 if backprop_util.IsTrainable(output): 813 # Swap in the Variable object for resource handles if we can so 814 # sparse gradients work. 815 output = handles_to_variables.get(id(output), output) 816 trainable_outputs.append(output) 817 trainable_indices.append(index) 818 819 backwards_graph = func_graph_module.FuncGraph( 820 _backward_name(self._func_graph.name)) 821 with backwards_graph.as_default(): 822 gradients_wrt_outputs = [] 823 for output in trainable_outputs: 824 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 825 output) 826 gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 827 handle_data_util.copy_handle_data(output, gradient_placeholder) 828 gradients_wrt_outputs.append(gradient_placeholder) 829 with ops.device(None): 830 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 831 trainable_outputs, 832 self._func_graph.inputs, 833 grad_ys=gradients_wrt_outputs, 834 src_graph=self._func_graph) 835 836 if input_tangents: 837 # Convert IndexedSlices to dense tensors (as we do elsewhere for 838 # function gradients). Our C++ bindings don't know how to handle them 839 # currently. 840 gradients_wrt_inputs = nest.map_structure( 841 lambda x: ops.convert_to_tensor(x) if x is not None else None, 842 gradients_wrt_inputs) 843 captures_from_forward = [ 844 c for c in backwards_graph.external_captures 845 if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph 846 ] 847 existing_outputs = object_identity.ObjectIdentitySet( 848 self._func_graph.outputs) 849 for capture in captures_from_forward: 850 if capture not in existing_outputs: 851 existing_outputs.add(capture) 852 self._func_graph.outputs.append(capture) 853 854 # The ordering of `backwards_graph.inputs` is important: inputs of 855 # `backward_function` correspond to outputs (including 856 # side outputs) of `self._tape_forward_function`. 857 backwards_graph.inputs = ( 858 gradients_wrt_outputs + backwards_graph.internal_captures) 859 backwards_graph.outputs.extend( 860 grad 861 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) 862 if grad is not None) 863 backwards_graph.structured_outputs = gradients_wrt_inputs 864 865 forward_function, backward_function = _create_forward_backward_with_graph( 866 self._attrs, self._func_graph, backwards_graph) 867 868 if not input_tangents: 869 # There is no need to special-case forwardprop, so we can return the 870 # forward+backward pair we've created without further wrapping. 871 return (forward_function, self._func_graph, backward_function, 872 # No forwardprop outputs. 873 None, 0) 874 forward_wrapper = self._wrap_forward_function_with_jvps( 875 forward_function, backward_function, inference_args, input_tangents) 876 (wrapped_backwards_graph, 877 forward_wrapper) = self._wrap_backward_function_with_jvp_backprop( 878 backward_function, gradients_wrt_outputs, forward_wrapper) 879 # Now that we've added new captures, we need to make sure forward outputs 880 # are in the same order the backward function expects them to be in: 881 # [inference outputs] + [jvps] + [side outputs] + [captures]. 882 forward_wrapper = self._shuffle_forward_outputs(forward_wrapper) 883 (wrapped_forward_function, 884 wrapped_backward_function) = _create_forward_backward_with_graph( 885 self._attrs, forward_wrapper.graph, wrapped_backwards_graph) 886 if (len(inference_args) + len(input_tangents) 887 != len(forward_wrapper.graph.inputs)): 888 raise errors.InternalError( 889 f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, " 890 f"but we expected {len(inference_args) + len(input_tangents)} " 891 f"({len(inference_args)} inference inputs and " 892 f"{len(input_tangents)} input tangents).") 893 return (wrapped_forward_function, forward_wrapper.graph, 894 wrapped_backward_function, forward_wrapper.output_indices, 895 len(forward_wrapper.output_tangents)) 896 897 def _wrap_forward_function_with_jvps( 898 self, forward_function, backward_function, 899 inference_args, input_tangents): 900 """Adds inline JVP computation to a forward function.""" 901 forward_wrapper_graph = func_graph_module.FuncGraph( 902 _forward_name(self._func_graph.name)) 903 with forward_wrapper_graph.as_default(): 904 # Tell forward accumulators to free up space for new JVP computations, 905 # since one may be in the process of computing a JVP (if that computation 906 # triggered this function building). 907 # 908 # We'll make symbolic versions of input JVPs, run the forward function 909 # under forward accumulators to get symbolic output JVPs, then set those 910 # as outputs of the new wrapped forward function. 911 with forwardprop_util.push_forwardprop_state(): 912 forward_captures = { 913 ops.tensor_id(internal): external 914 for external, internal in self._func_graph.captures} 915 for input_index, real_input in enumerate(self._func_graph.inputs): 916 # This loop is more or less equivalent to running tf.identity on each 917 # of self._func_graph.inputs. However, doing that also captures jvps 918 # for resource handles, which confuses the jvp capturing code below 919 # (since primal inputs are interwoven with jvp inputs). 920 input_placeholder = array_ops.placeholder( 921 dtype=real_input.dtype, 922 shape=real_input.shape) 923 capture = forward_captures.get(ops.tensor_id(real_input)) 924 if capture is not None: 925 forward_wrapper_graph.add_capture(capture, input_placeholder) 926 if capture.dtype == dtypes.resource: 927 handle_data_util.copy_handle_data(capture, input_placeholder) 928 else: 929 forward_wrapper_graph.inputs.append(input_placeholder) 930 for inp, arg in zip(forward_wrapper_graph.inputs, inference_args): 931 tape.record_operation( 932 "captured_value", [inp], [arg], 933 backward_function=lambda x: [x], 934 forward_function=lambda x: [x]) 935 num_inference_inputs = len(inference_args) 936 for tape_indices in self._forwardprop_input_indices: 937 for input_index, jvp_index in tape_indices: 938 input_placeholder = forward_wrapper_graph.inputs[input_index] 939 if len(forward_wrapper_graph.inputs) != jvp_index: 940 raise errors.InternalError( 941 f"Expected {jvp_index} forward graph inputs, " 942 f"got {len(forward_wrapper_graph.inputs)}.") 943 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 944 input_placeholder) 945 jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 946 external_jvp = input_tangents[jvp_index - num_inference_inputs] 947 forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder) 948 tensor_shape.TensorShape( 949 external_jvp.shape).assert_is_compatible_with( 950 jvp_placeholder.shape) 951 tape.record_operation( 952 "captured_value", 953 [jvp_placeholder], 954 [external_jvp], 955 backward_function=lambda x: [x], 956 forward_function=lambda x: [x]) 957 forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs] 958 gradient_function = ( 959 self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access 960 with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access 961 {"PartitionedCall": gradient_function, 962 "StatefulPartitionedCall": gradient_function}): 963 forward_outputs = forward_function.call(context.context(), 964 forward_inputs) 965 if isinstance(forward_outputs, ops.Operation): 966 # _wrapped_backward_function expects a list, but if the function has 967 # no outputs its call() returns an Operation. We need to undo that 968 # so we don't cause problems later. 969 forward_outputs = [] 970 py_backward, _ = self._wrap_backward_function( 971 self._func_graph, backward_function, forward_outputs) 972 # We will never request backward tape gradients for this operation 973 # directly since we're wrapping the call; forwardprop will call the 974 # backward function (and nested forward accumulators may build 975 # higher-order gradients), but any watching GradientTapes should ignore 976 # it. 977 # 978 # TODO(allenl): It might be better to explicitly stop backward recording 979 # so we don't use the second-order tape cases unnecessarily. 980 tape.record_operation_forwardprop_only( 981 forward_function.signature.name, 982 forward_outputs, forward_inputs, py_backward, None) 983 output_indices, output_tangents = ( 984 pywrap_tfe.TFE_Py_PackJVPs(forward_outputs)) 985 output_tangents = [forward_wrapper_graph.capture(t) 986 for t in output_tangents] 987 return _ForwardWrapper( 988 graph=forward_wrapper_graph, outputs=forward_outputs, 989 output_indices=output_indices, output_tangents=output_tangents) 990 991 def _wrap_backward_function_with_jvp_backprop( 992 self, backward_function, gradients_wrt_outputs, forward_wrapper): 993 """Wraps `backward_function` to include gradients for JVPs.""" 994 wrapped_backwards_graph = func_graph_module.FuncGraph( 995 _backward_name(self._func_graph.name)) 996 with wrapped_backwards_graph.as_default(): 997 py_backward, recorded_outputs = self._wrap_backward_function( 998 self._func_graph, backward_function, forward_wrapper.outputs) 999 trainable_index = 0 1000 forward_doutputs = [] 1001 doutput_args = [] 1002 for output in recorded_outputs: 1003 if backprop_util.IsTrainable(output): 1004 doutput = gradients_wrt_outputs[trainable_index] 1005 doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape) 1006 doutput_args.append(doutput_placeholder) 1007 forward_doutputs.append(doutput_placeholder) 1008 trainable_index += 1 1009 else: 1010 doutput_args.append(None) 1011 1012 dinputs = py_backward(*doutput_args) 1013 existing_outputs = object_identity.ObjectIdentitySet( 1014 forward_wrapper.outputs + forward_wrapper.output_tangents) 1015 num_processed_output_tangents = 0 1016 gradients_wrt_output_tangents = [] 1017 tangent_doutputs = [] 1018 output_tangents = forward_wrapper.output_tangents 1019 output_indices = forward_wrapper.output_indices 1020 if self._need_gradients_for_jvps: 1021 # TODO(allenl): Consider using a throwaway graph to avoid extra gradient 1022 # evaluations; gradients for jvps may have common subgraphs. 1023 while num_processed_output_tangents != len(output_tangents): 1024 for output in output_tangents[num_processed_output_tangents:]: 1025 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 1026 output) 1027 placeholder = graph_placeholder(gradient_dtype, gradient_shape) 1028 gradients_wrt_output_tangents.append(placeholder) 1029 tangent_doutputs.append(placeholder) 1030 num_processed_output_tangents = len(output_tangents) 1031 with ops.device(None): 1032 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 1033 output_tangents, 1034 forward_wrapper.graph.inputs, 1035 grad_ys=gradients_wrt_output_tangents, 1036 src_graph=forward_wrapper.graph) 1037 dinputs = [ 1038 backprop.aggregate_indexed_slices_gradients((existing, new)) 1039 for existing, new in zip(dinputs, gradients_wrt_inputs) 1040 if existing is not None or new is not None] 1041 dinputs.extend(gradients_wrt_inputs[len(dinputs):]) 1042 captures_from_forward = [ 1043 c for c in wrapped_backwards_graph.external_captures 1044 if (not isinstance(c, ops.EagerTensor) 1045 and c.graph is forward_wrapper.graph)] 1046 for capture in captures_from_forward: 1047 if capture not in existing_outputs: 1048 existing_outputs.add(capture) 1049 forward_wrapper.outputs.append(capture) 1050 output_indices, output_tangents = ( 1051 forwardprop_util.pack_tangents(forward_wrapper.outputs)) 1052 output_tangents = [forward_wrapper.graph.capture(t) 1053 for t in output_tangents] 1054 for t in output_tangents: 1055 existing_outputs.add(t) 1056 wrapped_backwards_graph.inputs = ( 1057 forward_doutputs[:self._num_trainable_inference_outputs] 1058 + tangent_doutputs 1059 + forward_doutputs[self._num_trainable_inference_outputs:] 1060 + wrapped_backwards_graph.internal_captures) 1061 wrapped_backwards_graph.structured_outputs = dinputs 1062 wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None] 1063 return (wrapped_backwards_graph, 1064 forward_wrapper._replace(output_indices=output_indices, 1065 output_tangents=output_tangents)) 1066 1067 def _shuffle_forward_outputs(self, forward_wrapper): 1068 """Reorders function outputs so captures are last.""" 1069 def _index_map(original): 1070 if original < self._num_inference_outputs: 1071 return original 1072 if original >= len(forward_wrapper.outputs): 1073 return (original - len(forward_wrapper.outputs) 1074 + self._num_inference_outputs) 1075 return original + len(forward_wrapper.output_tangents) 1076 output_indices = nest.map_structure( 1077 _index_map, forward_wrapper.output_indices) 1078 forward_wrapper.graph.outputs = ( 1079 forward_wrapper.outputs[:self._num_inference_outputs] 1080 + forward_wrapper.output_tangents 1081 + forward_wrapper.outputs[self._num_inference_outputs:]) 1082 return forward_wrapper._replace(output_indices=output_indices) 1083 1084 def forward(self, inference_args, input_tangents): 1085 """Construct or fetch a forward function with side-outputs. 1086 1087 When graph building without a tape active, symbolic gradients rely on 1088 regenerating the backward function for higher-order gradients (to account 1089 for new side outputs of the rewritten forward function call). Thus there is 1090 no fixed backward function for this case. However, when a tape is active 1091 (eager or graph building), we generate fixed backward and forward functions 1092 at forward function call time. 1093 1094 This difference between the tape and non-tape cases is to avoid building 1095 unneeded backward functions while graph building (where we may or may not 1096 eventually need gradients). 1097 1098 Args: 1099 inference_args: A flat list of Tensors, arguments to the inference 1100 function. 1101 input_tangents: A flat list of Tensors, jvps associated with 1102 `inference_args`. 1103 1104 Returns: 1105 A forward _EagerDefinedFunction. 1106 """ 1107 if self._forward is None: 1108 (self._forward, self._forward_graph, self._backward, 1109 self._forwardprop_output_indices, self._num_forwardprop_outputs) = ( 1110 self._forward_and_backward_functions(inference_args, input_tangents)) 1111 return self._forward 1112 1113 def _wrap_backward_function(self, forward_graph, backward, outputs): 1114 """Create a backward function given `outputs` from the forward function.""" 1115 capture_mapping = dict( 1116 zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs)) 1117 captured_inputs = backward.captured_inputs 1118 remapped_captures = [ 1119 capture_mapping.get(ops.tensor_id(capture), capture) 1120 for capture in captured_inputs 1121 ] 1122 if any(t.graph is forward_graph for t in remapped_captures 1123 if not isinstance(t, ops.EagerTensor)): 1124 incorrect_mapping = [t for t in remapped_captures 1125 if (not isinstance(t, ops.EagerTensor) and 1126 t.graph is not forward_graph)] 1127 raise errors.InternalError("Failed to map all backward graph captures to " 1128 "the forward graph. Incorrectly mapped: " 1129 f"{incorrect_mapping}.") 1130 # We may need to use zeros_like to get a zero for variant Tensors with 1131 # unconnected gradients. We do that in advance so we don't have to hold on 1132 # to the outputs themselves, which may not be needed otherwise. 1133 variant_zeros_like = {} 1134 backward_function_inputs = (len(backward.inputs) - len(captured_inputs)) 1135 recorded_outputs = [] 1136 trainable_recorded_outputs = 0 1137 skip_positions = [] 1138 if self._num_forwardprop_outputs and not self._need_gradients_for_jvps: 1139 relevant_outputs = ( 1140 outputs[:self._num_inference_outputs] 1141 + outputs[self._num_inference_outputs 1142 + self._num_forwardprop_outputs:]) 1143 else: 1144 relevant_outputs = outputs 1145 for output_index, output in enumerate(relevant_outputs): 1146 if trainable_recorded_outputs < backward_function_inputs: 1147 recorded_outputs.append(output) 1148 if backprop_util.IsTrainable(output): 1149 trainable_recorded_outputs += 1 1150 else: 1151 skip_positions.append(output_index) 1152 if output.dtype == dtypes.variant: 1153 variant_zeros_like[output_index] = default_gradient.zeros_like(output) 1154 1155 def _backward_function_wrapper(*args): 1156 """Process output gradients and call the backward function.""" 1157 if not backward.outputs: 1158 return backward.structured_outputs 1159 1160 processed_args = [] 1161 input_index = 0 1162 for output_index, arg in enumerate(args): 1163 # Convert IndexedSlices to dense tensors. The IndexedSlices optimization 1164 # is only really effective when doing tf.gather(variable) as the 1165 # adjoint functions for most operations are unlikely to preserve the 1166 # sparsity in IndexedSlices. 1167 if isinstance(arg, indexed_slices.IndexedSlices): 1168 arg = ops.convert_to_tensor(arg) 1169 if output_index in skip_positions: 1170 continue 1171 if arg is None: 1172 # We're calling a (non-polymorphic) ConcreteFunction, so we need to 1173 # have a Tensor value for each Tensor we thought would be trainable 1174 # based on its dtype, even if it ended up being unconnected. 1175 input_placeholder = backward.inputs[ 1176 input_index] 1177 if input_placeholder.dtype == dtypes.variant: 1178 arg = variant_zeros_like[output_index] 1179 else: 1180 arg = array_ops.zeros( 1181 *default_gradient.shape_and_dtype(input_placeholder)) 1182 processed_args.append(arg) 1183 input_index += 1 1184 if input_index >= backward_function_inputs: 1185 break 1186 return backward._call_flat( # pylint: disable=protected-access 1187 processed_args, remapped_captures) 1188 1189 return _backward_function_wrapper, recorded_outputs 1190 1191 def record(self, flat_outputs, inference_args, input_tangents): 1192 """Record the function call operation. 1193 1194 For backprop, indicates the backward function to use and which new Tensors 1195 must be watched. For forwardprop from eager, the function call itself will 1196 have produced tangents which need to be recorded. 1197 1198 Args: 1199 flat_outputs: The result of running `forward`. 1200 inference_args: A flat list of Tensors with inference inputs to the 1201 operation. 1202 input_tangents: A flat list of Tensors with input tangents consumed by the 1203 operation. 1204 """ 1205 backward_function, to_record = self._wrap_backward_function( 1206 self._forward_graph, self._backward, flat_outputs) 1207 if self._forwardprop_output_indices: 1208 tape.record_operation_backprop_only( 1209 self._forward.signature.name, 1210 to_record, inference_args, 1211 backward_function) 1212 tape.record_operation_forwardprop_only( 1213 self._forward.signature.name, 1214 flat_outputs, inference_args + input_tangents, 1215 backward_function, 1216 self._forwardprop_output_indices) 1217 else: 1218 tape.record_operation(self._forward.signature.name, 1219 to_record, inference_args + input_tangents, 1220 backward_function) 1221 1222 1223class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions): 1224 """Caches tape-friendly functions for first-order gradients.""" 1225 1226 def __init__(self, func_graph, attrs, func_graph_deleter, 1227 forwardprop_input_indices, delayed_rewrite_functions, 1228 need_gradients_for_jvps): 1229 super().__init__(func_graph, attrs, func_graph_deleter, 1230 forwardprop_input_indices, delayed_rewrite_functions, 1231 need_gradients_for_jvps) 1232 self._func_graph_deleter = func_graph_deleter 1233 self._forwardprop_input_indices = forwardprop_input_indices 1234 1235 def _forward_and_backward_functions(self, inference_args, input_tangents): 1236 """Shortcut for when only first-order gradients are required. 1237 1238 The returned backward function does not accept gradients with respect to 1239 side output of forward_function. This is fine as long as the user can't 1240 possibly request second order tape gradients, as when they've used a single 1241 non-persistent GradientTape. Since we don't need the backward function to 1242 take gradients with respect to side outputs, we can skip some potentially 1243 slow graph building. 1244 1245 Args: 1246 inference_args: A flat list of Tensors, arguments to the inference 1247 function. 1248 input_tangents: A flat list of Tensors, jvps associated with 1249 `inference_args`. 1250 1251 Returns: 1252 A tuple of (forward_function, backward_function): 1253 forward_function: Takes the same inputs as the inference function, but 1254 returns side outputs used by backward_function in addition to the 1255 inference function's outputs. 1256 backward_function: Takes side outputs from forward_function and 1257 gradients with respect to the "real" outputs of forward_function and 1258 returns gradients with respect to the inputs. 1259 """ 1260 outputs = self._func_graph.outputs[:self._num_inference_outputs] 1261 return self._build_functions_for_outputs( 1262 outputs, inference_args, input_tangents) 1263 1264 1265class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): 1266 """Caches tape-friendly functions for higher-order gradients.""" 1267 1268 # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider 1269 # generalizing if so. 1270 def _forward_and_backward_functions(self, inference_args, input_tangents): 1271 """Forward and backward functions suitable for higher-order gradients. 1272 1273 Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by 1274 this method accepts gradients for all of the outputs of the returned forward 1275 function, including side outputs. 1276 1277 Args: 1278 inference_args: A flat list of Tensors, arguments to the inference 1279 function. 1280 input_tangents: A flat list of Tensors, jvps associated with 1281 `inference_args`. 1282 1283 Returns: 1284 A tuple of (forward_function, backward_function): 1285 forward_function: Takes the same inputs as the inference function, but 1286 returns side outputs used by backward_function in addition to the 1287 inference function's outputs. 1288 backward_function: Takes side outputs from forward_function and 1289 gradients with respect to all of its outputs, real and side. Returns 1290 gradients with respect to the inputs. 1291 """ 1292 outputs = [] 1293 iteration_count = 0 1294 # First we need to figure out how many side outputs from the forward pass 1295 # will be required. We do this in a temporary graph to avoid actually 1296 # running multiple copies of the backward pass (one per _GradientsHelper 1297 # call). 1298 # 1299 # While computing gradients, the backward function captures Tensors from 1300 # the forward function. We add these as side outputs of the original 1301 # function. However, we then need to accept output gradients with respect 1302 # to these side outputs for higher order gradients to work. Thus we loop 1303 # until the number of outputs of the function stabilizes. Note that this 1304 # is only required for tape gradients, where we need to declare in advance 1305 # all of the forward op's outputs: symbolic gradients with tf.gradients 1306 # instead rely on regenerating backward functions when higher-order 1307 # gradients are requested. 1308 while (len(outputs) < len(self._func_graph.outputs) 1309 # It's possible for gradient generation to add new ops to the forward 1310 # pass. If all of the new outputs are non-trainable, there's no 1311 # reason to continue. 1312 and any(backprop_util.IsTrainable(output) 1313 for output in self._func_graph.outputs[len(outputs):])): 1314 iteration_count += 1 1315 if iteration_count >= 20 and iteration_count % 5 == 0: 1316 new_op_with_trainable_output = None 1317 num_new_trainable_outputs = 0 1318 for output in self._func_graph.outputs[len(outputs):]: 1319 if backprop_util.IsTrainable(output): 1320 num_new_trainable_outputs += 1 1321 new_op_with_trainable_output = output.op 1322 logging.warning( 1323 ("Determining side outputs for the function '{}' is taking longer " 1324 "than expected ({} iterations, typically this converges in 5 or " 1325 "so). This could indicate that a gradient registration is adding " 1326 "new ops to the forward pass every time gradients are generated. " 1327 "{} new trainable output(s) were added this iteration, one from " 1328 "the following op:\n {}\nThis may indicate a TensorFlow bug, or " 1329 "an issue in a tf.custom_gradient.") 1330 .format( 1331 self._func_graph.name, iteration_count, 1332 num_new_trainable_outputs, new_op_with_trainable_output)) 1333 outputs = list(self._func_graph.outputs) 1334 self._build_functions_for_outputs( 1335 outputs, inference_args, input_tangents) 1336 1337 (forward_function, forward_graph, 1338 backward_function, output_indices, num_output_tangents) = ( 1339 self._build_functions_for_outputs( 1340 outputs, inference_args, input_tangents)) 1341 if (len(self._func_graph.outputs) > len(outputs) 1342 and any(backprop_util.IsTrainable(output) 1343 for output in self._func_graph.outputs[len(outputs):])): 1344 raise errors.InternalError( 1345 "Unexpectedly added new outputs to the forward function when " 1346 "building the backward function: " 1347 f"{self._func_graph.outputs[len(outputs):]}.") 1348 return (forward_function, forward_graph, backward_function, output_indices, 1349 num_output_tangents) 1350 1351 1352class _ForwardBackwardCall(object): 1353 """Holds the state of a function call between execution and recording.""" 1354 1355 __slots__ = [ 1356 "_functions", "_inference_args", "_input_tangents", "_tape_watching" 1357 ] 1358 1359 def __init__(self, functions, inference_args, input_tangents, tape_watching): 1360 """Collects information about the function call. 1361 1362 Args: 1363 functions: An object which produces forward and backward functions, either 1364 a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object. 1365 inference_args: A flat list of Tensors, arguments to the inference 1366 function. 1367 input_tangents: A flat list of Tensors, jvps associated with 1368 `inference_args`. 1369 tape_watching: Boolean, with True indicating that recording is necessary. 1370 """ 1371 self._functions = functions 1372 self._inference_args = inference_args 1373 self._input_tangents = input_tangents 1374 self._tape_watching = tape_watching 1375 1376 def forward(self): 1377 """Builds or retrieves a forward function for this call.""" 1378 forward_function = self._functions.forward( 1379 self._inference_args, self._input_tangents) 1380 return forward_function, self._inference_args + self._input_tangents 1381 1382 def record(self, flat_outputs): 1383 """Given outputs from the execution of `forward`, records the operation.""" 1384 if (self._tape_watching 1385 and not isinstance(flat_outputs, ops.Operation) 1386 and flat_outputs is not None): 1387 # We only record function calls which have outputs, and then only when a 1388 # tape is watching. 1389 self._functions.record( 1390 flat_outputs, self._inference_args, self._input_tangents) 1391 1392 1393class ConcreteFunction(core.ConcreteFunction, trackable.Trackable): 1394 """A `tf.types.experimental.ConcreteFunction` created from `tf.function`.""" 1395 1396 def __init__(self, func_graph, attrs=None, shared_func_graph=True, spec=None): 1397 """Initialize a `ConcreteFunction`. 1398 1399 Args: 1400 func_graph: An instance of FuncGraph: the function body to wrap. 1401 attrs: (optional) dict mapping names of attributes to their AttrValue 1402 values. Attributes in `attrs` will be included in this function's 1403 definition. 1404 shared_func_graph: If False, the ConcreteFunction takes ownership of 1405 `func_graph` and will break reference cycles when it is deleted. This 1406 makes the FuncGraph inoperable. 1407 spec: FunctionSpec for the original function. If not specified, then this 1408 ConcreteFunction may only be called using the flat signature. 1409 1410 Raises: 1411 ValueError: If number of input_placeholders is not equal to the number 1412 of function inputs. 1413 """ 1414 # _arg_keywords and _num_positional_args define the flat signature. They 1415 # are assigned after construction. 1416 self._arg_keywords = None 1417 self._num_positional_args = None 1418 1419 self._func_graph = func_graph 1420 self._captured_inputs = self._func_graph.external_captures + self._func_graph.deferred_external_captures 1421 1422 # spec defines the structured signature. 1423 self._set_function_spec(spec) 1424 1425 if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs: 1426 # The alternative is to silently drop "implements" tag 1427 # but it seems likely it would lead to hard to catch bugs. 1428 # Another alternative is to make func_body to preserve the order 1429 # of arguments if variables are present. Yet another option 1430 # is to automatically replace variables as arguments to functions 1431 # to v.read_value() whenever "implements" tag is present 1432 # Anytime we annotate existing function we probably want to wrap 1433 # it with safe read_value for backward compatibility. 1434 has_resource_vars = any( 1435 inp.dtype == dtypes.resource for inp in self.inputs) 1436 1437 assert not any((has_resource_vars, self._captured_inputs)), ( 1438 'Function {name} has "{attr}={value}" attribute and thus can not ' 1439 "depend on any tensors outside of its signature or modify variables. " 1440 "\n\nNote: variables are always captured and cause function " 1441 "re-tracing for every variable called.\n" 1442 " inputs: {inputs}\n captures: {captured}\n\n" 1443 "To pass a variable to such function use " 1444 "use variable.read_value().".format( 1445 name=func_graph.name, 1446 attr=IMPLEMENTS_ATTRIBUTE_NAME, 1447 value=attrs[IMPLEMENTS_ATTRIBUTE_NAME], 1448 inputs=self.inputs, 1449 captured=self._captured_inputs)) 1450 self._output_shapes = tuple( 1451 output.shape for output in self._func_graph.outputs) 1452 self._attrs = _parse_func_attrs(attrs or {}) 1453 1454 if shared_func_graph: 1455 self._garbage_collector = None 1456 else: 1457 self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph) 1458 1459 # Pairs of forward and backward functions used for computing gradients. 1460 # 1461 # These each get a reference to the FuncGraph deleter since they use the 1462 # FuncGraph directly. 1463 self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions( 1464 func_graph, self._attrs, self._garbage_collector) 1465 self._first_order_tape_functions = {} 1466 self._higher_order_tape_functions = {} 1467 # Cache the inference function to avoid a (Python) function call when not 1468 # building gradients. 1469 self._inference_function = self._delayed_rewrite_functions.forward() 1470 1471 def _set_function_spec(self, spec): 1472 """Enables the structured signature by supplying a spec.""" 1473 self._function_spec = None 1474 self._pre_initialized_function_spec = spec 1475 self._initialize_function_spec() 1476 1477 def _initialize_function_spec(self): 1478 """Updates `self._function_spec` to include varargs and bound variables. 1479 1480 Adds new positional arguments for any varargs (i.e., for args that are 1481 in `structured_input_signature`, but not in the original fullargspec.args). 1482 1483 Replaces `defaults` and `kwonlydefaults` with the `BOUND_VALUE`, for 1484 all args and kwargs in `structured_input_signature`. 1485 1486 Sets `varkw` and `varargs` to None. 1487 """ 1488 if self._pre_initialized_function_spec is None: 1489 return # e.g., SavedBareConcreteFunction doesn't have function_spec yet. 1490 assert not self._function_spec, "already initialized" 1491 spec = self._pre_initialized_function_spec 1492 args = spec.fullargspec.args 1493 arg_specs, kwarg_specs = self.structured_input_signature 1494 vararg_indices = range(len(spec.arg_names), len(arg_specs)) 1495 fullargspec = tf_inspect.FullArgSpec( 1496 args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices], 1497 varargs=None, 1498 varkw=None, 1499 defaults=[function_spec.BOUND_VALUE] * len(arg_specs), 1500 kwonlyargs=list(sorted(kwarg_specs)), 1501 kwonlydefaults=dict( 1502 (k, function_spec.BOUND_VALUE) for k in kwarg_specs), 1503 annotations=spec.fullargspec.annotations) 1504 self._function_spec = function_spec.FunctionSpec( 1505 fullargspec, 1506 spec.is_method, 1507 spec.input_signature, 1508 spec.is_pure, 1509 name=self._func_graph.name) 1510 1511 @property 1512 def variables(self): 1513 """Sequence of variables for this function.""" 1514 return tuple(self._func_graph.variables) 1515 1516 def set_variables(self, variables): 1517 self._func_graph.variables = variables 1518 1519 @property 1520 def trainable_variables(self): 1521 """Sequence of trainable variables for this function.""" 1522 return tuple(self._func_graph.trainable_variables) 1523 1524 def __call__(self, *args, **kwargs): 1525 """Executes the wrapped function. 1526 1527 ConcreteFunctions have two signatures: 1528 1529 * The signature of the original function wrapped by this ConcreteFunction. 1530 * A flat signature, where each argument accepts a single Tensor. 1531 1532 The original function signature is generally preferred, but the flat input 1533 signature is supported for backward compatibility. 1534 1535 ### Original Function Signature 1536 1537 When calling a ConcreteFunction with the signature of the original function, 1538 each argument must match the type or value that was used when the 1539 ConcreteFunction's graph was traced. In particular: 1540 1541 * Tensor arguments (including CompositeTensors, such as RaggedTensor) must 1542 have matching `TypeSpec`s. 1543 * Non-Tensor arguments (such as booleans or ints) must have equal values. 1544 * Nested arguments (such as lists, tuples, or dictionaries) must have the 1545 same nesting structure; and each nested value must have a matching type 1546 or value. 1547 1548 The default value for any arguments that were traced with non-Tensor values 1549 is the value that was used in the trace. Arguments that were traced with 1550 tensor arguments do not have a default value (even if the original function 1551 had a default value for that argument). 1552 1553 ### Flat Signature 1554 1555 When calling a ConcreteFunction with the flat signature, the arguments 1556 correspond to the flattened component tensors of the arguments that were 1557 used to construct the ConcreteFunction. Parameter names are assigned based 1558 on `TensorSpec.name` (when specified) or the original argument names (with 1559 suffixes automatically added for nested arguments or composite tensors with 1560 multiple components). 1561 1562 Args: 1563 *args: Positional arguments to the concrete function. 1564 **kwargs: Keyword arguments to the concrete function. 1565 1566 Returns: 1567 The result of applying the TF function on the given Tensors. 1568 1569 Raises: 1570 AssertionError: If this `ConcreteFunction` was not created through 1571 `get_concrete_function`. 1572 TypeError: If the arguments do not match the function's signature. 1573 """ 1574 return self._call_impl(args, kwargs) 1575 1576 def _call_impl(self, args, kwargs, cancellation_manager=None): 1577 """See `__call__` for details.""" 1578 with trace.Trace(self._func_graph.name, tf_function_call="concrete"): 1579 # Construct the list of input tensors: check if the structured signature 1580 # applies first; and if not, then use the flat signature. 1581 if self._function_spec is not None: 1582 try: 1583 return self._call_with_structured_signature(args, kwargs, 1584 cancellation_manager) 1585 except TypeError as structured_err: 1586 try: 1587 return self._call_with_flat_signature(args, kwargs, 1588 cancellation_manager) 1589 except TypeError: 1590 raise structured_err 1591 1592 return self._call_with_flat_signature(args, kwargs, cancellation_manager) 1593 1594 def _call_with_flat_signature(self, args, kwargs, cancellation_manager): 1595 """Executes the wrapped function with the flat signature. 1596 1597 Args: 1598 args: Positional arguments to the concrete function. 1599 kwargs: Keyword arguments to the concrete function. 1600 cancellation_manager: A `CancellationManager` that can be used to cancel 1601 function invocation. 1602 1603 Returns: 1604 The result of applying the function on the Tensors/Variables contained in 1605 `args` and `kwargs`. 1606 Raises: 1607 TypeError: if `args` and `kwargs` do not match the flat signature of this 1608 `ConcreteFunction`. 1609 """ 1610 if len(args) > self._num_positional_args: 1611 raise TypeError( 1612 f"{self._flat_signature_summary()} takes {self._num_positional_args} " 1613 f"positional arguments, got {len(args)}.") 1614 args = list(args) 1615 kwargs = dict(kwargs) 1616 for keyword in self._arg_keywords[len(args):]: 1617 try: 1618 args.append(kwargs.pop(compat.as_str(keyword))) 1619 except KeyError: 1620 specified_keywords = ( 1621 list(self._arg_keywords[:len(args)]) + list(kwargs.keys())) 1622 missing_required_args = sorted( 1623 set(self._arg_keywords) - set(specified_keywords)) 1624 raise TypeError(f"{self._flat_signature_summary()} missing required " 1625 f"arguments: {', '.join(missing_required_args)}.") 1626 if kwargs: 1627 positional_arg_keywords = set(self._arg_keywords[:len(args)]) 1628 for unused_key in kwargs: 1629 if unused_key in positional_arg_keywords: 1630 raise TypeError(f"{self._flat_signature_summary()} got two values " 1631 f"for '{unused_key}'.") 1632 raise TypeError(f"{self._flat_signature_summary()} got unexpected " 1633 f"keyword arguments: {', '.join(sorted(kwargs))}.") 1634 1635 for i, arg in enumerate(args): 1636 if not isinstance( 1637 arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 1638 raise TypeError(f"{self._flat_signature_summary()}: expected argument " 1639 f"#{i}(zero-based) to be a Tensor; " 1640 f"got {type(arg).__name__} ({arg}).") 1641 return self._call_flat(args, self.captured_inputs, cancellation_manager) 1642 1643 def _call_with_structured_signature(self, args, kwargs, cancellation_manager): 1644 """Executes the wrapped function with the structured signature. 1645 1646 Args: 1647 args: Positional arguments to the concrete function. 1648 kwargs: Keyword arguments to the concrete function. 1649 cancellation_manager: A `CancellationManager` that can be used to cancel 1650 function invocation. 1651 1652 Returns: 1653 The result of applying the function on the Tensors/Variables contained in 1654 `args` and `kwargs`. 1655 Raises: 1656 TypeError: if `args` and `kwargs` do not match the structured signature 1657 of this `ConcreteFunction`. 1658 """ 1659 args, kwargs, filtered_flat_args = ( 1660 self._function_spec.canonicalize_function_inputs(args, kwargs)) 1661 self._structured_signature_check_missing_args(args, kwargs) 1662 self._structured_signature_check_unexpected_args(args, kwargs) 1663 self._structured_signature_check_arg_types(args, kwargs) 1664 return self._call_flat( 1665 filtered_flat_args, 1666 captured_inputs=self.captured_inputs, 1667 cancellation_manager=cancellation_manager) 1668 1669 def _structured_signature_check_missing_args(self, args, kwargs): 1670 """Raises a TypeError if any args are missing.""" 1671 arg_specs, kwarg_specs = self.structured_input_signature 1672 missing_arguments = [] 1673 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1674 if arg is function_spec.BOUND_VALUE and _contains_type_spec(spec): 1675 missing_arguments.append(self._function_spec.arg_names[i]) 1676 for (name, arg) in kwargs.items(): 1677 if arg is function_spec.BOUND_VALUE and _contains_type_spec( 1678 kwarg_specs[name]): 1679 missing_arguments.append(name) 1680 if missing_arguments: 1681 raise TypeError(f"{self._structured_signature_summary()} missing " 1682 "required arguments: " 1683 f"{', '.join(sorted(missing_arguments))}.") 1684 1685 def _structured_signature_check_unexpected_args(self, args, kwargs): 1686 """Raises a TypeError if there are any extra args.""" 1687 arg_specs, kwarg_specs = self.structured_input_signature 1688 if len(args) > len(arg_specs): 1689 raise TypeError( 1690 f"{self._structured_signature_summary()} takes " 1691 f"{len(self._function_spec.arg_names)} positional arguments but got " 1692 f"{len(args)}.") 1693 if len(kwargs) > len(kwarg_specs): 1694 extra_args = set(kwargs) - set(kwarg_specs) 1695 raise TypeError(f"{self._structured_signature_summary()} got unexpected " 1696 f"keyword arguments: {', '.join(extra_args)}.") 1697 1698 def _structured_signature_check_arg_types(self, args, kwargs): 1699 """Raises a TypeError if any args have the wrong type.""" 1700 # Check argument types 1701 arg_specs, kwarg_specs = self.structured_input_signature 1702 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1703 name = self._function_spec.arg_names[i] 1704 self._structured_signature_check_arg_type(arg, spec, name) 1705 for (name, arg) in kwargs.items(): 1706 self._structured_signature_check_arg_type(arg, kwarg_specs[name], name) 1707 1708 def _structured_signature_check_arg_type(self, arg, spec, name): 1709 """Raise TypeError if `arg`'s type doesn't match `spec`.""" 1710 if arg is function_spec.BOUND_VALUE: 1711 return 1712 1713 # Check the overall nested structure of the argument. 1714 try: 1715 nest.assert_same_structure(arg, spec, expand_composites=True) 1716 except (ValueError, TypeError): 1717 try: 1718 nest.assert_same_structure(arg, spec, expand_composites=False) 1719 expected, got = spec, arg 1720 except (ValueError, TypeError): 1721 expected, got = _structure_summary(spec), _structure_summary(arg) 1722 raise TypeError(f"{self._structured_signature_summary()}: argument " 1723 f"{name} had incorrect type\n" 1724 f" expected: {expected}\n" 1725 f" got: {got}") 1726 1727 # Check the type for each leaf in the nested structure. 1728 arg_pieces = nest.flatten(arg, expand_composites=True) 1729 spec_pieces = nest.flatten(spec, expand_composites=True) 1730 for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces): 1731 # TODO(mdan): Use consistent error messages. 1732 if isinstance(spec_piece, tensor_spec.DenseSpec): 1733 # TODO(edloper): Consider calling convert_to_tensor on non-tensor 1734 # values here. That would match the behavior of 1735 # _call_concrete_function() in function_deserialization.py. If 1736 # we do, then we need to change the nest assert_same_structure and 1737 # flatten calls above to use shallow variants. 1738 tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable) 1739 if not isinstance(arg_piece, tensor_types): 1740 raise TypeError(f"{self._structured_signature_summary()} expected a " 1741 f"Tensor in {name}, but got " 1742 f"{type(arg_piece).__name__} value {arg_piece}.") 1743 elif arg_piece is not function_spec.BOUND_VALUE: 1744 try: 1745 arg_matches_spec = bool(arg_piece == spec_piece) 1746 except (ValueError, TypeError): 1747 logging.vlog(1, "Error matching value with spec", exc_info=True) 1748 arg_matches_spec = False 1749 if not arg_matches_spec: 1750 raise TypeError( 1751 f"ConcreteFunction {self._structured_signature_summary()} was " 1752 f"constructed with {type(spec_piece).__name__} value " 1753 f"{spec_piece} in {name}, but was called with " 1754 f"{type(arg_piece).__name__} value {arg_piece}.") 1755 1756 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 1757 """Executes the wrapped function. 1758 1759 Args: 1760 args: a list of Tensors or Variables. Arguments from the Python function 1761 should be filtered before calling this method: objects aside from 1762 Tensors, CompositeTensors, and Variables are ignored. Any 1763 CompositeTensors should be expanded before calling this method. 1764 captured_inputs: the captured inputs that are also part of the input args 1765 to the actual execution. By default, it should be self._captured_inputs. 1766 cancellation_manager: (Optional.) A `CancellationManager` that can be 1767 used to cancel function invocation. 1768 1769 Returns: 1770 The result of applying the TF function to `args`. 1771 1772 Raises: 1773 ValueError: If `args` contains anything other than Tensors or Variables. 1774 """ 1775 ctx = context.context() 1776 executing_eagerly = ctx.executing_eagerly() 1777 1778 # Copy saveable status of function's graph to current FuncGraph. 1779 default_graph = ops.get_default_graph() 1780 if default_graph.building_function and not self._func_graph.saveable: 1781 default_graph.mark_as_unsaveable(self._func_graph.saving_errors) 1782 1783 if (tape.could_possibly_record() or 1784 hasattr(default_graph, "watch_variable")): 1785 for v in self._func_graph.variables: 1786 resource_variable_ops.variable_accessed(v) 1787 1788 tensor_inputs = [] 1789 variables_used = set([]) 1790 for i, arg in enumerate(args): 1791 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1792 # We can pass a variable more than once, and in this case we need to 1793 # pass its handle only once. 1794 if id(arg.handle) in variables_used: 1795 continue 1796 resource_variable_ops.variable_accessed(arg) 1797 tensor_inputs.append(arg.handle) 1798 variables_used.add(id(arg.handle)) 1799 elif isinstance(arg, ops.Tensor): 1800 tensor_inputs.append(arg) 1801 else: 1802 raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got " 1803 f"{type(arg)} when calling {self._func_graph.name}.") 1804 1805 if not executing_eagerly: 1806 for i, tensor_input in enumerate(tensor_inputs): 1807 # Can not compare shapes in these cases 1808 # TODO(b/216506654): Consider moving this check elsewhere and making it 1809 # work for all types (e.g. by including shape for Variables). 1810 if (tensor_input.dtype == dtypes.resource or 1811 tensor_input.dtype == dtypes.variant): 1812 continue 1813 1814 # If we're graph building, shape inference is on. We check for input 1815 # compatibility up front to avoid hard to debug incompatibilities 1816 # later. 1817 graph_input_shape = tensor_shape.TensorShape( 1818 self._func_graph.inputs[i].shape) 1819 if not graph_input_shape.is_compatible_with(tensor_input.shape): 1820 raise ValueError( 1821 f"Tensor {tensor_input} is not compatible with the shape this " 1822 f"function was traced with. Expected shape " 1823 f"{self._func_graph.inputs[i].shape}, but got shape " 1824 f"{tensor_input.shape}.\n\nIf you called get_concrete_function, " 1825 f"you may need to pass a tf.TensorSpec(..., shape=...) with a " 1826 f"less specific shape, having None on axes which can vary.") 1827 1828 args = tensor_inputs + captured_inputs 1829 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) 1830 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE 1831 and executing_eagerly): 1832 # No tape is watching; skip to running the function. 1833 return self._build_call_outputs(self._inference_function.call( 1834 ctx, args, cancellation_manager=cancellation_manager)) 1835 forward_backward = self._select_forward_and_backward_functions( 1836 args, 1837 possible_gradient_type, 1838 executing_eagerly) 1839 forward_function, args_with_tangents = forward_backward.forward() 1840 if executing_eagerly: 1841 flat_outputs = forward_function.call( 1842 ctx, args_with_tangents, cancellation_manager=cancellation_manager) 1843 else: 1844 with default_graph._override_gradient_function( # pylint: disable=protected-access 1845 {"PartitionedCall": self._get_gradient_function(), 1846 "StatefulPartitionedCall": self._get_gradient_function()}): 1847 flat_outputs = forward_function.call(ctx, args_with_tangents) 1848 forward_backward.record(flat_outputs) 1849 return self._build_call_outputs(flat_outputs) 1850 1851 def _experimental_with_cancellation_manager(self, cancellation_manager): 1852 """Returns a callable that invokes a cancellable version of this function. 1853 1854 Args: 1855 cancellation_manager: A `CancellationManager` object that can be used to 1856 cancel function invocation. 1857 1858 Returns: 1859 A callable with the same signature as this concrete function. 1860 """ 1861 1862 def cancellable_call(*args, **kwargs): 1863 return self._call_impl( 1864 args, kwargs, cancellation_manager=cancellation_manager) 1865 1866 return cancellable_call 1867 1868 @property 1869 def name(self): 1870 """`ConcreteFunction` name.""" 1871 return self._delayed_rewrite_functions.forward().name 1872 1873 @property 1874 def graph(self): 1875 """Returns the graph from which this function was constructed.""" 1876 return self._func_graph 1877 1878 @property 1879 def inputs(self): 1880 """Returns tensors in `self.graph` corresponding to arguments.""" 1881 return self._func_graph.inputs 1882 1883 @property 1884 def structured_input_signature(self): 1885 """Returns structured signature for this concrete function. 1886 1887 Returns: 1888 A tuple `(args, kwargs)`, where: 1889 1890 * `args` is a tuple that specifies the expected type or value each for 1891 positional argument. 1892 * `kwargs` is a dictionary that specifies the expected type or value 1893 for each keyword-only argument. 1894 1895 The type or value for each argument is specified using one of the 1896 following: 1897 1898 * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native 1899 value is expected. 1900 * A Python value, such as an integer, indicating that an equal value 1901 is expected. 1902 * A nested structure of `tf.TypeSpec`s and Python values, indicating 1903 that a corresponding nested structure is expected. 1904 """ 1905 return self._func_graph.structured_input_signature 1906 1907 @property 1908 def outputs(self): 1909 """Returns tensors in `self.graph` corresponding to returned tensors.""" 1910 return self._func_graph.outputs 1911 1912 @property 1913 def structured_outputs(self): 1914 """Returns outputs in `self.graph` as returned by the original function.""" 1915 return self._func_graph.structured_outputs 1916 1917 def set_external_captures(self, captures): 1918 """Updates the function capture values. 1919 1920 The new values must have tensor types and shapes consistent with the 1921 original captures of the concrete function, but it is allowed to change a 1922 value captured with a deferred one and vice-versa. 1923 1924 Args: 1925 captures: A list of tensors or closures. Tensors are value captures, and 1926 closures are call-time (deferred captures). 1927 """ 1928 # TODO(wxinyi): 1. verify that the new captures' type spec is compatible 1929 # with the original's. However, doing so requires MirroredVariable captures 1930 # initialized. 2. replace the original/new captures/deferred 1931 # captures in the wrapped graph. Doing such for a capture-to-deferred 1932 # capture replacement requires more arguments than the deferred capture 1933 # itself, e.g. default value, spec. 1934 self._captured_inputs = captures 1935 1936 def replace_capture_with_deferred_capture(self, 1937 tensor, 1938 closure, 1939 spec, 1940 placeholder=None, 1941 default_value=None): 1942 """Replaces existing capture `tensor` with a deferred capture `closure`. 1943 1944 This API replaces the capture `tensor` from the concrete function's captured 1945 inputs list, and places the deferred capture `closure` in 1946 its spot so the order of captured inputs is preserved. This is important 1947 because the old `tensor` and the new `closure` will have the same internal 1948 placeholder, which can be passed through the `placeholder` argument, or 1949 skipped, in which case we find the placeholder from internal inputs by 1950 indexing `tensor` in the external captured inputs list. Thus, it is 1951 important that the new deferred capture has output spec (specified by the 1952 `spec` argument) compatible with the internal placeholder (`placeholder`) 1953 and the original capture (`tensor`). 1954 1955 For example, 1956 1957 ```python 1958 bool_captured_tensor = tf.constant(True) 1959 float_captured_tensor = tf.constant([3.], dtype=tf.float32) 1960 value = tf.constant([2.], dtype=tf.float32) 1961 1962 @tf.function 1963 def fn(): 1964 deferred_tensor = ops.get_default_graph().capture_call_time_value( 1965 lambda: value, 1966 tf.TensorSpec(shape=(1,), dtype=tf.float32)) 1967 if bool_captured_tensor: 1968 return deferred_tensor 1969 else: 1970 return deferred_tensor + float_captured_tensor 1971 1972 concrete_fn = fn.get_concrete_function() 1973 print(concrete_fn()) # tf.Tensor([2.], shape=(1,), dtype=float32) 1974 1975 new_bool_captured_tensor = constant_op.constant(False) 1976 def bool_closure(): 1977 return new_bool_captured_tensor 1978 1979 concrete_fn.replace_capture_with_deferred_capture( 1980 bool_captured_tensor, 1981 bool_closure, 1982 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) 1983 1984 print(concrete_fn()) # tf.Tensor([5.], shape=(1,), dtype=float32) 1985 ``` 1986 1987 Args: 1988 tensor: Tensor already captured. This `tensor` should be listed in 1989 concrete_function.captured_inputs except when it's empty such as when 1990 the concrete function is restored from SavedModel. 1991 closure: function which takes no arguments, to be evaluated at function 1992 call time, returning a nest of tensors compatible with `spec`. 1993 spec: nest of TypeSpec for the value to capture. 1994 placeholder: optional. The internal placeholder corresponding to the 1995 captured `tensor` and the new `closure`. 1996 default_value: optional value to use in environments that cannot safely 1997 evaluate closure. 1998 """ 1999 capture_index = None 2000 for i, capture in enumerate(self._captured_inputs): 2001 if id(tensor) == id(capture): 2002 capture_index = i 2003 break 2004 2005 if placeholder is None: 2006 if capture_index is None: 2007 raise ValueError( 2008 f"Did not find `tensor` argument {tensor} in the ConcreteFunction's" 2009 " captured inputs list, and did not receive a placeholder argument." 2010 " Thus we're unable to infer the internal placeholder. ") 2011 2012 placeholder = self.inputs[-len(self._captured_inputs) + capture_index] 2013 2014 if not (spec.is_compatible_with(tensor) or 2015 spec.is_compatible_with(placeholder)): 2016 raise ValueError( 2017 f"Attempting to substitute closure with spec {spec} that's " 2018 f"incompatible with the original capture {tensor} or the internal " 2019 f"placeholder {placeholder}.") 2020 2021 self._func_graph.replace_capture_with_deferred_capture( 2022 tensor=tensor, 2023 closure=closure, 2024 spec=spec, 2025 placeholder=placeholder, 2026 default_value=default_value) 2027 2028 if capture_index is not None: 2029 self._captured_inputs[capture_index] = closure 2030 2031 @property 2032 def captured_inputs(self): 2033 """Returns external Tensors captured by this function. 2034 2035 self.__call__(*args) passes `args + self.captured_inputs` to the function. 2036 """ 2037 return nest.flatten( 2038 [x() if callable(x) else x for x in self._captured_inputs], 2039 expand_composites=True) 2040 2041 @property 2042 def function_def(self): 2043 """Returns a `FunctionDef` object representing this function.""" 2044 return self._delayed_rewrite_functions.forward().definition 2045 2046 @property 2047 def output_shapes(self): 2048 """The function's output shapes.""" 2049 return nest.map_structure( 2050 lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)), 2051 composite_tensor.replace_composites_with_components( 2052 self._func_graph.structured_outputs), 2053 expand_composites=False) 2054 2055 @property 2056 def output_dtypes(self): 2057 # TODO(akshayka): Consider removing this. 2058 return nest.map_structure( 2059 lambda x: x.dtype if x is not None else None, 2060 composite_tensor.replace_composites_with_components( 2061 self._func_graph.structured_outputs), 2062 expand_composites=False) 2063 2064 def add_to_graph(self, g=None): 2065 """Registers the function, adds it to the graph g or default graph. 2066 2067 Args: 2068 g: If specified, registers the function with this graph. Defaults to the 2069 current context (either the default graph or the eager context). 2070 """ 2071 # If we are not executing eagerly, adds the function to default graph if no 2072 # graph is specified. 2073 # In case of eager execution, function definition gets added to context 2074 # during construction itself. 2075 2076 if not context.executing_eagerly() and not g: 2077 g = ops.get_default_graph() 2078 self._delayed_rewrite_functions.forward().add_to_graph(g) 2079 2080 def add_gradient_functions_to_graph(self, g=None): 2081 """Add forward/backward functions to graph `g` or the current context.""" 2082 if not context.executing_eagerly() and not g: 2083 g = ops.get_default_graph() 2084 self._delayed_rewrite_functions.forward().add_to_graph(g) 2085 forward_function, backward_function = ( 2086 self._delayed_rewrite_functions.forward_backward()) 2087 forward_function.add_to_graph(g) 2088 backward_function.add_to_graph(g) 2089 2090 def _get_gradient_function(self): 2091 """Returns gradient function. It will be lazily created at first call.""" 2092 return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access 2093 2094 def _select_forward_and_backward_functions( 2095 self, args, possible_gradient_type, executing_eagerly): 2096 """Selects forward and backward functions based on the calling context. 2097 2098 The forward function computes the "real" function outputs, `self._outputs`, 2099 and any extra values needed by the corresponding backward function. 2100 2101 Args: 2102 args: A flat list of Tensors with all of the inputs to the forward 2103 function (including user-specified and captured inputs). 2104 possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. 2105 executing_eagerly: Boolean, the value of context.executing_eagerly(). 2106 2107 Returns: 2108 An object with a `forward` method returning a tuple of (forward_function : 2109 _EagerDefinedFunction, augmented_arguments : List), and a corresponding 2110 `record` method which takes outputs from the forward function and records 2111 the operation. forward_function should be called with augmented_arguments. 2112 """ 2113 if executing_eagerly: 2114 input_tangents = forwardprop_util.pack_tangents(args) 2115 else: 2116 input_tangents = forwardprop_util.TangentInfo() 2117 need_gradients_for_jvps = tape.should_record_backprop( 2118 input_tangents.tangents) 2119 # Allows re-use of forward and backward function pairs depending on the 2120 # tapes and forward accumulators watching its inputs. 2121 cache_key = (need_gradients_for_jvps, input_tangents.indices) 2122 if (possible_gradient_type 2123 == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): 2124 if input_tangents.indices or executing_eagerly: 2125 # There is a single non-persistent tape active, so the user can only 2126 # request first-order gradients from a tape. We can spend less time 2127 # graph building since we know this. 2128 # 2129 # We may still end up computing higher-order gradients, but that'd be 2130 # through `tf.gradients`, which can re-write the forward pass and so 2131 # needs no preparation here. 2132 functions = self._first_order_tape_functions.get(cache_key, None) 2133 if functions is None: 2134 functions = _FirstOrderTapeGradientFunctions( 2135 self._func_graph, self._attrs, self._garbage_collector, 2136 forwardprop_input_indices=input_tangents.indices, 2137 delayed_rewrite_functions=self._delayed_rewrite_functions, 2138 need_gradients_for_jvps=need_gradients_for_jvps) 2139 self._first_order_tape_functions[cache_key] = functions 2140 return _ForwardBackwardCall( 2141 functions, args, input_tangents.tangents, tape_watching=True) 2142 else: 2143 # We can avoid computing second-order gradients in some cases by doing a 2144 # delayed rewrite when graph building. Since we know we'll only compute 2145 # first-order tape gradients, the delayed rewrite is safe: we won't need 2146 # to tell the tape about side outputs. 2147 # 2148 # TODO(allenl): This case is really dirty. It would be better if we 2149 # could temporarily pop all of the current tapes to avoid 2150 # accidentally taking second-order gradients. 2151 return _ForwardBackwardCall( 2152 self._delayed_rewrite_functions, args, input_tangents.tangents, 2153 tape_watching=True) 2154 elif (possible_gradient_type 2155 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): 2156 # Either there's a persistent tape watching, or there are multiple nested 2157 # tapes. Either way, the user may request higher-order gradients. We'll 2158 # spend a bit more time and make sure higher-order gradients are correct. 2159 functions = self._higher_order_tape_functions.get( 2160 cache_key, None) 2161 if functions is None: 2162 functions = _HigherOrderTapeGradientFunctions( 2163 self._func_graph, self._attrs, self._garbage_collector, 2164 forwardprop_input_indices=input_tangents.indices, 2165 delayed_rewrite_functions=self._delayed_rewrite_functions, 2166 need_gradients_for_jvps=need_gradients_for_jvps) 2167 self._higher_order_tape_functions[cache_key] = functions 2168 return _ForwardBackwardCall(functions, args, input_tangents.tangents, 2169 tape_watching=True) 2170 # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no 2171 # tape is recording. 2172 return _ForwardBackwardCall( 2173 self._delayed_rewrite_functions, args, input_tangents.tangents, 2174 tape_watching=False) 2175 2176 def _build_call_outputs(self, result): 2177 """Maps the fdef output list to actual output structure. 2178 2179 Args: 2180 result: Output lists defined by FunctionDef. 2181 Returns: 2182 The actual call output. 2183 """ 2184 # TODO(jlchu): call C++ version in function.cc when speed is improved 2185 if self._func_graph.structured_outputs is None: 2186 return result 2187 2188 # Replace outputs with results, skipping over any 'None' values. 2189 outputs_list = nest.flatten( 2190 self._func_graph.structured_outputs, expand_composites=True) 2191 j = 0 2192 for i, o in enumerate(outputs_list): 2193 if o is not None: 2194 handle_data_util.copy_handle_data(self.outputs[j], result[j]) 2195 outputs_list[i] = result[j] 2196 j += 1 2197 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, 2198 outputs_list, expand_composites=True) 2199 return ret 2200 2201 @property 2202 def _as_name_attr_list(self): 2203 """Returns a `NameAttrList` representing this function.""" 2204 ret = attr_value_pb2.NameAttrList(name=self.name) 2205 for name, value in self._attrs.items(): 2206 ret.attr[name].CopyFrom(value) 2207 return ret 2208 2209 def _structured_signature_summary(self, default_values=False): 2210 """Returns a string summarizing this function's structured signature. 2211 2212 Args: 2213 default_values: If true, then include default values in the signature. 2214 2215 Returns: 2216 A `string`. 2217 """ 2218 # Note: we can't just use self._funcion_spec.signature_summary(), because 2219 # that would show "BOUND_VALUE" as the default value for all arguments. 2220 assert self._function_spec is not None 2221 arg_specs, kwarg_specs = self.structured_input_signature 2222 arg_names = list(self._function_spec.arg_names) 2223 2224 # If an explicit input_signature is provided to @tf.function, then any 2225 # arguments with defaults that are not covered by that explicit signature 2226 # are simply dropped from the signature. 2227 # TODO(b/159639913) Look into whether dropping arguments with default values 2228 # from the signature is the right thing to do. 2229 arg_names = arg_names[:len(arg_specs)] 2230 2231 if default_values: 2232 for i in range(len(arg_names)): 2233 if not _contains_type_spec(arg_specs[i]): 2234 arg_names[i] += "={}".format(arg_specs[i]) 2235 if kwarg_specs: 2236 arg_names.append("*") 2237 for name, spec in kwarg_specs.items(): 2238 arg_names.append(name) 2239 if default_values and not _contains_type_spec(spec): 2240 arg_names[-1] += "={}".format(spec) 2241 signature = f"{self._func_graph.name}({', '.join(arg_names)})" 2242 2243 return signature 2244 2245 def _flat_signature_summary(self): 2246 """Returns a string summarizing this function's flat signature.""" 2247 assert self._arg_keywords is not None 2248 assert self._num_positional_args is not None 2249 arg_names = self._arg_keywords 2250 if self._num_positional_args > len(arg_names): 2251 arg_names.extend( 2252 "<arg{}>".format(i + 1) 2253 for i in range(len(arg_names), self._num_positional_args)) 2254 return f"{self._func_graph.name}({', '.join(arg_names)})" 2255 2256 def pretty_printed_signature(self, verbose=True): 2257 """Returns a string summarizing the signature of this concrete function.""" 2258 if not verbose: 2259 return self._structured_signature_summary(default_values=True) 2260 2261 def pretty_print_spec(spec): 2262 """Returns a string describing the spec for a single argument.""" 2263 if isinstance(spec, tensor_spec.TensorSpec): 2264 return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) 2265 elif nest.is_nested(spec): 2266 pieces = nest.flatten(spec, expand_composites=False) 2267 markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))] 2268 structure = nest.pack_sequence_as(spec, markers) 2269 # Ensure dictionaries are sorted by key (for determinism) 2270 result = pprint.pformat(structure, width=10000) 2271 for (marker, piece) in zip(markers, pieces): 2272 result += "\n {}: {}".format(marker, pretty_print_spec(piece)) 2273 return result 2274 else: 2275 return repr(spec) 2276 2277 lines = [self._structured_signature_summary(default_values=True)] 2278 arg_specs, kwarg_specs = self.structured_input_signature 2279 names = list(self._function_spec.arg_names) 2280 2281 # If an explicit input_signature is provided to @tf.function, then any 2282 # arguments with defaults that are not covered by that explicit signature 2283 # are simply dropped from the signature. 2284 # TODO(b/159639913) Look into whether dropping arguments with default values 2285 # from the signature is the right thing to do. 2286 2287 # Note: we can skip bound args, since we already displayed their bound 2288 # value in the signature summary. 2289 arg_details = [] 2290 for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)): 2291 if _contains_type_spec(spec): 2292 arg_details.append(" {}: {}".format(name, pretty_print_spec(spec))) 2293 2294 if kwarg_specs: 2295 for kwarg in sorted(kwarg_specs): 2296 spec = kwarg_specs[kwarg] 2297 if _contains_type_spec(spec): 2298 arg_details.append(" {}: {}".format( 2299 kwarg, pretty_print_spec(spec))) 2300 2301 if arg_details: 2302 lines.append(" Args:") 2303 lines.extend(arg_details) 2304 lines.append(" Returns:") 2305 2306 def spec_from_value(value): 2307 # For loaded function, structured_outputs are already specs. 2308 if isinstance(value, type_spec.TypeSpec): 2309 return value 2310 return type_spec.type_spec_from_value(value) 2311 2312 lines.append(" {}".format( 2313 pretty_print_spec( 2314 nest.map_structure(spec_from_value, self.structured_outputs)))) 2315 2316 return "\n".join(lines) 2317 2318 def __repr__(self): 2319 if self._function_spec is not None: 2320 return "<ConcreteFunction {} at 0x{:X}>".format( 2321 self.pretty_printed_signature(verbose=False), id(self)) 2322 elif not (self._num_positional_args is None or self._arg_keywords is None): 2323 return "<ConcreteFunction {} at 0x{:X}>".format( 2324 self._flat_signature_summary(), id(self)) 2325 else: 2326 return object.__repr__(self) 2327 2328 def __str__(self): 2329 if self._function_spec is not None: 2330 return "ConcreteFunction {}".format(self.pretty_printed_signature()) 2331 else: 2332 return self.__repr__() 2333 2334 def _trackable_children(self, save_type="checkpoint", **kwargs): 2335 """Implements `Trackable`.""" 2336 if save_type == "checkpoint": 2337 # Checkpoint dependencies do not include functions at all. Users 2338 # expect the checkpointed variables to be saved using the model 2339 # architecture, e.g. `model.layers[1].kernel` or `model.variables`. 2340 return {} 2341 2342 captured_trackables = {} 2343 for n, (capture, _) in enumerate(self.graph.captures): 2344 if (capture.dtype not in (dtypes.variant, dtypes.resource) and 2345 not resource_variable_ops.is_resource_variable(capture)): 2346 # Variant/resource type tensors are skipped since we have no way of 2347 # getting the `Trackable` wrapper for these tensors. The wrappers are 2348 # expected to be elsewhere in the saved object graph. 2349 # TODO(b/223866972): Directly encode/decode tensor captures. 2350 2351 # Resource variable captures are also skipped at this time, to maintain 2352 # existing behavior. 2353 # TODO(b/217979389): Return the non-constant captures as children. 2354 2355 captured_trackables[f"capture_{n}"] = capture 2356 2357 return captured_trackables 2358 2359 def _deserialization_dependencies(self, children): 2360 return children 2361 2362 def _export_to_saved_model_graph(self, object_map, tensor_map, 2363 **unused_kwargs): 2364 if not self.graph.saveable: 2365 raise ValueError( 2366 (f"Unable to save function {self.name} for the following reason(s):\n" 2367 + "\n".join(self.graph.saving_errors))) 2368 self.add_to_graph() 2369 object_map[self] = function_saved_model_utils.ExportedConcreteFunction( 2370 self, tensor_map) 2371 return [] 2372 2373 2374_pywrap_utils.RegisterType("Tensor", ops.Tensor) 2375_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) 2376_pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices) 2377 2378 2379# TODO(mdan): Refactor this and clarify relationship with def_function.Function. 2380# Right now, def_function.Function is the higher level implementation. 2381class Function: 2382 """Wrapper class for the graph functions defined for a Python function. 2383 2384 See the documentation for `defun` for more information on the semantics of 2385 defined functions. 2386 2387 `Function` class is thread-compatible meaning that minimal usage of defuns 2388 (defining and calling) is thread-safe, but if users call other methods or 2389 invoke the base `python_function` themselves, external synchronization is 2390 necessary. 2391 In addition, Function is not reentrant, so recursive functions need to call 2392 the wrapped function, not the wrapper. 2393 """ 2394 2395 def __init__(self, 2396 python_function, 2397 name, 2398 input_signature=None, 2399 attributes=None, 2400 autograph=True, 2401 autograph_options=None, 2402 reduce_retracing=False, 2403 capture_by_value=None, 2404 jit_compile=None, 2405 experimental_follow_type_hints=False): 2406 """Initializes a `Function`. 2407 2408 Args: 2409 python_function: the function to be wrapped. 2410 name: the name given to it. 2411 input_signature: a possibly nested sequence of `TensorSpec` objects 2412 specifying the input signature of this function. If `None`, a separate 2413 function is instantiated for each inferred input signature. 2414 attributes: dict, extra keyword arguments that will be added as attribute 2415 of the function. 2416 autograph: whether to use autograph to compile 2417 `python_function`. See https://www.tensorflow.org/guide/autograph for 2418 more information. 2419 autograph_options: Experimental knobs to control behavior 2420 `when autograph=True`. See https://www.tensorflow.org/guide/autograph 2421 for more information. 2422 reduce_retracing: When True, `tf.function` uses 2423 `tf.types.experimental.TraceType` to trace supertypes of arguments to 2424 reduce the number of traces. 2425 capture_by_value: Experimental. Whether to capture resource variables by 2426 value or reference. If None, will inherit from a parent context or 2427 default to False. 2428 jit_compile: Force-compile the function with XLA, cf. 2429 def_function.Function doc on jit_compile. 2430 experimental_follow_type_hints: See the documentation for `tf.function`. 2431 2432 Raises: 2433 ValueError: if `input_signature` is not None and the `python_function`'s 2434 argspec has keyword arguments. 2435 """ 2436 self._python_function = python_function 2437 pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes 2438 self._function_spec = function_spec.FunctionSpec.from_function_and_signature( 2439 python_function, 2440 input_signature, 2441 is_pure=pure_function, 2442 experimental_follow_type_hints=experimental_follow_type_hints) 2443 self._name = name 2444 self._autograph = autograph 2445 self._autograph_options = autograph_options 2446 self._reduce_retracing = reduce_retracing 2447 self._function_cache = function_cache.FunctionCache() 2448 self._function_attributes = attributes or {} 2449 self._capture_by_value = capture_by_value 2450 self.tracing_count = 0 2451 # Maintein a dict of all captures: identifier -> lambda function. It's used 2452 # to get runtime values for all captures during ConcreteFunction dispatch, 2453 self._captures_container = func_graph_module.CapturesContainer() 2454 self._lock = threading.RLock() 2455 # _descriptor_cache is a of instance of a class to an instance-specific 2456 # `Function`, used to make sure defun-decorated methods create different 2457 # functions for each instance. 2458 self._descriptor_cache = weakref.WeakKeyDictionary() 2459 self._jit_compile = jit_compile 2460 self._experimental_follow_type_hints = experimental_follow_type_hints 2461 2462 def __call__(self, *args, **kwargs): 2463 """Calls a graph function specialized to the inputs.""" 2464 with self._lock: 2465 (graph_function, 2466 filtered_flat_args) = self._maybe_define_function(args, kwargs) 2467 return graph_function._call_flat( 2468 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access 2469 2470 @property 2471 def python_function(self): 2472 """Returns the wrapped Python function.""" 2473 return self._python_function # pylint: disable=protected-access 2474 2475 @property 2476 def function_spec(self): 2477 return self._function_spec 2478 2479 @property 2480 def input_signature(self): 2481 """Returns the input signature.""" 2482 return self._function_spec.input_signature 2483 2484 def _maybe_define_concrete_function(self, args, kwargs): 2485 if self.input_signature and not args and not kwargs: 2486 # TODO(b/215596825): Throw error here if multiple entries are defined. 2487 args = self.input_signature 2488 kwargs = {} 2489 2490 return self._maybe_define_function(args, kwargs) 2491 2492 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): 2493 """Returns a concrete function which cleans up its graph function.""" 2494 with self._lock: 2495 graph_function, _ = self._maybe_define_concrete_function(args, kwargs) 2496 return graph_function 2497 2498 def _get_concrete_function_internal(self, *args, **kwargs): 2499 """Bypasses error checking when getting a graph function.""" 2500 graph_function = self._get_concrete_function_internal_garbage_collected( 2501 *args, **kwargs) 2502 # We're returning this concrete function to someone, and they may keep a 2503 # reference to the FuncGraph without keeping a reference to the 2504 # ConcreteFunction object. So we won't clean up the reference cycles 2505 # manually and instead will leave them to Python's garbage collector. 2506 graph_function._garbage_collector.release() # pylint: disable=protected-access 2507 return graph_function 2508 2509 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 2510 """Returns a `ConcreteFunction` specialized to inputs and execution context. 2511 2512 Unlike `get_concrete_function(...)`, the graph will be deleted when the 2513 returned function is deleted. It's useful to avoid creating a reference 2514 cycle when you know for sure that the graph will be no longer used without 2515 the returned function. 2516 2517 Args: 2518 *args: inputs to specialize on. 2519 **kwargs: inputs to specialize on. 2520 """ 2521 if self.input_signature: 2522 self._function_spec.validate_inputs_with_signature(args, kwargs) 2523 2524 with self._lock: 2525 graph_function, _ = self._maybe_define_concrete_function(args, kwargs) 2526 seen_names = set() 2527 captured = object_identity.ObjectIdentitySet( 2528 graph_function.graph.internal_captures) 2529 # pylint: disable=protected-access 2530 graph_function._arg_keywords = [] 2531 prefix_counts = {} 2532 # pylint: enable=protected-access 2533 num_positional = 0 2534 for arg in graph_function.graph.inputs: 2535 if arg in captured: 2536 break 2537 num_positional += 1 2538 user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) 2539 proposal = user_arg_name 2540 while proposal in seen_names: 2541 index = prefix_counts.get(user_arg_name, 1) 2542 proposal = "{}_{}".format(user_arg_name, index) 2543 prefix_counts[user_arg_name] = index + 1 2544 seen_names.add(proposal) 2545 graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access 2546 # Anything can be a positional argument, in the same order as .inputs 2547 graph_function._num_positional_args = num_positional # pylint: disable=protected-access 2548 return graph_function 2549 2550 def get_concrete_function(self, *args, **kwargs): 2551 """Returns a `ConcreteFunction` specialized to inputs and execution context. 2552 2553 Args: 2554 *args: inputs to specialize on. Can be concrete values (e.g. 1) or 2555 `tf.Tensor` or `tf.TensorSpec`. 2556 **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) or 2557 `tf.Tensor` or `tf.TensorSpec`. 2558 """ 2559 graph_function = self._get_concrete_function_garbage_collected( 2560 *args, **kwargs) 2561 graph_function._garbage_collector.release() # pylint: disable=protected-access 2562 return graph_function 2563 2564 def _list_all_concrete_functions(self) -> List[ConcreteFunction]: 2565 return self._function_cache.values() 2566 2567 def __get__(self, instance, owner): 2568 """Makes it possible to defun instance methods.""" 2569 del owner 2570 # `instance` here is the instance that this `Function` was accessed through 2571 # e.g., for 2572 # 2573 # class Foo: 2574 # 2575 # @function.defun 2576 # def bar(self): 2577 # ... 2578 # 2579 # foo = Foo() 2580 # foo.bar() # `foo.bar` is a `Function` instance 2581 # 2582 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 2583 # new instance of `Function` here to allow different instances each 2584 # to create variables once, thereby allowing methods to be decorated with 2585 # defun. Keeps a cache to avoid retracing the function every time the 2586 # descriptor is accessed. 2587 if instance not in self._descriptor_cache: 2588 if instance is None: 2589 return self 2590 # If there is no instance-specific `Function` in the cache, we construct 2591 # an instance-specific `Function` that uses a weak reference to the 2592 # instance (so that the instance will be correctly gc'd). 2593 2594 # And finally add the wrapped function to the description cache 2595 self._descriptor_cache[instance] = class_method_to_instance_method( 2596 self, instance) 2597 2598 # Return the cached `Function` for the instance 2599 return self._descriptor_cache[instance] 2600 2601 def _create_graph_function(self, args, kwargs): 2602 """Create a `ConcreteFunction` from `args` and `kwargs`.""" 2603 self.tracing_count += 1 2604 2605 arglen = len(args) 2606 base_arg_names = self._function_spec.arg_names[:arglen] 2607 num_missing_args = arglen - len(self._function_spec.arg_names) 2608 missing_arg_names = [self._function_spec.vararg_name] * num_missing_args 2609 # Produce a list of missing args of the form ["arg_0", "arg_1", ...], 2610 # where arg is based on the self._function_spec.vararg_name. 2611 missing_arg_names = [ 2612 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) 2613 ] 2614 arg_names = base_arg_names + missing_arg_names 2615 graph_function = ConcreteFunction( 2616 func_graph_module.func_graph_from_py_func( 2617 self._name, 2618 self._python_function, 2619 args, 2620 kwargs, 2621 None, 2622 autograph=self._autograph, 2623 autograph_options=self._autograph_options, 2624 arg_names=arg_names, 2625 capture_by_value=self._capture_by_value), 2626 self._function_attributes, 2627 spec=self.function_spec, 2628 # Tell the ConcreteFunction to clean up its graph once it goes out of 2629 # scope. This is not the default behavior since it gets used in some 2630 # places (like Keras) where the FuncGraph lives longer than the 2631 # ConcreteFunction. 2632 shared_func_graph=False) 2633 return graph_function 2634 2635 def _maybe_define_function(self, args, kwargs): 2636 """Gets a function for these inputs, defining it if necessary. 2637 2638 Caller must hold self._lock. 2639 2640 Args: 2641 args: The varargs for the Python function. 2642 kwargs: The keyword args for the Python function. 2643 2644 Returns: 2645 A graph function corresponding to the input signature implied by args and 2646 kwargs, as well as filtered flattened inputs (only Tensors and Variables) 2647 that the object should be called with. 2648 2649 Raises: 2650 ValueError: If inputs are incompatible with the input signature. 2651 TypeError: If the function inputs include non-hashable objects 2652 RuntimeError: If there's an internal bug (inconsistency) in handling 2653 shape relaxation retracing. 2654 """ 2655 args, kwargs, filtered_flat_args = ( 2656 self._function_spec.canonicalize_function_inputs(args, kwargs)) 2657 2658 if self.input_signature is not None: 2659 args = self.input_signature 2660 kwargs = {} 2661 2662 # Get runtime values of captures 2663 captures = self._captures_container.get_snapshot() 2664 2665 # cache_key_deletion_observer is useless here. It's based on all captures. 2666 # A new cache key will be built later when saving ConcreteFunction because 2667 # only active captures should be saved. 2668 lookup_func_key, _ = function_context.make_cache_key((args, kwargs), 2669 captures) 2670 graph_function = self._function_cache.lookup(lookup_func_key, True) 2671 if graph_function is not None: 2672 return graph_function, filtered_flat_args 2673 2674 with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()): 2675 with trace.Trace("tf.function-graph_building"): 2676 logging.vlog(1, 2677 "Creating new FuncGraph for Python function %r (key: %r)", 2678 self._python_function, lookup_func_key) 2679 logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", 2680 args, kwargs) 2681 ag_status = ( 2682 ag_ctx.Status.ENABLED 2683 if self._autograph else ag_ctx.Status.DISABLED) 2684 with ag_ctx.ControlStatusCtx( 2685 status=ag_status, options=self._autograph_options): 2686 if self.input_signature is None and self._reduce_retracing: 2687 generalized_func_key = self._function_cache.generalize( 2688 lookup_func_key) 2689 # Only get placeholders for arguments, not captures 2690 args, kwargs = generalized_func_key._placeholder_value()["args"] # pylint: disable=protected-access 2691 2692 graph_function = self._create_graph_function(args, kwargs) 2693 2694 graph_capture_container = graph_function.graph._capture_func_lib # pylint: disable=protected-access 2695 # Maintain the list of all captures 2696 self._captures_container.update(graph_capture_container) 2697 # Get current active captures snapshot 2698 captures = graph_capture_container.get_snapshot() 2699 2700 # Create a cache_key with args and captures 2701 traced_func_key, traced_func_deletion_observer = ( 2702 function_context.make_cache_key((args, kwargs), captures)) 2703 2704 self._function_cache.add(traced_func_key, 2705 traced_func_deletion_observer, 2706 graph_function) 2707 2708 return graph_function, filtered_flat_args 2709 2710 2711def register(func, *args, **kwargs): 2712 """Register a specialization of a `Function` into the graph. 2713 2714 This won't actually call the function with the inputs, and only put the 2715 function definition into graph. Register function with different input param 2716 will result into multiple version of functions registered in graph. 2717 2718 Args: 2719 func: the `Function` instance that generated by a @defun 2720 *args: input arguments for the Python function. 2721 **kwargs: input keyword arguments for the Python function. 2722 2723 Returns: 2724 a `ConcreteFunction` object specialized to inputs and execution context. 2725 2726 Raises: 2727 ValueError: When the input function is not a defun wrapped python function. 2728 """ 2729 if not isinstance(func, Function): 2730 raise ValueError("Only defun function is allowed to be registered. " 2731 f"Got {func} with type {type(func)}.") 2732 concrete_func = func.get_concrete_function(*args, **kwargs) 2733 concrete_func.add_to_graph() 2734 concrete_func.add_gradient_functions_to_graph() 2735 return concrete_func 2736 2737 2738def defun(func=None, 2739 input_signature=None, 2740 autograph=True, 2741 experimental_autograph_options=None, 2742 reduce_retracing=False): 2743 """Compiles a Python function into a callable TensorFlow graph. 2744 2745 `defun` (short for "define function") compiles a Python function 2746 composed of TensorFlow operations into a callable that executes a `tf.Graph` 2747 containing those operations. The callable produced by `defun` contains only 2748 the subgraph of TensorFlow operations that were executed when the Python 2749 function was called with a particular input signature, defined as a list 2750 of the shapes and dtypes of the Python function's Tensor-valued arguments and 2751 the values of its non-Tensor Python objects. 2752 2753 When eager execution is enabled, the ability to create graphs from Python 2754 functions makes it possible to incrementally trade off debuggability and 2755 interactivity for performance. Functions compiled with `defun` cannot be 2756 inspected with `pdb`; however, executing a graph 2757 generated by `defun` sometimes takes less time and memory than eagerly 2758 executing the corresponding Python function, since specifying computations as 2759 graphs allows for optimizations like automatic buffer reuse and 2760 parallelization among ops. Note that executing a `defun`-compiled function 2761 incurs a small constant overhead, so eagerly executing sufficiently small 2762 Python functions might take less time than executing their corresponding 2763 `defun`-generated graphs. 2764 2765 For a Python function to be compatible with `defun`, all of its arguments must 2766 be hashable Python objects or lists thereof. The function itself may not 2767 modify the list/map structure of its arguments. Additionally, it must return 2768 zero or more `tf.Tensor` objects. If the Python function returns 2769 a `tf.Variable`, its compiled version will return the value of that variable 2770 as a `tf.Tensor`. 2771 2772 Executing a graph generated by `defun` respects device annotations (i.e., 2773 all `with tf.device` directives present in a Python function will also be 2774 present in its corresponding graph), but it is not yet possible to execute the 2775 generated graphs across multiple machines. 2776 2777 _Example Usage_ 2778 2779 ```python 2780 import tensorflow as tf 2781 2782 tf.compat.v1.enable_eager_execution() 2783 2784 # A simple example. 2785 def f(x, y): 2786 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 2787 2788 g = tf.contrib.eager.defun(f) 2789 2790 x = tf.constant([[2.0, 3.0]]) 2791 y = tf.constant([[3.0, -2.0]]) 2792 2793 # `f` and `g` will return the same value, but `g` will be executed as a 2794 # TensorFlow graph. 2795 assert f(x, y).numpy() == g(x, y).numpy() 2796 2797 # `defun` is capable of compiling Python functions that close over Python 2798 # objects, including Tensors and Variables. 2799 @tf.contrib.eager.defun 2800 def h(): 2801 return f(x, y) 2802 2803 assert (h().numpy() == f(x, y).numpy()).all() 2804 2805 # `defun` automatically lifts variables out of the graphs it creates, 2806 # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and 2807 # `tf.keras.Model` objects. 2808 class MyModel(tf.keras.Model): 2809 2810 def __init__(self, keep_probability=0.2): 2811 super().__init__() 2812 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 2813 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 2814 self.keep_probability = keep_probability 2815 2816 @tf.contrib.eager.defun 2817 def call(self, inputs, training=True): 2818 x = self.dense2(self.dense1(inputs)) 2819 if training: 2820 return tf.nn.dropout(x, self.keep_probability) 2821 else: 2822 return x 2823 2824 model = MyModel() 2825 model(x, training=True) # executes a graph, with dropout 2826 model(x, training=False) # executes a graph, without dropout 2827 2828 # `defun`-compiled functions are differentiable. 2829 optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01) 2830 with tf.GradientTape() as tape: 2831 outputs = model(x) 2832 gradient = tape.gradient(outputs, model.trainable_variables) 2833 optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, 2834 model.trainable_variables)) 2835 ``` 2836 2837 When using `defun`, there are subtleties regarding inputs, Python control 2838 flow, and variable creation that one should be aware of. For concreteness, let 2839 `f` be a Python function that returns zero or more `tf.Tensor` objects and 2840 let `F = defun(f)`. `F` builds a graph for each unique input signature it 2841 sees, Python control flow is baked into graphs, and operations related to 2842 variable initialization are automatically lifted out of the graphs that `F` 2843 generates and placed in the eager context if executing eagerly or into an 2844 outer graph otherwise. 2845 2846 _Input Signatures_ 2847 2848 By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph 2849 for every unique sequence of the shapes and dtypes of Tensor arguments and 2850 the values of Python objects it is invoked with. For example, calling 2851 `F(tf.random.uniform([2])` will execute a different graph than 2852 `F(tf.random.uniform([3])` because the two inputs have different shapes. 2853 The first time that `F(*args, **kwargs)` is called with a particular sequence 2854 of Tensor shapes and dtypes and Python values, it constructs a graph by 2855 tracing the execution of `f(*args, **kwargs)`; this graph is bound to an 2856 input signature inferred from `(*args, **kwargs)` and cached for future reuse. 2857 2858 NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects 2859 before being passed to `f`, and are treated as Tensors for caching. This 2860 allows a function to be called multiple times with NumPy arrays having 2861 different values but the same shape and dtype without re-tracing each time. 2862 2863 `tf.contrib.eager.defun` caches graphs for your convenience, letting you 2864 define TensorFlow functions without explicitly specifying their signatures. 2865 However, this policy is conservative and potentially expensive; for example, 2866 when different invocations of your function have differently-shaped Tensor 2867 inputs, this policy might generate more graph functions than necessary. To 2868 eliminate such costs, `tf.contrib.eager.defun` allows you to supply an 2869 optional `input_signature` argument specifying the shapes and dtypes of the 2870 inputs. In particular, the shapes may be partially unspecified, with `None`s 2871 in the unknown dimensions. When an input signature is provided, 2872 `tf.contrib.eager.defun` will only instantiate a single graph for the 2873 decorated Python function. The following is an example: 2874 2875 ```python 2876 import tensorflow as tf 2877 2878 # The first `TensorSpec` below describes the shape and dtype of `words`, 2879 # and the second describes the shape and dtype of `another_tensor`. Note that 2880 # the last dimension of the `words` `TensorSpec` is left unspecified. 2881 @tf.contrib.eager.defun(input_signature=[ 2882 tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), 2883 tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) 2884 ]) 2885 def my_sequence_model(words, another_tensor): 2886 ... 2887 2888 # Note how the third dimension of the first input can vary freely. 2889 words = tf.random.uniform(([50, 300, 10]) 2890 second_input = tf.random.uniform([300, 100]) 2891 my_sequence_model(words, second_input) 2892 2893 words = tf.random.uniform(([50, 300, 20]) 2894 my_sequence_model(words, second_input) 2895 2896 # Passing an input with an incompatible shape will raise an error. 2897 words = tf.random.uniform(([50, 100, 20]) 2898 my_sequence_model(words, second_input) # <---- This will raise an error. 2899 2900 ``` 2901 2902 Python functions that are compiled with an `input_signature` must only accept 2903 Tensors as arguments and must not take unnamed keyword arguments (**kwargs). 2904 2905 _Tracing_ 2906 2907 Be aware that because `F` only logs TensorFlow operations, all the other 2908 Python code that `f` executes will only shape the _construction_ of the graphs 2909 that `F` executes: the Python code won't be executed when the graphs 2910 themselves are executed, though it will be executed every time the Python 2911 function is traced (and a given Python function might be traced multiple 2912 times, once for each input signature it is invoked with). For example, whereas 2913 the Python function 2914 2915 ```python 2916 import tensorflow as tf 2917 import numpy as np 2918 2919 tf.compat.v1.enable_eager_execution() 2920 2921 def add_noise(): 2922 return tf.eye(5) + np.random.randn(5, 5) 2923 ``` 2924 2925 will return a different output everytime it is invoked, the compiled function 2926 `compiled = tf.contrib.eager.defun(add_noise)` will return the same value 2927 every time it is called, since a particular random offset generated by NumPy 2928 will be inserted into the graph as a TensorFlow constant. The solution is to 2929 replace the call to `np.random.randn` with `tf.random.normal((5, 5))`. 2930 2931 _Python Side-Effects_ 2932 2933 A corollary of the previous discussion on tracing is the following: If a 2934 Python function `f` has Python side-effects, then executing `f` multiple times 2935 will not necessarily be semantically equivalent to executing `F = 2936 tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact 2937 that `defun` only captures the subgraph of TensorFlow operations that is 2938 constructed when `f` is called in a graph-building context. 2939 2940 _Python Control Flow_ 2941 2942 The structure of many machine learning computations depend upon whether one is 2943 training or validating, and it is common to nest specialized logic under `if 2944 training:` blocks. By mapping each input signature to a unique graph, `defun` 2945 lets users transparently compile such code, as the following code snippet 2946 demonstrates: 2947 2948 ```python 2949 import tensorflow as tf 2950 2951 tf.compat.v1.enable_eager_execution() 2952 2953 @tf.contrib.eager.defun 2954 def lossy_matmul(W, x, training=True): 2955 outputs = tf.matmul(W, x) 2956 if training: 2957 outputs = tf.nn.dropout(outputs, keep_probability=0.2) 2958 return outputs 2959 2960 W = tf.random.normal((3, 5)) 2961 x = tf.random.normal((5, 1)) 2962 2963 # Executes a graph that applies dropout. 2964 lossy_outputs = lossy_matmul(W, x, training=True) 2965 2966 # Executes a graph that does not apply dropout. 2967 exact_outputs = lossy_matmul(W, x, training=False) 2968 ``` 2969 2970 _TensorFlow Control Flow_ 2971 2972 When `autograph` is `True`, data-dependent control flow is allowed as well. 2973 Control flow statements that depend on `Tensor` values are staged into 2974 corresponding TensorFlow ops. For example, the following code will work as 2975 expected: 2976 2977 ```python 2978 @tf.contrib.eager.defun 2979 def dynamic_rnn_loop(cell, seq): 2980 state, output = cell.zero_state() 2981 for input in seq: 2982 state, output = cell(input, state) 2983 return output 2984 ``` 2985 2986 For more information see `tf.autograph`. 2987 2988 _Variables_ 2989 2990 TensorFlow operations related to variable creation and initialization are 2991 automatically lifted out of the graphs generated by `defun`. In practice, this 2992 implies that variable creation and initialization only happen the first time 2993 `F` is called, and that variables are reused every time thereafter. Many 2994 TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the 2995 first time they are called and reuse them thereafter. Automatic variable 2996 lifting makes it possible to compile these APIs without extra effort, at the 2997 cost of introducing a discrepancy between the semantics of executing Python 2998 functions and their corresponding compiled functions. For example: 2999 3000 ```python 3001 import tensorflow as tf 3002 3003 tf.compat.v1.enable_eager_execution() 3004 3005 def fn(): 3006 x = tf.Variable(0.0) 3007 x.assign_add(1.0) 3008 return x.read_value() 3009 3010 # `fn` is a Python function, so x is created, initialized, and destroyed upon 3011 # every invocation 3012 assert fn().numpy() == fn().numpy() == 1.0 3013 3014 compiled = tf.contrib.eager.defun(fn) 3015 3016 # Compiling `fn` with `defun` hoists all variables outside of the generated 3017 # graph, so initialization happens exactly once. 3018 assert compiled().numpy() == 1.0 3019 assert compiled().numpy() == 2.0 3020 ``` 3021 3022 Finally, because each input signature is bound to a unique graph, if your 3023 Python function constructs `tf.Variable` objects, then each graph constructed 3024 for that Python function will reference a unique set of variables. To 3025 circumvent this problem, we recommend against compiling Python functions that 3026 create `tf.Variable` objects. Instead, Python functions should either 3027 lexically close over `tf.Variable` objects or accept them as arguments, 3028 preferably encapsulated in an object-oriented container. If you must create 3029 variables inside your Python function and you want each graph generated for it 3030 to reference the same set of variables, add logic to your Python function that 3031 ensures that variables are only created the first time it is called and are 3032 reused for every subsequent invocation; note that this is precisely what 3033 `tf.keras.layers.Layer` objects do, so we recommend using them to represent 3034 variable-bearing computations whenever possible. 3035 3036 Args: 3037 func: function to be compiled. If `func` is None, returns a 3038 decorator that can be invoked with a single argument - `func`. The 3039 end result is equivalent to providing all the arguments up front. 3040 In other words, defun(input_signature=...)(func) is equivalent to 3041 defun(func, input_signature=...). The former allows 3042 the following use case: 3043 @tf.contrib.eager.defun(input_signature=...) 3044 def foo(...): 3045 ... 3046 3047 input_signature: A possibly nested sequence of 3048 `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of 3049 the Tensors that will be supplied to this function. If `None`, a separate 3050 function is instantiated for each inferred input signature. If a 3051 signature is specified, every input to `func` must be a `Tensor`, and 3052 `func` cannot accept `**kwargs`. 3053 autograph: Whether `func` should be compiled before 3054 constructing the graph. See https://www.tensorflow.org/guide/autograph 3055 for more information. 3056 experimental_autograph_options: Experimental knobs (in the form of a tuple 3057 of tensorflow.autograph.Feature values) to control behavior when 3058 autograph=True. 3059 reduce_retracing: When True, `tf.function` uses 3060 `tf.types.experimental.TraceType` to trace supertypes of arguments to 3061 reduce the number of traces. 3062 3063 Returns: 3064 If `func` is not None, returns a callable that will execute the compiled 3065 function (and return zero or more `tf.Tensor` objects). 3066 If `func` is None, returns a decorator that, when invoked with a single 3067 `func` argument, returns a callable equivalent to the case above. 3068 3069 Raises: 3070 TypeError: If `input_signature` is neither `None` nor a sequence of 3071 `tf.contrib.eager.TensorSpec` objects. 3072 """ 3073 return defun_with_attributes( 3074 func=func, 3075 input_signature=input_signature, 3076 autograph=autograph, 3077 experimental_autograph_options=experimental_autograph_options, 3078 reduce_retracing=reduce_retracing) 3079 3080 3081@tf_export("__internal__.function.defun_with_attributes", v1=[]) 3082def defun_with_attributes(func=None, 3083 input_signature=None, 3084 attributes=None, 3085 autograph=True, 3086 experimental_autograph_options=None, 3087 jit_compile=None, 3088 reduce_retracing=False, 3089 experimental_follow_type_hints=False): 3090 """Compiles a Python function into a callable TensorFlow graph. 3091 3092 This function supports adding extra function attributes. See detailed 3093 documentation in defun(). Currently this is not exposed in public API since we 3094 don't expect user to directly use attributes, and attribute won't work by 3095 itself. This assumption might change in future. 3096 3097 Args: 3098 func: function to be compiled. 3099 input_signature: same as defun()'s input_signature. 3100 attributes: A dictionary of arguments which will be added to function def as 3101 attributes. Currently only support primitive types as value, and only 3102 allowlisted attribute name is allowed. Unallowlisted attribute name or 3103 unsupported value will result into ValueError. `func_name` is also one of 3104 the allowlisted argument which is a python string, and sets the name for 3105 this `ConcreteFunction` in the graph. 3106 autograph: same as defun()'s autograph. 3107 experimental_autograph_options: same as defun()'s 3108 experimental_autograph_options. 3109 jit_compile: same as defun()'s jit_compile. 3110 reduce_retracing: same as defun()'s reduce_retracing 3111 experimental_follow_type_hints: see `tf.function`. 3112 3113 Returns: 3114 Same as the return value of defun, with attributes added to the function in 3115 graph. 3116 """ 3117 3118 # TODO(apassos): deal with captured global state. Deal with control flow. 3119 def decorated(function): 3120 try: 3121 if attributes: 3122 name = attributes.pop("func_name", function.__name__) 3123 else: 3124 name = function.__name__ 3125 except AttributeError: 3126 name = "function" 3127 return tf_decorator.make_decorator( 3128 function, 3129 Function( 3130 function, 3131 name, 3132 input_signature=input_signature, 3133 attributes=attributes, 3134 autograph=autograph, 3135 autograph_options=experimental_autograph_options, 3136 jit_compile=jit_compile, 3137 reduce_retracing=reduce_retracing, 3138 experimental_follow_type_hints=experimental_follow_type_hints)) 3139 3140 # This code path is for the `foo = tfe.defun(foo, ...)` use case 3141 if func is not None: 3142 return decorated(func) 3143 3144 # This code path is for the 3145 # 3146 # @tfe.defun(...) 3147 # def foo(...): 3148 # ... 3149 # 3150 # use case, which is equivalent to `foo = tfe.defun(...)(foo)` 3151 return decorated 3152 3153 3154# When a method is bound to objects of this type, it allows AutoGraph to 3155# recover a weak reference the original method's self pointer, so that it can 3156# execute it consistent with class_method_to_instance_method's 3157# bound_method_wrapper. 3158# TODO(b/119246461): This is not pretty. Use a descriptor instead? 3159class TfMethodTarget: 3160 """Binding target for methods replaced by function and defun.""" 3161 3162 __slots__ = ("weakrefself_target__", "weakrefself_func__") 3163 3164 def __init__(self, target, original_python_function): 3165 self.weakrefself_target__ = target 3166 self.weakrefself_func__ = weakref.ref(original_python_function) 3167 3168 @property 3169 def target(self): 3170 return self.weakrefself_target__() 3171 3172 @property 3173 def target_class(self): 3174 true_self = self.weakrefself_target__() 3175 if tf_inspect.isclass(true_self): 3176 # Class method 3177 return true_self 3178 else: 3179 return true_self.__class__ 3180 3181 def call(self, args, kwargs): 3182 wrapped_fn = self.weakrefself_func__() 3183 return wrapped_fn(self.weakrefself_target__(), *args, **kwargs) 3184 3185 3186def class_method_to_instance_method(original_function, instance): 3187 """Constructs a new `Function` with `self` bound.""" 3188 weak_instance = weakref.ref(instance) 3189 3190 # Note: while we could bind to a weakref proxy instead, that causes the 3191 # bound method to be unhashable. 3192 bound_method = types_lib.MethodType( 3193 original_function.python_function, 3194 TfMethodTarget(weak_instance, original_function.python_function)) 3195 3196 # original_function is expected to be of one of the two `Function` types 3197 # (defined either in function.py or def_function.py). 3198 assert hasattr(original_function, "_name") 3199 assert hasattr(original_function, "_autograph") 3200 assert hasattr(original_function, "_function_spec") 3201 assert hasattr(original_function, "python_function") 3202 3203 weak_bound_method_wrapper = None 3204 def bound_method_wrapper(*args, **kwargs): 3205 """Wraps either a dummy MethodType or a converted AutoGraph function.""" 3206 # __wrapped__ allows AutoGraph to swap in a converted function. 3207 strong_bound_method_wrapper = weak_bound_method_wrapper() 3208 wrapped_fn = strong_bound_method_wrapper.__wrapped__ 3209 3210 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: 3211 # If __wrapped__ was not replaced, then call original_function. 3212 # TODO(mdan): For better consistency, use the wrapper's call(). 3213 wrapped_fn = original_function.python_function 3214 return wrapped_fn(weak_instance(), *args, **kwargs) 3215 3216 # If __wrapped__ was replaced, then it is always an unbound function. 3217 # However, the replacer is still responsible for attaching self properly. 3218 # TODO(mdan): Is it possible to do it here instead? 3219 return wrapped_fn(*args, **kwargs) 3220 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) 3221 3222 # pylint: disable=protected-access 3223 # We make a dummy MethodType object to generate the correct bound method 3224 # signature. The actual call is to a function with a weak reference to 3225 # `instance`. 3226 instance_func = type(original_function)( 3227 tf_decorator.make_decorator(bound_method, bound_method_wrapper), 3228 name=original_function._name, 3229 autograph=original_function._autograph, 3230 input_signature=original_function.input_signature, 3231 reduce_retracing=original_function._reduce_retracing, 3232 jit_compile=original_function._jit_compile) 3233 # pylint: enable=protected-access 3234 3235 # We wrap the bound method with tf_decorator so inspection works correctly 3236 wrapped_instance_func = tf_decorator.make_decorator(bound_method, 3237 instance_func) 3238 return wrapped_instance_func 3239 3240 3241class ConcreteFunctionGarbageCollector: 3242 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" 3243 3244 __slots__ = ["_func_graph"] 3245 3246 def __init__(self, func_graph): 3247 self._func_graph = func_graph 3248 3249 def release(self): 3250 """Call off the FuncGraph deletion.""" 3251 self._func_graph = None 3252 3253 def __del__(self): 3254 if func_graph_module is None or memory is None or self._func_graph is None: 3255 return 3256 try: 3257 func_graph_module.dismantle_func_graph(self._func_graph) 3258 except: # pylint: disable=bare-except 3259 pass 3260 3261 3262class _Marker(object): 3263 """Markers used to pretty-print nested args in function signatures.""" 3264 3265 __slots__ = ["_s"] 3266 3267 def __init__(self, s): 3268 self._s = s 3269 3270 def __repr__(self): 3271 return str(self._s) 3272 3273 3274def _structure_summary(structure): 3275 """Displays a summary of the nesting structure of the given value.""" 3276 3277 def type_name(x): 3278 if isinstance(x, type_spec.TypeSpec): 3279 return x.value_type.__name__ 3280 else: 3281 return type(x).__name__ 3282 3283 markers = [_Marker(type_name(v)) for v in nest.flatten(structure)] 3284 return str(nest.pack_sequence_as(structure, markers)) 3285 3286 3287def _contains_type_spec(value): 3288 return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value)) 3289