1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17import collections 18import contextlib 19 20from tensorflow.core.framework import attr_value_pb2 21from tensorflow.python import pywrap_tfe 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import backprop_util 24from tensorflow.python.eager import context 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import composite_tensor_gradient 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import function as framework_function 29from tensorflow.python.framework import indexed_slices 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework.func_graph import FuncGraph 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import control_flow_state 36from tensorflow.python.ops import control_flow_util 37from tensorflow.python.ops import default_gradient 38from tensorflow.python.ops import functional_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import resource_variable_ops 41from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.util import compat 44from tensorflow.python.util import object_identity 45from tensorflow.python.util import variable_utils 46from tensorflow.python.util.compat import collections_abc 47from tensorflow.python.util.tf_export import tf_export 48 49 50def _MarkReachedOps(from_ops, reached_ops, func_graphs): 51 """Mark all ops reached from "from_ops". 52 53 Args: 54 from_ops: list of Operations. 55 reached_ops: set of Operations. 56 func_graphs: list of FuncGraphs. This method will traverse through 57 these functions if they capture from_ops or any reachable ops. 58 """ 59 queue = collections.deque() 60 queue.extend(from_ops) 61 while queue: 62 op = queue.popleft() 63 if op not in reached_ops: 64 reached_ops.add(op) 65 for output in op.outputs: 66 if backprop_util.IsTrainable(output): 67 queue.extend(_Consumers(output, func_graphs)) 68 69 70def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, 71 xs_set): 72 """Initialize the pending count for ops between two lists of Operations. 73 74 'pending_count[op]' indicates the number of backprop inputs 75 to this operation. 76 77 Args: 78 to_ops: list of Operations. 79 from_ops: list of Operations. 80 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 81 func_graphs: list of FuncGraphs. This method will traverse through 82 these functions if they capture from_ops or any reachable ops. This is 83 useful if to_ops occur in a function and from_ops are in an outer function 84 or graph. 85 xs_set: ObjectIdentitySet of Tensors. 86 87 Returns: 88 A tuple containing: (1) the subset of to_ops reachable from from_ops by a 89 path of zero or more backpropagatable tensors, (2) a mapping from operation 90 to the number of backprop inputs to that op, and (3) a ControlFlowState 91 object which is not None if the ops between from_ops and to_ops contain 92 control flow loops. 93 """ 94 # Mark reachable ops from from_ops. 95 reached_ops = set() 96 _MarkReachedOps(from_ops, reached_ops, func_graphs) 97 # X in reached_ops iff X is reachable from from_ops by a path of zero or more 98 # backpropagatable tensors. 99 100 reachable_to_ops = set(op for op in to_ops if op in reached_ops) 101 102 # Mark between ops. 103 between_ops = set() 104 between_op_list = [] 105 queue = collections.deque() 106 queue.extend(to_ops) 107 while queue: 108 op = queue.popleft() 109 # We are interested in this op. 110 if op in reached_ops: 111 between_ops.add(op) 112 between_op_list.append(op) 113 # Clear the boolean so we won't add the inputs again. 114 reached_ops.remove(op) 115 for inp in _NonEagerInputs(op, xs_set): 116 queue.append(inp.op) 117 # X in between_ops iff X is on a path of zero or more backpropagatable tensors 118 # between from_ops and to_ops 119 120 # 'loop_state' is None if there are no while loops. 121 loop_state = control_flow_state.MaybeCreateControlFlowState( 122 between_op_list, between_ops, colocate_gradients_with_ops) 123 124 # Initialize pending count for between ops. 125 pending_count = collections.defaultdict(int) 126 for op in between_op_list: 127 for x in _NonEagerInputs(op, xs_set): 128 if x.op in between_ops: 129 pending_count[x.op] += 1 130 131 return reachable_to_ops, pending_count, loop_state 132 133 134def _AsList(x): 135 return x if isinstance(x, (list, tuple)) else [x] 136 137 138def _DefaultGradYs(grad_ys, 139 ys, 140 colocate_gradients_with_ops, 141 gradient_uid="__unsupported__"): 142 """Fill in default values for grad_ys. 143 144 Args: 145 grad_ys: List of gradients, can contain None. 146 ys: List of tensors. 147 colocate_gradients_with_ops: If True, try colocating gradients with 148 the corresponding op. 149 gradient_uid: A unique identifier within the graph indicating 150 which invocation of gradients is being executed. Used to cluster 151 ops for compilation. 152 153 Returns: 154 A list of gradients to use, without None. 155 156 Raises: 157 ValueError: If sizes of gradients and inputs don't match 158 TypeError: If type of any gradient is not valid for its input. 159 """ 160 if len(grad_ys) != len(ys): 161 raise ValueError(f"Length mismatch. Passed {len(grad_ys)} grad_ys for " 162 f"{len(ys)} ys") 163 grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") 164 new_grad_ys = [] 165 for i, (y, grad_y) in enumerate(zip(ys, grad_ys)): 166 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): 167 if grad_y is None: 168 if y.dtype.is_complex: 169 raise TypeError( 170 f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = " 171 f"{dtypes.as_dtype(y.dtype).name})") 172 new_grad_ys.append( 173 array_ops.ones( 174 array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i)) 175 continue 176 if y.dtype.is_floating or y.dtype.is_integer: 177 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 178 raise TypeError( 179 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 180 f"for real or integer-valued tensor {y} with type " 181 f"{dtypes.as_dtype(y.dtype).name} must be real or integer") 182 elif y.dtype.is_complex: 183 if not grad_y.dtype.is_complex: 184 raise TypeError( 185 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 186 f"for complex-valued tensor {y} with type " 187 f"{dtypes.as_dtype(y.dtype).name} must be real") 188 elif y.dtype == dtypes.variant: 189 if grad_y.dtype != dtypes.variant: 190 raise TypeError( 191 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 192 f"for variant tensor {y} with type " 193 f"{dtypes.as_dtype(y.dtype).name} must be variant") 194 elif y.dtype == dtypes.resource: 195 # We assume y is the handle of a ResourceVariable. The gradient of a 196 # ResourceVariable should be a numeric value, not another resource. 197 if grad_y.dtype == dtypes.resource: 198 raise TypeError(f"Input gradient {grad_y} for resource tensor {y} " 199 "should not be a resource") 200 else: 201 raise TypeError( 202 f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be " 203 "numeric to obtain a default gradient") 204 # Create a grad_y tensor in the name scope of the gradient. 205 # Required for TensorArrays to identify which gradient call a 206 # grad_y value is coming from. 207 if isinstance(grad_y, indexed_slices.IndexedSlices): 208 new_grad_ys.append( 209 indexed_slices.IndexedSlices( 210 indices=(array_ops.identity( 211 grad_y.indices, name="grad_ys_%d_indices" % i) 212 if isinstance(grad_y.indices, ops.Tensor) else 213 grad_y.indices), 214 values=(array_ops.identity( 215 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 216 grad_y.values, ops.Tensor) else grad_y.values), 217 dense_shape=(array_ops.identity( 218 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 219 if isinstance(grad_y.dense_shape, ops.Tensor) else 220 grad_y.dense_shape))) 221 else: 222 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 223 224 return new_grad_ys 225 226 227def _VerifyGeneratedGradients(grads, op): 228 """Verify that gradients are valid in number and type. 229 230 Args: 231 grads: List of generated gradients. 232 op: Operation for which the gradients where generated. 233 234 Raises: 235 ValueError: if sizes of gradients and inputs don't match. 236 TypeError: if type of any gradient is not valid for its input. 237 """ 238 # While ops have inputs added to them during the gradient computation, so we 239 # skip the below check. See while_v2 for details. 240 if op.type == "While" or op.type == "StatelessWhile": 241 return 242 243 if len(grads) != len(op.inputs): 244 raise ValueError(f"Num gradients {len(grads)} generated for op " 245 f"{op.node_def} do not match num inputs {len(op.inputs)}") 246 247 248def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set): 249 """The set of ops that terminate the gradient computation. 250 251 This computes the frontier of the forward graph *before* which backprop 252 should stop. Operations in the returned set will not be differentiated. 253 This set is defined as the subset of `from_ops` containing ops that have 254 no predecessor in `from_ops`. `pending_count` is the result of 255 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` 256 iff pending_count[op] > 0. 257 258 In addition, none of `stop_gradient_ops` will be differentiated. 259 260 Args: 261 from_ops: list of Operations. 262 stop_gradient_ops: list of Operations never to backprop through. 263 pending_count: mapping from operation to number of backprop inputs. 264 xs_set: ObjectIdentitySet of Tensors. 265 266 Returns: 267 The set of operations. 268 """ 269 stop_ops = set() 270 for op in from_ops: 271 is_stop_op = True 272 for inp in _NonEagerInputs(op, xs_set): 273 if pending_count[inp.op] > 0: 274 is_stop_op = False 275 break 276 if is_stop_op: 277 stop_ops.add(op) 278 stop_ops.update(op for op in stop_gradient_ops) 279 return stop_ops 280 281 282@contextlib.contextmanager 283def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name 284 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 285 if colocate_gradients_with_ops: 286 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access 287 yield 288 else: 289 yield 290 291 292def _IsPartitionedCall(op): 293 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" 294 295 296def _SymGrad(op, out_grads): 297 """Backprop through a function call node op given its outputs' gradients.""" 298 f_in = [x for x in op.inputs] + out_grads 299 f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs] 300 f = attr_value_pb2.NameAttrList() 301 if _IsPartitionedCall(op): 302 f.name = op.get_attr("f").name 303 else: 304 f.name = op.type 305 for k in op.node_def.attr: 306 f.attr[k].CopyFrom(op.node_def.attr[k]) 307 in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) 308 return in_grads 309 310 311def _MaybeCompile(scope, op, func, grad_fn): 312 """Compile the calculation in grad_fn if op was marked as compiled.""" 313 scope = scope.rstrip("/").replace("/", "_") 314 if func is not None: 315 xla_compile = func.definition.attr["_XlaCompile"].b 316 xla_separate_compiled_gradients = func.definition.attr[ 317 "_XlaSeparateCompiledGradients"].b 318 xla_scope = func.definition.attr["_XlaScope"].s.decode() 319 else: 320 try: 321 xla_compile = op.get_attr("_XlaCompile") 322 xla_separate_compiled_gradients = op.get_attr( 323 "_XlaSeparateCompiledGradients") 324 xla_scope = op.get_attr("_XlaScope").decode() 325 except ValueError: 326 xla_compile = False 327 328 if not xla_compile: 329 return grad_fn() # Exit early 330 331 # If the gradients are supposed to be compiled separately, we give them a 332 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 333 # they just inherit the existing _XlaScope name, which lets them be merged 334 # together with the non-gradient computation. 335 if xla_separate_compiled_gradients: 336 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 337 else: 338 xla_grad_scope = xla_scope 339 340 attrs = { 341 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 342 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 343 } 344 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 345 return grad_fn() 346 347 348def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set): 349 """Raises an error if we backprop through a loop var.""" 350 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error 351 # message. 352 target_op = None 353 queue = collections.deque([op]) 354 visited = set() 355 while queue: 356 curr_op = queue.popleft() 357 if curr_op in visited: continue 358 visited.add(curr_op) 359 if curr_op in from_ops: 360 target_op = curr_op 361 break 362 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set)) 363 assert target_op 364 raise ValueError( 365 "Cannot compute gradient inside while loop with respect to op " 366 f"'{target_op.name}'. We do not support taking the gradient wrt or " 367 "through the initial value of a loop variable. Gradients can be computed " 368 "through loop invariants or wrt the input parameters to the loop body.") 369 370 371def _IsFunction(graph): 372 return (isinstance(graph, FuncGraph) or 373 isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access 374 375 376def _Captures(func_graph): 377 if isinstance(func_graph, FuncGraph): 378 return func_graph.captures 379 else: 380 assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access 381 return func_graph.captures 382 383 384def _MaybeCaptured(t): 385 """If t is a captured value placeholder, returns the original captured value. 386 387 Args: 388 t: Tensor 389 390 Returns: 391 A tensor, potentially from a different Graph/FuncGraph. 392 """ 393 # pylint: disable=protected-access 394 if (not isinstance(t, ops.EagerTensor) and 395 _IsFunction(t.op.graph) and t.op.type == "Placeholder"): 396 for input_t, placeholder_t in _Captures(t.op.graph): 397 if t is placeholder_t: 398 return _MaybeCaptured(input_t) 399 # pylint: enable=protected-access 400 return t 401 402 403def _NonEagerInputs(op, xs_set): 404 """Returns the inputs of op, crossing closure boundaries where necessary. 405 406 Does not return any captured EagerTensors, i.e., the number of tensors 407 returned may be less than the actual number of inputs. 408 409 Args: 410 op: Operation 411 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 412 413 Returns: 414 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 415 is in a FuncGraph and has captured inputs. 416 """ 417 return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)] 418 419 420# TODO(skyewm): plumbing xs through everywhere is ugly, consider making 421# _GradientsHelper a class with xs as a member variable. 422def _Inputs(op, xs_set): 423 """Returns the inputs of op, crossing closure boundaries where necessary. 424 425 Args: 426 op: Operation 427 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 428 429 Returns: 430 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 431 is in a FuncGraph and has captured inputs. 432 """ 433 if _IsFunction(op.graph): # pylint: disable=protected-access 434 inputs = [] 435 for t in op.inputs: 436 # If we're differentiating w.r.t. `t`, do not attempt to traverse through 437 # it to a captured value. The algorithm needs to "see" `t` in this case, 438 # even if it's a function input for a captured value, whereas usually we'd 439 # like to traverse through these closures as if the captured value was the 440 # direct input to op. 441 if t not in xs_set: 442 t = _MaybeCaptured(t) 443 inputs.append(t) 444 return inputs 445 else: 446 return op.inputs 447 448 449def _Consumers(t, func_graphs): 450 """Returns the consumers of t, crossing closure boundaries where necessary. 451 452 Args: 453 t: Tensor 454 func_graphs: a list of FuncGraphs that may have captured t. 455 456 Returns: 457 A list of tensors. The tensors will be from the current graph and/or 458 func_graphs. 459 """ 460 consumers = t.consumers() 461 for func in func_graphs: 462 for input_t, placeholder in _Captures(func): 463 if input_t is t: 464 consumers.extend(_Consumers(placeholder, func_graphs)) 465 return consumers 466 467 468def _GradientsHelper(ys, 469 xs, 470 grad_ys=None, 471 name="gradients", 472 colocate_gradients_with_ops=False, 473 gate_gradients=False, 474 aggregation_method=None, 475 stop_gradients=None, 476 unconnected_gradients=UnconnectedGradients.NONE, 477 src_graph=None): 478 """Implementation of gradients().""" 479 if context.executing_eagerly(): 480 raise RuntimeError("tf.gradients is not supported when eager execution " 481 "is enabled. Use tf.GradientTape instead.") 482 ys = variable_utils.convert_variables_to_tensors(_AsList(ys)) 483 xs = [ 484 x.handle if resource_variable_ops.is_resource_variable(x) else x 485 for x in _AsList(xs) 486 ] 487 if grad_ys is not None: 488 grad_ys = _AsList(grad_ys) 489 490 # Handle CompositeTensors. 491 if (any(isinstance(x, composite_tensor.CompositeTensor) for x in xs) or 492 any(isinstance(y, composite_tensor.CompositeTensor) for y in ys)): 493 flat_xs = composite_tensor_gradient.get_flat_tensors_for_gradients(xs) 494 flat_ys = composite_tensor_gradient.get_flat_tensors_for_gradients(ys) 495 flat_grad_ys = ( 496 None if grad_ys is None else 497 composite_tensor_gradient.get_flat_tensors_for_gradients(grad_ys)) 498 flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name, 499 colocate_gradients_with_ops, gate_gradients, 500 aggregation_method, stop_gradients, 501 unconnected_gradients, src_graph) 502 return composite_tensor_gradient.replace_flat_tensors_for_gradients( 503 xs, flat_grads) 504 505 if src_graph is None: 506 src_graph = ops.get_default_graph() 507 try: 508 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 509 except ValueError: 510 raise ValueError( 511 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'") 512 513 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all 514 # ancestor graphs. This is necessary for correctly handling captured values. 515 func_graphs = [] 516 curr_graph = src_graph 517 while _IsFunction(curr_graph): 518 func_graphs.append(curr_graph) 519 if isinstance(curr_graph, FuncGraph): 520 curr_graph = curr_graph.outer_graph 521 else: 522 assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access 523 curr_graph = curr_graph._outer_graph # pylint: disable=protected-access 524 525 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 526 if grad_ys is None: 527 grad_ys = [None] * len(ys) 528 529 with ops.name_scope( 530 name, "gradients", 531 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 532 # Get a uid for this call to gradients that can be used to help 533 # cluster ops for compilation. 534 gradient_uid = ops.get_default_graph().unique_name("uid") 535 ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") 536 xs = ops.internal_convert_n_to_tensor_or_indexed_slices( 537 xs, name="x", as_ref=True) 538 xs_set = object_identity.ObjectIdentitySet(xs) 539 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, 540 gradient_uid) 541 542 # The approach we take here is as follows: Create a list of all ops in the 543 # subgraph between the ys and xs. Visit these ops in reverse order of ids 544 # to ensure that when we visit an op the gradients w.r.t its outputs have 545 # been collected. Then aggregate these gradients if needed, call the op's 546 # gradient function, and add the generated gradients to the gradients for 547 # its input. 548 549 # Initialize the pending count for ops in the connected subgraph from ys 550 # to the xs. 551 to_ops = [t.op for t in ys] 552 from_ops = [t.op for t in xs] 553 stop_gradient_ops = [t.op for t in stop_gradients] 554 reachable_to_ops, pending_count, loop_state = _PendingCount( 555 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set) 556 557 # Iterate over the collected ops. 558 # 559 # grads: op => list of gradients received on each output endpoint of the 560 # op. The gradients for each endpoint are initially collected as a list. 561 # When it is time to call the op's gradient function, for each endpoint we 562 # aggregate the list of received gradients into a Add() Operation if there 563 # is more than one. 564 grads = {} 565 566 # Add the initial gradients for the ys. 567 for y, grad_y in zip(ys, grad_ys): 568 _SetGrad(grads, y, grad_y) 569 570 # Initialize queue with to_ops. 571 queue = collections.deque() 572 # Add the ops in 'to_ops' into the queue. 573 to_ops_set = set() 574 for op in to_ops: 575 # 'ready' handles the case where one output gradient relies on 576 # another output's gradient. 577 ready = (pending_count[op] == 0) 578 if ready and op not in to_ops_set and op in reachable_to_ops: 579 to_ops_set.add(op) 580 queue.append(op) 581 582 if loop_state: 583 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 584 for y in loop_exits: 585 if backprop_util.IsTrainable(y): 586 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 587 queue.append(y.op) 588 589 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set) 590 while queue: 591 # generate gradient subgraph for op. 592 op = queue.popleft() 593 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): 594 if loop_state: 595 loop_state.EnterGradWhileContext(op, before=True) 596 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, 597 aggregation_method) 598 if loop_state: 599 loop_state.ExitGradWhileContext(op, before=True) 600 601 grad_fn = None 602 func_call = None 603 is_partitioned_call = _IsPartitionedCall(op) 604 # pylint: disable=protected-access 605 is_func_call = ( 606 src_graph._is_function(op.type) or is_partitioned_call) 607 # pylint: enable=protected-access 608 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 609 if has_out_grads and (op not in stop_ops): 610 try: 611 grad_fn = ops.get_gradient_function(op) 612 except LookupError: 613 if is_func_call: 614 if is_partitioned_call: 615 func_name = compat.as_bytes(op.get_attr("f").name) 616 func_call = src_graph._get_function( # pylint: disable=protected-access 617 func_name) 618 # When a graph is imported, the FunctionDefs are not copied over 619 # to each sub-graph so we recursively search the outer graphs 620 # for the FunctionDef. 621 if not func_call and hasattr(src_graph, "outer_graph"): 622 graph = src_graph.outer_graph 623 while graph is not None: 624 func_call = graph._get_function(func_name) # pylint: disable=protected-access 625 if func_call is not None: 626 break 627 if hasattr(graph, "outer_graph"): 628 graph = graph.outer_graph 629 else: 630 break 631 else: 632 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access 633 # Note that __defun is not set if the graph is 634 # imported. If it's set, we prefer to access the original 635 # defun. 636 func_call = getattr(op, "__defun", func_call) 637 grad_fn = func_call.python_grad_func 638 else: 639 raise LookupError( 640 "No gradient defined for operation" 641 f"'{op.name}' (op type: {op.type}). " 642 "In general every operation must have an associated " 643 "`@tf.RegisterGradient` for correct autodiff, which this " 644 "op is lacking. If you want to pretend this " 645 "operation is a constant in your program, you may insert " 646 "`tf.stop_gradient`. This can be useful to silence the " 647 "error in cases where you know gradients are not needed, " 648 "e.g. the forward pass of tf.custom_gradient. " 649 "Please see more details in " 650 "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.") # pylint: disable=line-too-long 651 if loop_state: 652 loop_state.EnterGradWhileContext(op, before=False) 653 654 # NOTE(skyewm): We don't support computing gradients wrt a loop variable 655 # unless it's within the context of a single iteration (i.e. the 656 # gradient is wrt to the loop parameter in the body function, not wrt or 657 # through the initial value). This means if we're in a while loop 658 # context, we should never see a switch node from this context. 659 # pylint: disable=protected-access 660 if (control_flow_util.IsSwitch(op) and 661 op._control_flow_context is not None and 662 op._control_flow_context.IsWhileContext() and 663 op._control_flow_context == 664 ops.get_default_graph()._get_control_flow_context()): 665 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set) 666 # pylint: enable=protected-access 667 668 if (grad_fn or is_func_call) and has_out_grads: 669 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 670 # output, it means that the cost does not depend on output[i], 671 # therefore dC/doutput[i] is 0. 672 for i, out_grad in enumerate(out_grads): 673 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 674 (not grad_fn and is_func_call) 675 or backprop_util.IsTrainable(op.outputs[i])): 676 # Only trainable outputs or outputs for a function call that 677 # will use SymbolicGradient get a zero gradient. Gradient 678 # functions should ignore the gradient for other outputs. 679 # TODO(apassos) gradients of resource handles might be an 680 # issue here because of zeros. 681 if loop_state: 682 out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i) 683 elif default_gradient.supports_default_grad(op.outputs[i]): 684 # TODO(b/143286622): The supports_default_grad check is needed 685 # because While op emits non-differentiable resource tensors 686 # as outputs. Remove this check when that is not the case. 687 out_grads[i] = control_flow_state.ZerosLike(op, i) 688 with ops.name_scope(op.name + "_grad"): 689 # pylint: disable=protected-access 690 with src_graph._original_op(op): 691 # pylint: enable=protected-access 692 if grad_fn: 693 # If grad_fn was found, do not use SymbolicGradient even for 694 # functions. 695 in_grads = _MaybeCompile(grad_scope, op, func_call, 696 lambda: grad_fn(op, *out_grads)) 697 else: 698 # For function call ops, we add a 'SymbolicGradient' 699 # node to the graph to compute gradients. 700 in_grads = _MaybeCompile(grad_scope, op, func_call, 701 lambda: _SymGrad(op, out_grads)) 702 in_grads = _AsList(in_grads) 703 _VerifyGeneratedGradients(in_grads, op) 704 if gate_gradients and len([x for x in in_grads 705 if x is not None]) > 1: 706 with ops.device(None): 707 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 708 None, 709 gradient_uid, 710 ignore_existing=True): 711 in_grads = control_flow_ops.tuple(in_grads) 712 _LogOpGradients(op, out_grads, in_grads) 713 else: 714 # If no grad_fn is defined or none of out_grads is available, 715 # just propagate a list of None backwards. 716 in_grads = [None] * len(_Inputs(op, xs_set)) 717 # Note: we don't filter out eager inputs here because the inputs need to 718 # line up with in_grads. 719 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): 720 if in_grad is not None: 721 if (isinstance(in_grad, ops.Tensor) and 722 t_in.dtype != dtypes.resource): 723 try: 724 in_grad.set_shape(t_in.get_shape()) 725 except ValueError: 726 raise ValueError( 727 "Incompatible shapes between op input and calculated " 728 f"input gradient. Forward operation: {op.name}. Input " 729 f"index: {i}. Original input shape: {t_in.shape}. " 730 f"Calculated input gradient shape: {in_grad.shape}") 731 if not isinstance(t_in, ops.EagerTensor): 732 _SetGrad(grads, t_in, in_grad) 733 if loop_state: 734 loop_state.ExitGradWhileContext(op, before=False) 735 736 # Update pending count for the inputs of op and enqueue ready ops. 737 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 738 xs_set) 739 740 if loop_state: 741 loop_state.PostProcessing() 742 return [_GetGrad(grads, x, unconnected_gradients) for x in xs] 743 744 745def _HasAnyNotNoneGrads(grads, op): 746 """Return true iff op has real gradient.""" 747 out_grads = _GetGrads(grads, op) 748 for out_grad in out_grads: 749 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): 750 return True 751 if out_grad and isinstance(out_grad, collections_abc.Sequence): 752 if any(g is not None for g in out_grad): 753 return True 754 return False 755 756 757def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 758 xs_set): 759 """Update pending count for the inputs of op and enqueue ready ops.""" 760 for x in _NonEagerInputs(op, xs_set): 761 pending_count[x.op] -= 1 762 ready = (pending_count[x.op] == 0) 763 if loop_state and not ready: 764 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) 765 if ready: 766 if control_flow_util.IsLoopExit(x.op): 767 # if x is an exit without real gradient, defer processing them. 768 grad_state = loop_state.GetGradState(x.op, before=False) 769 grad_state.deferred_exits.append(x) 770 grad_state.pending_exits_count -= 1 771 if grad_state.pending_exits_count == 0: 772 # We now have all the exits so process them. 773 has_not_none_grad = False 774 for y in grad_state.deferred_exits: 775 if _HasAnyNotNoneGrads(grads, y.op): 776 has_not_none_grad = True 777 queue.append(y.op) 778 else: 779 grad_state.unused_exits.append(y) 780 if has_not_none_grad: 781 # For an unused exit, if it has trainable outputs, backprop 782 # a zero gradient. Otherwise, just ignore it. 783 for y in grad_state.unused_exits: 784 if backprop_util.IsTrainable(y): 785 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 786 queue.append(y.op) 787 else: 788 # All exits are "unused" so use None as gradient. 789 for y in grad_state.unused_exits: 790 queue.append(y.op) 791 else: 792 queue.append(x.op) 793 794 795def _SetGrad(grads, t, grad): 796 """Sets gradient "grad" in "grads" for tensor "t".""" 797 op = t.op 798 op_grads = grads.get(op) 799 if not op_grads: 800 op_grads = [[] for _ in range(len(op.outputs))] 801 grads[op] = op_grads 802 t_grads = op_grads[t.value_index] 803 if isinstance(t_grads, list): 804 t_grads.append(grad) 805 else: 806 assert control_flow_util.IsLoopSwitch(op) 807 op_grads[t.value_index] = grad 808 809 810def _ZerosLike(t): 811 t_dtype = default_gradient.get_zeros_dtype(t) 812 if t.dtype == dtypes.resource: 813 return array_ops.zeros( 814 resource_variable_ops.variable_shape(t), dtype=t_dtype) 815 else: 816 return array_ops.zeros_like(t, dtype=t_dtype) 817 818 819def _GetGrad(grads, t, unconnected_gradients): 820 """Gets gradient for tensor "t".""" 821 op = t.op 822 op_grads = grads.get(op) 823 if not op_grads: 824 if unconnected_gradients == UnconnectedGradients.ZERO: 825 return _ZerosLike(t) 826 elif unconnected_gradients == UnconnectedGradients.NONE: 827 return None 828 else: 829 raise ValueError( 830 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'") 831 832 t_grad = op_grads[t.value_index] 833 # This can happen if some other output of `t.op` has non-None grad. 834 if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None: 835 return _ZerosLike(t) 836 837 assert not isinstance( 838 t_grad, list), ("gradients list should have been aggregated by now.") 839 return t_grad 840 841 842def _GetGrads(grads, op): 843 """Gets all gradients for op.""" 844 if op in grads: 845 return grads[op] 846 else: 847 return [[] for _ in range(len(op.outputs))] 848 849 850def _AccumulatorShape(inputs): 851 shape = tensor_shape.unknown_shape() 852 for i in inputs: 853 if isinstance(i, ops.Tensor): 854 shape = shape.merge_with(i.get_shape()) 855 return shape 856 857 858def _LogOpGradients(op, out_grads, in_grads): 859 """Log the in and out grads of an op.""" 860 logging.vlog(1, "Gradient for '" + op.name + "'") 861 862 def _FilterGrad(x): 863 if x is None: 864 return False 865 if isinstance(x, (list, tuple)): 866 return bool(x) 867 else: 868 return True 869 870 logging.vlog(1, " in --> %s", 871 ", ".join(x.name for x in out_grads if _FilterGrad(x))) 872 logging.vlog(1, " out --> %s", 873 ", ".join(x.name for x in in_grads if _FilterGrad(x))) 874 875 876def _MultiDeviceAddN(tensor_list, gradient_uid): 877 """Adds tensors from potentially multiple devices.""" 878 # Basic function structure comes from control_flow_ops.group(). 879 # Sort tensors according to their devices. 880 tensors_on_device = collections.defaultdict(lambda: []) 881 for tensor in tensor_list: 882 tensors_on_device[tensor.device].append(tensor) 883 884 # For each device, add the tensors on that device first. 885 # Then gather the partial sums from multiple devices. 886 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 887 # E.g., aggregate per GPU, then per task, and so on. 888 summands = [] 889 890 def DeviceKey(dev): 891 return "" if dev is None else dev 892 893 for dev in sorted(tensors_on_device, key=DeviceKey): 894 tensors = tensors_on_device[dev] 895 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 896 tensors[0].op, 897 gradient_uid, 898 ignore_existing=True): 899 summands.append(math_ops.add_n(tensors)) 900 901 return math_ops.add_n(summands) 902 903 904@tf_export("AggregationMethod") 905class AggregationMethod: 906 """A class listing aggregation methods used to combine gradients. 907 908 Computing partial derivatives can require aggregating gradient 909 contributions. This class lists the various methods that can 910 be used to combine gradients in the graph. 911 912 The following aggregation methods are part of the stable API for 913 aggregating gradients: 914 915 * `ADD_N`: All of the gradient terms are summed as part of one 916 operation using the "AddN" op (see `tf.add_n`). This 917 method has the property that all gradients must be ready and 918 buffered separately in memory before any aggregation is performed. 919 * `DEFAULT`: The system-chosen default aggregation method. 920 921 The following aggregation methods are experimental and may not 922 be supported in future releases: 923 924 * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using 925 the "AddN" op. This method of summing gradients may reduce 926 performance, but it can improve memory utilization because the 927 gradients can be released earlier. 928 929 """ 930 ADD_N = 0 931 DEFAULT = ADD_N 932 # The following are experimental and may not be supported in future releases. 933 EXPERIMENTAL_TREE = 1 934 EXPERIMENTAL_ACCUMULATE_N = 2 # An alias for EXPERIMENTAL_ADD_N = 1 935 936 937def _AggregatedGrads(grads, 938 op, 939 gradient_uid, 940 loop_state, 941 aggregation_method=None): 942 """Get the aggregated gradients for op. 943 944 Args: 945 grads: The map of memoized gradients. 946 op: The op to get gradients for. 947 gradient_uid: A unique identifier within the graph indicating 948 which invocation of gradients is being executed. Used to cluster 949 ops for compilation. 950 loop_state: An object for maintaining the state of the while loops in the 951 graph. It is of type ControlFlowState. None if the graph 952 contains no while loops. 953 aggregation_method: Specifies the method used to combine gradient terms. 954 Accepted values are constants defined in the class `AggregationMethod`. 955 956 Returns: 957 A list of gradients, one per each output of `op`. If the gradients 958 for a particular output is a list, this function aggregates it 959 before returning. 960 961 Raises: 962 TypeError: if the incoming grads are not Tensors or IndexedSlices. 963 ValueError: if the arguments are invalid. 964 965 """ 966 if aggregation_method is None: 967 aggregation_method = AggregationMethod.DEFAULT 968 valid_aggregation_methods = [ 969 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 970 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N] 971 if aggregation_method not in valid_aggregation_methods: 972 raise ValueError( 973 f"Invalid `aggregation_method` specified {aggregation_method}. " 974 f"Accepted values are {valid_aggregation_methods}.") 975 out_grads = _GetGrads(grads, op) 976 for i, out_grad in enumerate(out_grads): 977 if loop_state: 978 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): 979 assert control_flow_util.IsLoopSwitch(op) 980 continue 981 # Grads have to be Tensors or IndexedSlices 982 if (isinstance(out_grad, collections_abc.Sequence) and not all( 983 isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) 984 for g in out_grad 985 if g is not None)): 986 raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients " 987 "have to be either all Tensors or all IndexedSlices") 988 # Aggregate multiple gradients, and convert [] to None. 989 if out_grad: 990 if len(out_grad) < 2: 991 used = "nop" 992 out_grads[i] = out_grad[0] 993 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): 994 tensor_shape = _AccumulatorShape(out_grad) 995 if aggregation_method in [ 996 AggregationMethod.EXPERIMENTAL_TREE, 997 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 998 ]: 999 # Aggregate all gradients by doing pairwise sums: this may 1000 # reduce performance, but it can improve memory because the 1001 # gradients can be released earlier. 1002 # 1003 # TODO(vrv): Consider replacing this with a version of 1004 # tf.AddN() that eagerly frees its inputs as soon as they are 1005 # ready, so the order of this tree does not become a problem. 1006 used = "tree" 1007 with ops.name_scope(op.name + "_gradient_sum"): 1008 running_sum = out_grad[0] 1009 for grad in out_grad[1:]: 1010 running_sum = math_ops.add_n([running_sum, grad]) 1011 out_grads[i] = running_sum 1012 else: 1013 used = "add_n" 1014 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) 1015 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 1016 tensor_shape, used) 1017 else: 1018 out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad) # pylint: disable=protected-access 1019 else: # not out_grad 1020 # out_grads[i] is [], thus its aggregation is simply None. 1021 out_grads[i] = None 1022 return out_grads 1023 1024 1025# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are 1026# unfortunately too slow to use here. 1027POSSIBLE_GRADIENT_TYPES_NONE = 0 1028POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 1029POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 1030 1031 1032def PossibleTapeGradientTypes(tensors): 1033 """Determines whether and how `args` may require tape gradients.""" 1034 return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors) 1035