xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/gradients_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17import collections
18import contextlib
19
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python import pywrap_tfe
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import backprop_util
24from tensorflow.python.eager import context
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import composite_tensor_gradient
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import function as framework_function
29from tensorflow.python.framework import indexed_slices
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework.func_graph import FuncGraph
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import control_flow_state
36from tensorflow.python.ops import control_flow_util
37from tensorflow.python.ops import default_gradient
38from tensorflow.python.ops import functional_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import resource_variable_ops
41from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.util import compat
44from tensorflow.python.util import object_identity
45from tensorflow.python.util import variable_utils
46from tensorflow.python.util.compat import collections_abc
47from tensorflow.python.util.tf_export import tf_export
48
49
50def _MarkReachedOps(from_ops, reached_ops, func_graphs):
51  """Mark all ops reached from "from_ops".
52
53  Args:
54    from_ops: list of Operations.
55    reached_ops: set of Operations.
56    func_graphs: list of FuncGraphs. This method will traverse through
57      these functions if they capture from_ops or any reachable ops.
58  """
59  queue = collections.deque()
60  queue.extend(from_ops)
61  while queue:
62    op = queue.popleft()
63    if op not in reached_ops:
64      reached_ops.add(op)
65      for output in op.outputs:
66        if backprop_util.IsTrainable(output):
67          queue.extend(_Consumers(output, func_graphs))
68
69
70def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
71                  xs_set):
72  """Initialize the pending count for ops between two lists of Operations.
73
74  'pending_count[op]' indicates the number of backprop inputs
75  to this operation.
76
77  Args:
78    to_ops: list of Operations.
79    from_ops: list of Operations.
80    colocate_gradients_with_ops: Python bool.  See docstring of gradients().
81    func_graphs: list of FuncGraphs. This method will traverse through
82      these functions if they capture from_ops or any reachable ops. This is
83      useful if to_ops occur in a function and from_ops are in an outer function
84      or graph.
85    xs_set: ObjectIdentitySet of Tensors.
86
87  Returns:
88    A tuple containing: (1) the subset of to_ops reachable from from_ops by a
89    path of zero or more backpropagatable tensors, (2) a mapping from operation
90    to the number of backprop inputs to that op, and (3) a ControlFlowState
91    object which is not None if the ops between from_ops and to_ops contain
92    control flow loops.
93  """
94  # Mark reachable ops from from_ops.
95  reached_ops = set()
96  _MarkReachedOps(from_ops, reached_ops, func_graphs)
97  # X in reached_ops iff X is reachable from from_ops by a path of zero or more
98  # backpropagatable tensors.
99
100  reachable_to_ops = set(op for op in to_ops if op in reached_ops)
101
102  # Mark between ops.
103  between_ops = set()
104  between_op_list = []
105  queue = collections.deque()
106  queue.extend(to_ops)
107  while queue:
108    op = queue.popleft()
109    # We are interested in this op.
110    if op in reached_ops:
111      between_ops.add(op)
112      between_op_list.append(op)
113      # Clear the boolean so we won't add the inputs again.
114      reached_ops.remove(op)
115      for inp in _NonEagerInputs(op, xs_set):
116        queue.append(inp.op)
117  # X in between_ops iff X is on a path of zero or more backpropagatable tensors
118  # between from_ops and to_ops
119
120  # 'loop_state' is None if there are no while loops.
121  loop_state = control_flow_state.MaybeCreateControlFlowState(
122      between_op_list, between_ops, colocate_gradients_with_ops)
123
124  # Initialize pending count for between ops.
125  pending_count = collections.defaultdict(int)
126  for op in between_op_list:
127    for x in _NonEagerInputs(op, xs_set):
128      if x.op in between_ops:
129        pending_count[x.op] += 1
130
131  return reachable_to_ops, pending_count, loop_state
132
133
134def _AsList(x):
135  return x if isinstance(x, (list, tuple)) else [x]
136
137
138def _DefaultGradYs(grad_ys,
139                   ys,
140                   colocate_gradients_with_ops,
141                   gradient_uid="__unsupported__"):
142  """Fill in default values for grad_ys.
143
144  Args:
145    grad_ys: List of gradients, can contain None.
146    ys: List of tensors.
147    colocate_gradients_with_ops: If True, try colocating gradients with
148      the corresponding op.
149    gradient_uid: A unique identifier within the graph indicating
150      which invocation of gradients is being executed. Used to cluster
151      ops for compilation.
152
153  Returns:
154    A list of gradients to use, without None.
155
156  Raises:
157    ValueError: If sizes of gradients and inputs don't match
158    TypeError: If type of any gradient is not valid for its input.
159  """
160  if len(grad_ys) != len(ys):
161    raise ValueError(f"Length mismatch. Passed {len(grad_ys)} grad_ys for "
162                     f"{len(ys)} ys")
163  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
164  new_grad_ys = []
165  for i, (y, grad_y) in enumerate(zip(ys, grad_ys)):
166    with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
167      if grad_y is None:
168        if y.dtype.is_complex:
169          raise TypeError(
170              f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = "
171              f"{dtypes.as_dtype(y.dtype).name})")
172        new_grad_ys.append(
173            array_ops.ones(
174                array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i))
175        continue
176      if y.dtype.is_floating or y.dtype.is_integer:
177        if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
178          raise TypeError(
179              f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
180              f"for real or integer-valued tensor {y} with type "
181              f"{dtypes.as_dtype(y.dtype).name} must be real or integer")
182      elif y.dtype.is_complex:
183        if not grad_y.dtype.is_complex:
184          raise TypeError(
185              f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
186              f"for complex-valued tensor {y} with type "
187              f"{dtypes.as_dtype(y.dtype).name} must be real")
188      elif y.dtype == dtypes.variant:
189        if grad_y.dtype != dtypes.variant:
190          raise TypeError(
191              f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
192              f"for variant tensor {y} with type "
193              f"{dtypes.as_dtype(y.dtype).name} must be variant")
194      elif y.dtype == dtypes.resource:
195        # We assume y is the handle of a ResourceVariable. The gradient of a
196        # ResourceVariable should be a numeric value, not another resource.
197        if grad_y.dtype == dtypes.resource:
198          raise TypeError(f"Input gradient {grad_y} for resource tensor {y} "
199                          "should not be a resource")
200      else:
201        raise TypeError(
202            f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be "
203            "numeric to obtain a default gradient")
204      # Create a grad_y tensor in the name scope of the gradient.
205      # Required for TensorArrays to identify which gradient call a
206      # grad_y value is coming from.
207      if isinstance(grad_y, indexed_slices.IndexedSlices):
208        new_grad_ys.append(
209            indexed_slices.IndexedSlices(
210                indices=(array_ops.identity(
211                    grad_y.indices, name="grad_ys_%d_indices" % i)
212                         if isinstance(grad_y.indices, ops.Tensor) else
213                         grad_y.indices),
214                values=(array_ops.identity(
215                    grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
216                        grad_y.values, ops.Tensor) else grad_y.values),
217                dense_shape=(array_ops.identity(
218                    grad_y.dense_shape, name="grad_ys_%d_shape" % i)
219                             if isinstance(grad_y.dense_shape, ops.Tensor) else
220                             grad_y.dense_shape)))
221      else:
222        new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
223
224  return new_grad_ys
225
226
227def _VerifyGeneratedGradients(grads, op):
228  """Verify that gradients are valid in number and type.
229
230  Args:
231    grads: List of generated gradients.
232    op: Operation for which the gradients where generated.
233
234  Raises:
235    ValueError: if sizes of gradients and inputs don't match.
236    TypeError: if type of any gradient is not valid for its input.
237  """
238  # While ops have inputs added to them during the gradient computation, so we
239  # skip the below check. See while_v2 for details.
240  if op.type == "While" or op.type == "StatelessWhile":
241    return
242
243  if len(grads) != len(op.inputs):
244    raise ValueError(f"Num gradients {len(grads)} generated for op "
245                     f"{op.node_def} do not match num inputs {len(op.inputs)}")
246
247
248def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
249  """The set of ops that terminate the gradient computation.
250
251  This computes the frontier of the forward graph *before* which backprop
252  should stop. Operations in the returned set will not be differentiated.
253  This set is defined as the subset of `from_ops` containing ops that have
254  no predecessor in `from_ops`. `pending_count` is the result of
255  `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
256  iff pending_count[op] > 0.
257
258  In addition, none of `stop_gradient_ops` will be differentiated.
259
260  Args:
261    from_ops: list of Operations.
262    stop_gradient_ops: list of Operations never to backprop through.
263    pending_count: mapping from operation to number of backprop inputs.
264    xs_set: ObjectIdentitySet of Tensors.
265
266  Returns:
267    The set of operations.
268  """
269  stop_ops = set()
270  for op in from_ops:
271    is_stop_op = True
272    for inp in _NonEagerInputs(op, xs_set):
273      if pending_count[inp.op] > 0:
274        is_stop_op = False
275        break
276    if is_stop_op:
277      stop_ops.add(op)
278  stop_ops.update(op for op in stop_gradient_ops)
279  return stop_ops
280
281
282@contextlib.contextmanager
283def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):  # pylint: disable=invalid-name
284  """Context to colocate with `op` if `colocate_gradients_with_ops`."""
285  if colocate_gradients_with_ops:
286    with ops._colocate_with_for_gradient(op, gradient_uid):  # pylint: disable=protected-access
287      yield
288  else:
289    yield
290
291
292def _IsPartitionedCall(op):
293  return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
294
295
296def _SymGrad(op, out_grads):
297  """Backprop through a function call node op given its outputs' gradients."""
298  f_in = [x for x in op.inputs] + out_grads
299  f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs]
300  f = attr_value_pb2.NameAttrList()
301  if _IsPartitionedCall(op):
302    f.name = op.get_attr("f").name
303  else:
304    f.name = op.type
305  for k in op.node_def.attr:
306    f.attr[k].CopyFrom(op.node_def.attr[k])
307  in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
308  return in_grads
309
310
311def _MaybeCompile(scope, op, func, grad_fn):
312  """Compile the calculation in grad_fn if op was marked as compiled."""
313  scope = scope.rstrip("/").replace("/", "_")
314  if func is not None:
315    xla_compile = func.definition.attr["_XlaCompile"].b
316    xla_separate_compiled_gradients = func.definition.attr[
317        "_XlaSeparateCompiledGradients"].b
318    xla_scope = func.definition.attr["_XlaScope"].s.decode()
319  else:
320    try:
321      xla_compile = op.get_attr("_XlaCompile")
322      xla_separate_compiled_gradients = op.get_attr(
323          "_XlaSeparateCompiledGradients")
324      xla_scope = op.get_attr("_XlaScope").decode()
325    except ValueError:
326      xla_compile = False
327
328  if not xla_compile:
329    return grad_fn()  # Exit early
330
331  # If the gradients are supposed to be compiled separately, we give them a
332  # _XlaScope name that is based on the name_scope of the gradients.  Otherwise
333  # they just inherit the existing _XlaScope name, which lets them be merged
334  # together with the non-gradient computation.
335  if xla_separate_compiled_gradients:
336    xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
337  else:
338    xla_grad_scope = xla_scope
339
340  attrs = {
341      "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
342      "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
343  }
344  with ops.get_default_graph()._attr_scope(attrs):  # pylint: disable=protected-access
345    return grad_fn()
346
347
348def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
349  """Raises an error if we backprop through a loop var."""
350  # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
351  # message.
352  target_op = None
353  queue = collections.deque([op])
354  visited = set()
355  while queue:
356    curr_op = queue.popleft()
357    if curr_op in visited: continue
358    visited.add(curr_op)
359    if curr_op in from_ops:
360      target_op = curr_op
361      break
362    queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
363  assert target_op
364  raise ValueError(
365      "Cannot compute gradient inside while loop with respect to op "
366      f"'{target_op.name}'. We do not support taking the gradient wrt or "
367      "through the initial value of a loop variable. Gradients can be computed "
368      "through loop invariants or wrt the input parameters to the loop body.")
369
370
371def _IsFunction(graph):
372  return (isinstance(graph, FuncGraph) or
373          isinstance(graph, framework_function._FuncGraph))  # pylint: disable=protected-access
374
375
376def _Captures(func_graph):
377  if isinstance(func_graph, FuncGraph):
378    return func_graph.captures
379  else:
380    assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
381    return func_graph.captures
382
383
384def _MaybeCaptured(t):
385  """If t is a captured value placeholder, returns the original captured value.
386
387  Args:
388    t: Tensor
389
390  Returns:
391    A tensor, potentially from a different Graph/FuncGraph.
392  """
393  # pylint: disable=protected-access
394  if (not isinstance(t, ops.EagerTensor) and
395      _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
396    for input_t, placeholder_t in _Captures(t.op.graph):
397      if t is placeholder_t:
398        return _MaybeCaptured(input_t)
399  # pylint: enable=protected-access
400  return t
401
402
403def _NonEagerInputs(op, xs_set):
404  """Returns the inputs of op, crossing closure boundaries where necessary.
405
406  Does not return any captured EagerTensors, i.e., the number of tensors
407  returned may be less than the actual number of inputs.
408
409  Args:
410    op: Operation
411    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
412
413  Returns:
414    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
415    is in a FuncGraph and has captured inputs.
416  """
417  return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
418
419
420# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
421# _GradientsHelper a class with xs as a member variable.
422def _Inputs(op, xs_set):
423  """Returns the inputs of op, crossing closure boundaries where necessary.
424
425  Args:
426    op: Operation
427    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
428
429  Returns:
430    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
431    is in a FuncGraph and has captured inputs.
432  """
433  if _IsFunction(op.graph):  # pylint: disable=protected-access
434    inputs = []
435    for t in op.inputs:
436      # If we're differentiating w.r.t. `t`, do not attempt to traverse through
437      # it to a captured value. The algorithm needs to "see" `t` in this case,
438      # even if it's a function input for a captured value, whereas usually we'd
439      # like to traverse through these closures as if the captured value was the
440      # direct input to op.
441      if t not in xs_set:
442        t = _MaybeCaptured(t)
443      inputs.append(t)
444    return inputs
445  else:
446    return op.inputs
447
448
449def _Consumers(t, func_graphs):
450  """Returns the consumers of t, crossing closure boundaries where necessary.
451
452  Args:
453    t: Tensor
454    func_graphs: a list of FuncGraphs that may have captured t.
455
456  Returns:
457    A list of tensors. The tensors will be from the current graph and/or
458    func_graphs.
459  """
460  consumers = t.consumers()
461  for func in func_graphs:
462    for input_t, placeholder in _Captures(func):
463      if input_t is t:
464        consumers.extend(_Consumers(placeholder, func_graphs))
465  return consumers
466
467
468def _GradientsHelper(ys,
469                     xs,
470                     grad_ys=None,
471                     name="gradients",
472                     colocate_gradients_with_ops=False,
473                     gate_gradients=False,
474                     aggregation_method=None,
475                     stop_gradients=None,
476                     unconnected_gradients=UnconnectedGradients.NONE,
477                     src_graph=None):
478  """Implementation of gradients()."""
479  if context.executing_eagerly():
480    raise RuntimeError("tf.gradients is not supported when eager execution "
481                       "is enabled. Use tf.GradientTape instead.")
482  ys = variable_utils.convert_variables_to_tensors(_AsList(ys))
483  xs = [
484      x.handle if resource_variable_ops.is_resource_variable(x) else x
485      for x in _AsList(xs)
486  ]
487  if grad_ys is not None:
488    grad_ys = _AsList(grad_ys)
489
490  # Handle CompositeTensors.
491  if (any(isinstance(x, composite_tensor.CompositeTensor) for x in xs) or
492      any(isinstance(y, composite_tensor.CompositeTensor) for y in ys)):
493    flat_xs = composite_tensor_gradient.get_flat_tensors_for_gradients(xs)
494    flat_ys = composite_tensor_gradient.get_flat_tensors_for_gradients(ys)
495    flat_grad_ys = (
496        None if grad_ys is None else
497        composite_tensor_gradient.get_flat_tensors_for_gradients(grad_ys))
498    flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name,
499                                  colocate_gradients_with_ops, gate_gradients,
500                                  aggregation_method, stop_gradients,
501                                  unconnected_gradients, src_graph)
502    return composite_tensor_gradient.replace_flat_tensors_for_gradients(
503        xs, flat_grads)
504
505  if src_graph is None:
506    src_graph = ops.get_default_graph()
507  try:
508    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
509  except ValueError:
510    raise ValueError(
511        f"Unknown value for unconnected_gradients: '{unconnected_gradients}'")
512
513  # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
514  # ancestor graphs. This is necessary for correctly handling captured values.
515  func_graphs = []
516  curr_graph = src_graph
517  while _IsFunction(curr_graph):
518    func_graphs.append(curr_graph)
519    if isinstance(curr_graph, FuncGraph):
520      curr_graph = curr_graph.outer_graph
521    else:
522      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
523      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
524
525  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
526  if grad_ys is None:
527    grad_ys = [None] * len(ys)
528
529  with ops.name_scope(
530      name, "gradients",
531      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
532    # Get a uid for this call to gradients that can be used to help
533    # cluster ops for compilation.
534    gradient_uid = ops.get_default_graph().unique_name("uid")
535    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
536    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
537        xs, name="x", as_ref=True)
538    xs_set = object_identity.ObjectIdentitySet(xs)
539    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
540                             gradient_uid)
541
542    # The approach we take here is as follows: Create a list of all ops in the
543    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
544    # to ensure that when we visit an op the gradients w.r.t its outputs have
545    # been collected.  Then aggregate these gradients if needed, call the op's
546    # gradient function, and add the generated gradients to the gradients for
547    # its input.
548
549    # Initialize the pending count for ops in the connected subgraph from ys
550    # to the xs.
551    to_ops = [t.op for t in ys]
552    from_ops = [t.op for t in xs]
553    stop_gradient_ops = [t.op for t in stop_gradients]
554    reachable_to_ops, pending_count, loop_state = _PendingCount(
555        to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
556
557    # Iterate over the collected ops.
558    #
559    # grads: op => list of gradients received on each output endpoint of the
560    # op.  The gradients for each endpoint are initially collected as a list.
561    # When it is time to call the op's gradient function, for each endpoint we
562    # aggregate the list of received gradients into a Add() Operation if there
563    # is more than one.
564    grads = {}
565
566    # Add the initial gradients for the ys.
567    for y, grad_y in zip(ys, grad_ys):
568      _SetGrad(grads, y, grad_y)
569
570    # Initialize queue with to_ops.
571    queue = collections.deque()
572    # Add the ops in 'to_ops' into the queue.
573    to_ops_set = set()
574    for op in to_ops:
575      # 'ready' handles the case where one output gradient relies on
576      # another output's gradient.
577      ready = (pending_count[op] == 0)
578      if ready and op not in to_ops_set and op in reachable_to_ops:
579        to_ops_set.add(op)
580        queue.append(op)
581
582    if loop_state:
583      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
584      for y in loop_exits:
585        if backprop_util.IsTrainable(y):
586          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
587          queue.append(y.op)
588
589    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
590    while queue:
591      # generate gradient subgraph for op.
592      op = queue.popleft()
593      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
594        if loop_state:
595          loop_state.EnterGradWhileContext(op, before=True)
596        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
597                                     aggregation_method)
598        if loop_state:
599          loop_state.ExitGradWhileContext(op, before=True)
600
601        grad_fn = None
602        func_call = None
603        is_partitioned_call = _IsPartitionedCall(op)
604        # pylint: disable=protected-access
605        is_func_call = (
606            src_graph._is_function(op.type) or is_partitioned_call)
607        # pylint: enable=protected-access
608        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
609        if has_out_grads and (op not in stop_ops):
610          try:
611            grad_fn = ops.get_gradient_function(op)
612          except LookupError:
613            if is_func_call:
614              if is_partitioned_call:
615                func_name = compat.as_bytes(op.get_attr("f").name)
616                func_call = src_graph._get_function(  # pylint: disable=protected-access
617                    func_name)
618                # When a graph is imported, the FunctionDefs are not copied over
619                # to each sub-graph so we recursively search the outer graphs
620                # for the FunctionDef.
621                if not func_call and hasattr(src_graph, "outer_graph"):
622                  graph = src_graph.outer_graph
623                  while graph is not None:
624                    func_call = graph._get_function(func_name)  # pylint: disable=protected-access
625                    if func_call  is not None:
626                      break
627                    if hasattr(graph, "outer_graph"):
628                      graph = graph.outer_graph
629                    else:
630                      break
631              else:
632                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
633              # Note that __defun is not set if the graph is
634              # imported. If it's set, we prefer to access the original
635              # defun.
636              func_call = getattr(op, "__defun", func_call)
637              grad_fn = func_call.python_grad_func
638            else:
639              raise LookupError(
640                  "No gradient defined for operation"
641                  f"'{op.name}' (op type: {op.type}). "
642                  "In general every operation must have an associated "
643                  "`@tf.RegisterGradient` for correct autodiff, which this "
644                  "op is lacking. If you want to pretend this "
645                  "operation is a constant in your program, you may insert "
646                  "`tf.stop_gradient`. This can be useful to silence the "
647                  "error in cases where you know gradients are not needed, "
648                  "e.g. the forward pass of tf.custom_gradient. "
649                  "Please see more details in "
650                  "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.")  # pylint: disable=line-too-long
651        if loop_state:
652          loop_state.EnterGradWhileContext(op, before=False)
653
654        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
655        # unless it's within the context of a single iteration (i.e. the
656        # gradient is wrt to the loop parameter in the body function, not wrt or
657        # through the initial value). This means if we're in a while loop
658        # context, we should never see a switch node from this context.
659        # pylint: disable=protected-access
660        if (control_flow_util.IsSwitch(op) and
661            op._control_flow_context is not None and
662            op._control_flow_context.IsWhileContext() and
663            op._control_flow_context ==
664            ops.get_default_graph()._get_control_flow_context()):
665          _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
666        # pylint: enable=protected-access
667
668        if (grad_fn or is_func_call) and has_out_grads:
669          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
670          # output, it means that the cost does not depend on output[i],
671          # therefore dC/doutput[i] is 0.
672          for i, out_grad in enumerate(out_grads):
673            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
674                (not grad_fn and is_func_call)
675                or backprop_util.IsTrainable(op.outputs[i])):
676              # Only trainable outputs or outputs for a function call that
677              # will use SymbolicGradient get a zero gradient. Gradient
678              # functions should ignore the gradient for other outputs.
679              # TODO(apassos) gradients of resource handles might be an
680              # issue here because of zeros.
681              if loop_state:
682                out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i)
683              elif default_gradient.supports_default_grad(op.outputs[i]):
684                # TODO(b/143286622): The supports_default_grad check is needed
685                # because While op emits non-differentiable resource tensors
686                # as outputs. Remove this check when that is not the case.
687                out_grads[i] = control_flow_state.ZerosLike(op, i)
688          with ops.name_scope(op.name + "_grad"):
689            # pylint: disable=protected-access
690            with src_graph._original_op(op):
691              # pylint: enable=protected-access
692              if grad_fn:
693                # If grad_fn was found, do not use SymbolicGradient even for
694                # functions.
695                in_grads = _MaybeCompile(grad_scope, op, func_call,
696                                         lambda: grad_fn(op, *out_grads))
697              else:
698                # For function call ops, we add a 'SymbolicGradient'
699                # node to the graph to compute gradients.
700                in_grads = _MaybeCompile(grad_scope, op, func_call,
701                                         lambda: _SymGrad(op, out_grads))
702              in_grads = _AsList(in_grads)
703              _VerifyGeneratedGradients(in_grads, op)
704              if gate_gradients and len([x for x in in_grads
705                                         if x is not None]) > 1:
706                with ops.device(None):
707                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
708                      None,
709                      gradient_uid,
710                      ignore_existing=True):
711                    in_grads = control_flow_ops.tuple(in_grads)
712          _LogOpGradients(op, out_grads, in_grads)
713        else:
714          # If no grad_fn is defined or none of out_grads is available,
715          # just propagate a list of None backwards.
716          in_grads = [None] * len(_Inputs(op, xs_set))
717        # Note: we don't filter out eager inputs here because the inputs need to
718        # line up with in_grads.
719        for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
720          if in_grad is not None:
721            if (isinstance(in_grad, ops.Tensor) and
722                t_in.dtype != dtypes.resource):
723              try:
724                in_grad.set_shape(t_in.get_shape())
725              except ValueError:
726                raise ValueError(
727                    "Incompatible shapes between op input and calculated "
728                    f"input gradient. Forward operation: {op.name}. Input "
729                    f"index: {i}. Original input shape: {t_in.shape}. "
730                    f"Calculated input gradient shape: {in_grad.shape}")
731            if not isinstance(t_in, ops.EagerTensor):
732              _SetGrad(grads, t_in, in_grad)
733        if loop_state:
734          loop_state.ExitGradWhileContext(op, before=False)
735
736      # Update pending count for the inputs of op and enqueue ready ops.
737      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
738                                    xs_set)
739
740  if loop_state:
741    loop_state.PostProcessing()
742  return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
743
744
745def _HasAnyNotNoneGrads(grads, op):
746  """Return true iff op has real gradient."""
747  out_grads = _GetGrads(grads, op)
748  for out_grad in out_grads:
749    if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)):
750      return True
751    if out_grad and isinstance(out_grad, collections_abc.Sequence):
752      if any(g is not None for g in out_grad):
753        return True
754  return False
755
756
757def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
758                                  xs_set):
759  """Update pending count for the inputs of op and enqueue ready ops."""
760  for x in _NonEagerInputs(op, xs_set):
761    pending_count[x.op] -= 1
762    ready = (pending_count[x.op] == 0)
763    if loop_state and not ready:
764      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
765    if ready:
766      if control_flow_util.IsLoopExit(x.op):
767        # if x is an exit without real gradient, defer processing them.
768        grad_state = loop_state.GetGradState(x.op, before=False)
769        grad_state.deferred_exits.append(x)
770        grad_state.pending_exits_count -= 1
771        if grad_state.pending_exits_count == 0:
772          # We now have all the exits so process them.
773          has_not_none_grad = False
774          for y in grad_state.deferred_exits:
775            if _HasAnyNotNoneGrads(grads, y.op):
776              has_not_none_grad = True
777              queue.append(y.op)
778            else:
779              grad_state.unused_exits.append(y)
780          if has_not_none_grad:
781            # For an unused exit, if it has trainable outputs, backprop
782            # a zero gradient. Otherwise, just ignore it.
783            for y in grad_state.unused_exits:
784              if backprop_util.IsTrainable(y):
785                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
786              queue.append(y.op)
787          else:
788            # All exits are "unused" so use None as gradient.
789            for y in grad_state.unused_exits:
790              queue.append(y.op)
791      else:
792        queue.append(x.op)
793
794
795def _SetGrad(grads, t, grad):
796  """Sets gradient "grad" in "grads" for tensor "t"."""
797  op = t.op
798  op_grads = grads.get(op)
799  if not op_grads:
800    op_grads = [[] for _ in range(len(op.outputs))]
801    grads[op] = op_grads
802  t_grads = op_grads[t.value_index]
803  if isinstance(t_grads, list):
804    t_grads.append(grad)
805  else:
806    assert control_flow_util.IsLoopSwitch(op)
807    op_grads[t.value_index] = grad
808
809
810def _ZerosLike(t):
811  t_dtype = default_gradient.get_zeros_dtype(t)
812  if t.dtype == dtypes.resource:
813    return array_ops.zeros(
814        resource_variable_ops.variable_shape(t), dtype=t_dtype)
815  else:
816    return array_ops.zeros_like(t, dtype=t_dtype)
817
818
819def _GetGrad(grads, t, unconnected_gradients):
820  """Gets gradient for tensor "t"."""
821  op = t.op
822  op_grads = grads.get(op)
823  if not op_grads:
824    if unconnected_gradients == UnconnectedGradients.ZERO:
825      return _ZerosLike(t)
826    elif unconnected_gradients == UnconnectedGradients.NONE:
827      return None
828    else:
829      raise ValueError(
830          f"Unknown value for unconnected_gradients: '{unconnected_gradients}'")
831
832  t_grad = op_grads[t.value_index]
833  # This can happen if some other output of `t.op` has non-None grad.
834  if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
835    return _ZerosLike(t)
836
837  assert not isinstance(
838      t_grad, list), ("gradients list should have been aggregated by now.")
839  return t_grad
840
841
842def _GetGrads(grads, op):
843  """Gets all gradients for op."""
844  if op in grads:
845    return grads[op]
846  else:
847    return [[] for _ in range(len(op.outputs))]
848
849
850def _AccumulatorShape(inputs):
851  shape = tensor_shape.unknown_shape()
852  for i in inputs:
853    if isinstance(i, ops.Tensor):
854      shape = shape.merge_with(i.get_shape())
855  return shape
856
857
858def _LogOpGradients(op, out_grads, in_grads):
859  """Log the in and out grads of an op."""
860  logging.vlog(1, "Gradient for '" + op.name + "'")
861
862  def _FilterGrad(x):
863    if x is None:
864      return False
865    if isinstance(x, (list, tuple)):
866      return bool(x)
867    else:
868      return True
869
870  logging.vlog(1, "  in  --> %s",
871               ", ".join(x.name for x in out_grads if _FilterGrad(x)))
872  logging.vlog(1, "  out --> %s",
873               ", ".join(x.name for x in in_grads if _FilterGrad(x)))
874
875
876def _MultiDeviceAddN(tensor_list, gradient_uid):
877  """Adds tensors from potentially multiple devices."""
878  # Basic function structure comes from control_flow_ops.group().
879  # Sort tensors according to their devices.
880  tensors_on_device = collections.defaultdict(lambda: [])
881  for tensor in tensor_list:
882    tensors_on_device[tensor.device].append(tensor)
883
884  # For each device, add the tensors on that device first.
885  # Then gather the partial sums from multiple devices.
886  # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
887  # E.g., aggregate per GPU, then per task, and so on.
888  summands = []
889
890  def DeviceKey(dev):
891    return "" if dev is None else dev
892
893  for dev in sorted(tensors_on_device, key=DeviceKey):
894    tensors = tensors_on_device[dev]
895    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
896        tensors[0].op,
897        gradient_uid,
898        ignore_existing=True):
899      summands.append(math_ops.add_n(tensors))
900
901  return math_ops.add_n(summands)
902
903
904@tf_export("AggregationMethod")
905class AggregationMethod:
906  """A class listing aggregation methods used to combine gradients.
907
908  Computing partial derivatives can require aggregating gradient
909  contributions. This class lists the various methods that can
910  be used to combine gradients in the graph.
911
912  The following aggregation methods are part of the stable API for
913  aggregating gradients:
914
915  *  `ADD_N`: All of the gradient terms are summed as part of one
916     operation using the "AddN" op (see `tf.add_n`). This
917     method has the property that all gradients must be ready and
918     buffered separately in memory before any aggregation is performed.
919  *  `DEFAULT`: The system-chosen default aggregation method.
920
921  The following aggregation methods are experimental and may not
922  be supported in future releases:
923
924  * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using
925    the "AddN" op. This method of summing gradients may reduce
926    performance, but it can improve memory utilization because the
927    gradients can be released earlier.
928
929  """
930  ADD_N = 0
931  DEFAULT = ADD_N
932  # The following are experimental and may not be supported in future releases.
933  EXPERIMENTAL_TREE = 1
934  EXPERIMENTAL_ACCUMULATE_N = 2  # An alias for EXPERIMENTAL_ADD_N = 1
935
936
937def _AggregatedGrads(grads,
938                     op,
939                     gradient_uid,
940                     loop_state,
941                     aggregation_method=None):
942  """Get the aggregated gradients for op.
943
944  Args:
945    grads: The map of memoized gradients.
946    op: The op to get gradients for.
947    gradient_uid: A unique identifier within the graph indicating
948      which invocation of gradients is being executed. Used to cluster
949      ops for compilation.
950    loop_state: An object for maintaining the state of the while loops in the
951                graph. It is of type ControlFlowState. None if the graph
952                contains no while loops.
953    aggregation_method: Specifies the method used to combine gradient terms.
954      Accepted values are constants defined in the class `AggregationMethod`.
955
956  Returns:
957    A list of gradients, one per each output of `op`. If the gradients
958      for a particular output is a list, this function aggregates it
959      before returning.
960
961  Raises:
962    TypeError: if the incoming grads are not Tensors or IndexedSlices.
963    ValueError: if the arguments are invalid.
964
965  """
966  if aggregation_method is None:
967    aggregation_method = AggregationMethod.DEFAULT
968  valid_aggregation_methods = [
969      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
970      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N]
971  if aggregation_method not in valid_aggregation_methods:
972    raise ValueError(
973        f"Invalid `aggregation_method` specified {aggregation_method}. "
974        f"Accepted values are {valid_aggregation_methods}.")
975  out_grads = _GetGrads(grads, op)
976  for i, out_grad in enumerate(out_grads):
977    if loop_state:
978      if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)):
979        assert control_flow_util.IsLoopSwitch(op)
980        continue
981    # Grads have to be Tensors or IndexedSlices
982    if (isinstance(out_grad, collections_abc.Sequence) and not all(
983        isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices))
984        for g in out_grad
985        if g is not None)):
986      raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients "
987                      "have to be either all Tensors or all IndexedSlices")
988    # Aggregate multiple gradients, and convert [] to None.
989    if out_grad:
990      if len(out_grad) < 2:
991        used = "nop"
992        out_grads[i] = out_grad[0]
993      elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
994        tensor_shape = _AccumulatorShape(out_grad)
995        if aggregation_method in [
996            AggregationMethod.EXPERIMENTAL_TREE,
997            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
998        ]:
999          # Aggregate all gradients by doing pairwise sums: this may
1000          # reduce performance, but it can improve memory because the
1001          # gradients can be released earlier.
1002          #
1003          # TODO(vrv): Consider replacing this with a version of
1004          # tf.AddN() that eagerly frees its inputs as soon as they are
1005          # ready, so the order of this tree does not become a problem.
1006          used = "tree"
1007          with ops.name_scope(op.name + "_gradient_sum"):
1008            running_sum = out_grad[0]
1009            for grad in out_grad[1:]:
1010              running_sum = math_ops.add_n([running_sum, grad])
1011            out_grads[i] = running_sum
1012        else:
1013          used = "add_n"
1014          out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
1015        logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
1016                     tensor_shape, used)
1017      else:
1018        out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad)  # pylint: disable=protected-access
1019    else:  # not out_grad
1020      # out_grads[i] is [], thus its aggregation is simply None.
1021      out_grads[i] = None
1022  return out_grads
1023
1024
1025# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
1026# unfortunately too slow to use here.
1027POSSIBLE_GRADIENT_TYPES_NONE = 0
1028POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
1029POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
1030
1031
1032def PossibleTapeGradientTypes(tensors):
1033  """Determines whether and how `args` may require tape gradients."""
1034  return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)
1035