xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/backprop.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code for backpropagation using the tape utilities."""
16
17# TODO(b/159343581): Properly support CompositeTensor in all functions in this
18# file.
19
20import functools
21import operator
22
23from tensorflow.python import pywrap_tfe
24from tensorflow.python.eager import backprop_util
25from tensorflow.python.eager import context
26from tensorflow.python.eager import execute
27from tensorflow.python.eager import imperative_grad
28from tensorflow.python.eager import tape
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import composite_tensor_gradient
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import indexed_slices
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import check_ops
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import default_gradient
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_math_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import resource_variable_ops
46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import _pywrap_utils
49from tensorflow.python.util import nest
50from tensorflow.python.util import tf_contextlib
51from tensorflow.python.util import tf_inspect
52from tensorflow.python.util import variable_utils
53from tensorflow.python.util.lazy_loader import LazyLoader
54from tensorflow.python.util.tf_export import tf_export
55
56
57# Note that we need to lazy load the following two modules to avoid creating
58# circular dependencies.
59# TODO(b/119775953): fix the circular dependencies.
60pfor_ops = LazyLoader(
61    "pfor_ops", globals(),
62    "tensorflow.python.ops.parallel_for.control_flow_ops")
63
64function = LazyLoader("function", globals(),
65                      "tensorflow.python.eager.function")
66
67_op_attr_type_cache = {}
68
69
70def op_attr_type(op_type, attr_name):
71  try:
72    return _op_attr_type_cache[(op_type, attr_name)]
73  except KeyError:
74    context.ensure_initialized()
75    h = context.context()._handle  # pylint: disable=protected-access
76    attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
77  _op_attr_type_cache[(op_type, attr_name)] = attr_type
78  return attr_type
79
80
81def make_attr(attr_type, value):
82  # pybind11 enums do not return the raw value like SWIG enums do. They are
83  # useful when comparing amongst each other but not direct integers as we are
84  # doing in most tests.
85  # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
86  # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
87  # from integer value to class.
88  if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
89    return dtypes.as_dtype(value)
90  if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
91    return [dtypes.as_dtype(v) for v in value]
92  if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
93    return tensor_shape.as_shape(value).as_proto()
94  if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
95    return [tensor_shape.as_shape(v).as_proto() for v in value]
96  return nest.map_structure(
97      lambda v: v.encode() if isinstance(v, str) else v,
98      value)
99
100
101class _MockOp(object):
102  """Pretends to be a tf.Operation for the gradient functions."""
103
104  def __init__(self, attrs, inputs, outputs, typ, skip_input_indices):
105    self.attrs = attrs
106    self.inputs = inputs
107    self.outputs = outputs
108    self.type = typ
109    self.skip_input_indices = skip_input_indices
110
111  def get_attr(self, attr):
112    typ = op_attr_type(self.type, attr)
113    for i in range(0, len(self.attrs), 2):
114      if self.attrs[i] == attr:
115        return make_attr(typ, self.attrs[i + 1])
116    raise KeyError(attr)
117
118  def _get_control_flow_context(self):
119    raise NotImplementedError(
120        "tf.GradientTape.gradients() does not support graph control flow "
121        "operations like tf.cond or tf.while at this time. Use tf.gradients() "
122        "instead. If you need this feature, please file a feature request at "
123        "https://github.com/tensorflow/tensorflow/issues/new"
124    )
125
126
127def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
128                       out_grads, skip_input_indices, forward_pass_name_scope):
129  """Calls the gradient function of the op.
130
131  Args:
132    op_name: the name of the op to be differentiated.
133    attr_tuple: the attrs, as a tuple.
134    num_inputs: the number of inputs to the op.
135    inputs: inputs to the original operation.
136    outputs: outputs to the original operation.
137    out_grads: gradients of the operation wrt its outputs.
138    skip_input_indices: a tuple that is passed to the gradient function,
139      indicating which inputs to skip calculating the gradient for
140    forward_pass_name_scope: the namescope of the op in the forward pass.
141
142  Returns:
143    The gradients with respect to the inputs of the function, as a list.
144  """
145  mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
146  grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
147  if grad_fn is None:
148    return [None] * num_inputs
149
150  # This does not work with v1 TensorArrays.
151  if ops.executing_eagerly_outside_functions(
152  ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
153    gradient_name_scope = "gradient_tape/"
154    if forward_pass_name_scope:
155      gradient_name_scope += forward_pass_name_scope + "/"
156    with ops.name_scope(gradient_name_scope):
157      return grad_fn(mock_op, *out_grads)
158  else:
159    return grad_fn(mock_op, *out_grads)
160
161
162pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
163
164
165def _must_record_gradient():
166  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
167
168
169@tf_export("__internal__.record_gradient", v1=[])
170def record_gradient(op_name, inputs, attrs, outputs):
171  """Explicitly record the gradient for a given op.
172
173  Args:
174    op_name: The op name as listed in the `OpDef` for the op.
175    inputs: A list of tensor inputs to the op.
176    attrs: The op attributes as a flattened list of alternating attribute names
177      and attribute values.
178    outputs: A list of tensor outputs from the op.
179  """
180  pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, outputs,
181                                   ops.get_name_scope())
182
183
184execute.must_record_gradient = _must_record_gradient
185execute.record_gradient = record_gradient
186
187
188def implicit_val_and_grad(f):
189  """Returns a function which differentiates f with respect to variables.
190
191  The wrapped function returns the value and the gradient of f when called with
192  the same arguments. The gradient is with respect to all trainable TFE
193  variables accessed by `f`.
194
195  This function is useful when the exact set of variables to differentiate with
196  is not known ahead of time.
197
198  Example:
199
200  ```python
201  dense_layer = tf.compat.v1.layers.Dense(1)
202  def loss(x, y):
203    return tf.reduce_sum(tf.square(dense_layer(x) - y))
204
205  # Obtain the gradient function.
206  val_grad_fn = tfe.implicit_value_and_gradients(loss)
207
208  # Invoke the gradient function with concrete values of x and y.
209  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
210  y = tf.constant([[10.0], [20.0]])
211  value, grads_and_vars = val_grad_fn(x, y)
212  print('Value of loss: %s' % value)
213
214  # Apply the gradients to Variables.
215  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
216  optimizer.apply_gradients(grads_and_vars)
217  ```
218
219  Args:
220    f: function to be differentiated. If `f` returns a scalar, this scalar will
221      be differentiated. If `f` returns a tensor or list of tensors, by default
222      a scalar will be computed by adding all their values to produce a single
223      scalar.
224
225  Returns:
226    A function which, when called, returns a tuple pair.
227    Its first element is the value to which the function evaluates.
228    Its second element is list of (gradient, variable) pairs.
229
230  Raises:
231    ValueError: if `f` returns None.
232  """
233  # TODO(cais): Remove calls to tf.constant() once the gradients functions
234  # accept lists and np.ndarrays.
235
236  def grad_fn(*args, **kwds):
237    """Computes the gradient of the wrapped function."""
238    this_tape = tape.push_new_tape()
239    try:
240      end_node = f(*args, **kwds)
241      if end_node is None:
242        raise ValueError("Cannot differentiate a function that returns None; "
243                         "did you forget to return a value from {}?".format(
244                             f.__name__))
245    finally:
246      tape.pop_tape(this_tape)
247    # Note: variables are returned in construction order. This ensures unique
248    # order across executions.
249    variables = this_tape.watched_variables()
250    if not variables:
251      raise ValueError("No trainable variables were accessed while the "
252                       "function was being computed.")
253
254    sources = [v.handle for v in variables]
255    for s in sources:
256      if getattr(s, "is_packed", False):
257        raise ValueError(
258            "GradientTape.gradient is not supported on packed EagerTensors yet."
259        )
260    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
261                                           sources)
262    return end_node, list(zip(grad, variables))
263
264  return grad_fn
265
266
267def implicit_grad(f):
268  """Returns a function which differentiates f with respect to variables.
269
270  The wrapped function returns the gradient of f when called with the same
271  arguments. The gradient is with respect to all trainable TFE variables
272  accessed by `f`.
273
274  This function is useful when the exact set of variables to differentiate with
275  is not known ahead of time.
276
277  Example:
278
279  ```python
280  dense_layer = tf.compat.v1.layers.Dense(1)
281  def loss(x, y):
282    return tf.reduce_sum(tf.square(dense_layer(x) - y))
283
284  # Obtain the gradient function.
285  grad_fn = tfe.implicit_gradients(loss)
286
287  # Invoke the gradient function with concrete values of x and y.
288  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
289  y = tf.constant([[10.0], [20.0]])
290  grads_and_vars = grad_fn(x, y)
291
292  # Apply the gradients to Variables.
293  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
294  optimizer.apply_gradients(grads_and_vars)
295  ```
296
297  Args:
298    f: function to be differentiated. If `f` returns a scalar, this scalar will
299      be differentiated. If `f` returns a tensor or list of tensors, by default
300      a scalar will be computed by adding all their values to produce a single
301      scalar.
302
303  Returns:
304    A function which, when called, returns a list of (gradient, variable) pairs.
305  """
306  # TODO(cais): Remove calls to tf.constant() once the gradients functions
307  # accept lists and np.ndarrays.
308
309  def grad_fn(*args, **kwds):
310    """Computes the gradient of the wrapped function."""
311    return implicit_val_and_grad(f)(*args, **kwds)[1]
312
313  return grad_fn
314
315
316def _get_arg_spec(f, params, param_args):
317  """The positions of the parameters of f to be differentiated in param_args."""
318  try:
319    args = tf_inspect.getfullargspec(f).args
320  except TypeError as e:
321    # TypeError can happen when f is a callable object.
322    if params is None:
323      return range(len(param_args))
324    elif all(isinstance(x, int) for x in params):
325      return params
326    raise ValueError("Either callable provided is not a function or could not "
327                     "inspect its arguments by name: %s. Original error: %s"
328                     % (f, e))
329  if params is None:
330    if not args:
331      return range(len(param_args))
332    if args[0] == "self":
333      return range(len(args) - 1)
334    else:
335      return range(len(args))
336  elif all(isinstance(x, str) for x in params):
337    return [args.index(n) for n in params]
338  elif all(isinstance(x, int) for x in params):
339    return params
340  else:
341    raise ValueError(
342        "params must be all strings or all integers; got %s." % params)
343
344
345def gradients_function(f, params=None):
346  """Returns a function which differentiates f with respect to params.
347
348  Example:
349  ```python
350  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
351  # Therefore, the 1st order derivatives are:
352  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
353  #   df / dy = x ^ 3 - 2 * x * y
354  # The 2nd order derivatives with respect to x is:
355  #   d^2 f / (dx)^2 = 6 * x * y
356  def f(x, y):
357    return x * x * x * y - x * y * y
358
359  # Obtain a function that returns 1st order gradients.
360  grad_fn = tfe.gradients_function(f)
361
362  x = 2.0
363  y = 3.0
364
365  # Invoke the 1st order gradient function.
366  x_grad, y_grad = grad_fn(x, y)
367  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
368  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
369
370  # Obtain a function that returns the 2nd order gradient with respect to x.
371  gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
372
373  # Invoke the 2nd order gradient function.
374  x_gradgrad = gradgrad_fn(x, y)[0]
375  assert x_gradgrad.numpy() == 6 * 2 * 3
376
377  # To obtain a callable that returns the gradient(s) of `f` with respect to a
378  # subset of its inputs, use the `params` keyword argument with
379  # `gradients_function()`.
380  ygrad_fn = tfe.gradients_function(f, params=[1])
381
382  (y_grad,) = ygrad_fn(x, y)
383  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
384  ```
385
386  Note that only tensors with real or complex dtypes are differentiable.
387
388  Args:
389    f: function to be differentiated. If `f` returns a scalar, this scalar will
390      be differentiated. If `f` returns a tensor or list of tensors, by default
391      a scalar will be computed by adding all their values to produce a single
392      scalar. If desired, the tensors can be elementwise multiplied by the
393      tensors passed as the `dy` keyword argument to the returned gradient
394      function.
395    params: list of parameter names of f or list of integers indexing the
396      parameters with respect to which we'll differentiate. Passing None
397      differentiates with respect to all parameters.
398
399  Returns:
400    function which, when called, returns the value of f and the gradient
401    of `f` with respect to all of `params`. The function takes an extra optional
402    keyword argument `dy`. Setting it allows computation of vector jacobian
403    products for vectors other than the vector of ones.
404
405  Raises:
406    ValueError: if the params are not all strings or all integers.
407  """
408
409  def decorated(*args, **kwds):
410    """Computes the gradient of the decorated function."""
411
412    _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
413    return grad
414
415  return decorated
416
417
418def _ensure_unique_tensor_objects(parameter_positions, args):
419  """Make each of the parameter_positions in args a unique ops.Tensor object.
420
421  Ensure that each parameter is treated independently.
422  For example:
423
424  def f(x, y): return x * y
425  g = gradients_function(f)
426  one = tf.constant(1.)
427
428  g(one, one) should return [1., 1.]
429  (even though the two arguments are the same Tensor object).
430
431  Args:
432    parameter_positions: List of indices into args defining the arguments to
433      differentiate against.
434    args: A list of arguments to the function to be differentiated.
435
436  Returns:
437    args, possibly edited in-place.
438  """
439  s = set()
440  for (i, t) in enumerate(args):
441    if i in parameter_positions:
442      tid = ops.tensor_id(t)
443      if tid in s:
444        args[i] = gen_array_ops.identity(args[i])
445      else:
446        s.add(tid)
447  return args
448
449
450def val_and_grad_function(f, params=None):
451  """Returns a function that computes f and its derivative w.r.t. params.
452
453  Example:
454  ```python
455  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
456  # Therefore, the 1st order derivatives are:
457  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
458  #   df / dy = x ^ 3 - 2 * x * y
459  def f(x, y):
460    return x * x * x * y - x * y * y
461
462  # Obtain a function that returns the function value and the 1st order
463  # gradients.
464  val_grads_fn = tfe.value_and_gradients_function(f)
465
466  x = 2.0
467  y = 3.0
468
469  # Invoke the value-and-gradients function.
470  f_val, (x_grad, y_grad) = val_grads_fn(x, y)
471  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
472  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
473  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
474
475  # To obtain a callable that returns the value of `f` and the gradient(s) of
476  # `f` with respect to a subset of its inputs, use the `params` keyword
477  # argument with `value_and_gradients_function()`.
478  val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
479
480  f_val, (y_grad,) = val_ygrad_fn(x, y)
481  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
482  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
483  ```
484
485  Args:
486    f: function to be differentiated. If `f` returns a scalar, this scalar will
487      be differentiated. If `f` returns a tensor or list of tensors, by default
488      a scalar will be computed by adding all their values to produce a single
489      scalar. If desired, the tensors can be elementwise multiplied by the
490      tensors passed as the `dy` keyword argument to the returned gradient
491      function.
492    params: list of parameter names of f or list of integers indexing the
493      parameters with respect to which we'll differentiate. Passing `None`
494      differentiates with respect to all parameters.
495
496  Returns:
497    function which, when called, returns the value of f and the gradient
498    of f with respect to all of `params`. The function takes an extra optional
499    keyword argument "dy". Setting it allows computation of vector jacobian
500    products for vectors other than the vector of ones.
501
502  Raises:
503    ValueError: if the params are not all strings or all integers.
504  """
505
506  def decorated(*args, **kwds):
507    """Computes the value and gradient of the decorated function."""
508    dy = kwds.pop("dy", None)
509    if kwds:
510      raise ValueError("Functions to be differentiated cannot "
511                       "receive keyword arguments.")
512    val, vjp = make_vjp(f, params)(*args, **kwds)
513    return val, vjp(dy=dy)
514
515  return decorated
516
517
518def make_vjp(f, params=None, persistent=True):
519  """Returns a function that computes f and its vjp w.r.t.
520
521  params.
522
523  The term "vjp" here is an abbreviation for vector-jacobian product.
524
525  Args:
526    f: the function to be differentiated.
527    params: the parameters (numbers or names) to differentiate with respect to.
528      A value of None will differentiate with respect to all parameters.
529    persistent: Boolean controlling whether the VJP function can be re-used.
530      Must be True or False.
531
532  Returns:
533    A function, which when called, returns a tuple (value, vjp), where:
534    - value is the result of calling f.
535    - vjp is a function, which takes a vector as an argument and
536      returns the product of that vector with the Jacobian of f.
537      Providing no argument to vjp is equivalent to providing a
538      vector of ones.
539
540    For example,
541    ```python
542    def f(x):
543      return x * x
544
545    wrapped_fn = tfe.make_vjp(f)
546    result, vjp = wrapped_fn(tf.constant(3.0))
547    # result is 9.0
548    vjp()  # the vjp function returns 6.0
549
550  Raises:
551    ValueError: if `f` returns None.
552  """
553
554  def decorated(*args, **kwds):
555    """Computes the value and gradient of the decorated function."""
556    parameter_positions = _get_arg_spec(f, params, args)
557    assert not kwds, "The gradient function can't take keyword arguments."
558    this_tape = tape.push_new_tape(persistent=persistent)
559    try:
560      sources = []
561      args = [
562          ops.convert_to_tensor(arg) if i in parameter_positions else arg
563          for i, arg in enumerate(args)
564      ]
565      args = _ensure_unique_tensor_objects(parameter_positions, args)
566      for i in parameter_positions:
567        if getattr(args[i], "is_packed", False):
568          raise ValueError(
569              "GradientTape.gradient is not supported on packed EagerTensors"
570              "yet.")
571        sources.append(args[i])
572        tape.watch(this_tape, args[i])
573      result = f(*args)
574      if result is None:
575        raise ValueError("Cannot differentiate a function that returns None; "
576                         "did you forget to return a value from {}?".format(
577                             f.__name__))
578      flat_result = nest.flatten(result)
579      flat_result = [gen_array_ops.identity(x) for x in flat_result]
580      result = nest.pack_sequence_as(result, flat_result)
581    finally:
582      tape.pop_tape(this_tape)
583    def vjp(dy=None):
584      if dy is not None:
585        dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
586      return imperative_grad.imperative_grad(
587          this_tape, nest.flatten(result), sources, output_gradients=dy)
588
589    return result, vjp
590
591  return decorated
592
593
594def flatten_nested_indexed_slices(grad):
595  assert isinstance(grad, indexed_slices.IndexedSlices)
596  if isinstance(grad.values, ops.Tensor):
597    return grad
598  else:
599    assert isinstance(grad.values, indexed_slices.IndexedSlices)
600    g = flatten_nested_indexed_slices(grad.values)
601    return indexed_slices.IndexedSlices(
602        g.values, array_ops.gather(grad.indices, g.indices), g.dense_shape)
603
604
605def aggregate_indexed_slices_gradients(grads):
606  """Aggregates gradients containing `IndexedSlices`s."""
607  if len(grads) < 1:
608    return None
609  if len(grads) == 1:
610    return grads[0]
611  grads = [g for g in grads if g is not None]
612  # If any gradient is a `Tensor`, sum them up and return a dense tensor
613  # object.
614  if any(isinstance(g, ops.Tensor) for g in grads):
615    return math_ops.add_n(grads)
616
617  # The following `_as_indexed_slices_list` casts ids of IndexedSlices into
618  # int64. It is to make sure the inputs of `concat` all have same the data
619  # type.
620  grads = math_ops._as_indexed_slices_list(grads)  # pylint: disable=protected-access
621
622  grads = [flatten_nested_indexed_slices(x) for x in grads]
623  # Form IndexedSlices out of the concatenated values and indices.
624  concat_grad = indexed_slices.IndexedSlices(
625      array_ops.concat([x.values for x in grads], axis=0),
626      array_ops.concat([x.indices for x in grads], axis=0),
627      grads[0].dense_shape)
628
629  return concat_grad
630
631
632def _aggregate_grads(gradients):
633  """Aggregate gradients from multiple sources.
634
635  Args:
636    gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
637
638  Returns:
639    If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
640    Otherwise returns an aggregated 'IndexedSlices'.
641  """
642  assert gradients, "No gradients to aggregate"
643
644  if len(gradients) == 1:
645    return gradients[0]
646  if all(isinstance(g, ops.Tensor) for g in gradients):
647    return gen_math_ops.add_n(gradients)
648  else:
649    assert all(
650        isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices))
651        for g in gradients)
652    return aggregate_indexed_slices_gradients(gradients)
653
654
655def _num_elements(grad):
656  """The number of elements in the `grad` tensor."""
657  if isinstance(grad, ops.Tensor):
658    shape_tuple = grad._shape_tuple()  # pylint: disable=protected-access
659  elif isinstance(grad, indexed_slices.IndexedSlices):
660    shape_tuple = grad.values._shape_tuple()  # pylint: disable=protected-access
661  else:
662    raise ValueError("`grad` not a Tensor or IndexedSlices.")
663  if shape_tuple is None or None in shape_tuple:
664    return 0
665  return functools.reduce(operator.mul, shape_tuple, 1)
666
667
668def _fast_fill(value, shape, dtype):
669  return array_ops.fill(
670      constant_op.constant(shape, dtype=dtypes.int32),
671      constant_op.constant(value, dtype=dtype))
672
673
674def _zeros(shape, dtype):
675  """Helper to return (possibly cached) zero tensors in eager mode."""
676  # Note: variants will use _zeros_like
677  if dtype == dtypes.string or dtype == dtypes.resource:
678    return None
679
680  ctx = context.context()
681  if not ctx.executing_eagerly():
682    return array_ops.zeros(shape, dtype)
683
684  device = ctx.device_name
685
686  if tensor_util.is_tf_type(shape):
687    shape_key = shape.ref()
688  else:
689    shape_key = shape
690  cache_key = shape_key, dtype, device
691  cached = ctx.zeros_cache().get(cache_key)
692  if cached is None:
693    if dtypes.as_dtype(dtype).is_bool:
694      value = False
695    else:
696      value = 0
697    cached = _fast_fill(value, shape, dtype)
698    ctx.zeros_cache().put(cache_key, cached)
699  return cached
700
701
702def _ones(shape, dtype):
703  as_dtype = dtypes.as_dtype(dtype)
704  if as_dtype == dtypes.string:
705    return None
706
707  if not context.executing_eagerly():
708    return array_ops.ones(shape, dtype)
709
710  if as_dtype.is_bool:
711    value = True
712  else:
713    value = 1
714
715  if shape == ():  # pylint: disable=g-explicit-bool-comparison
716    return constant_op.constant(value, dtype=dtype)
717  return _fast_fill(value, shape, dtype)
718
719
720_default_vspace = imperative_grad.VSpace(
721    num_elements_fn=_num_elements,
722    aggregate_fn=_aggregate_grads,
723    zeros_fn=_zeros,
724    ones_fn=_ones,
725    zeros_like_fn=default_gradient.zeros_like,
726    ones_like_fn=default_gradient.ones_like,
727    graph_shape_fn=gen_array_ops.shape)
728pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
729
730
731def _handle_or_self(x):
732  """Unwrap resource variable/ndarray to return tensors."""
733  if resource_variable_ops.is_resource_variable(x):
734    return x.handle
735  return x
736
737
738def _extract_tensors_and_variables(tensor):
739  """Extracts tensors and variables from the input object."""
740  for obj in nest.flatten(tensor):
741    if _pywrap_utils.IsTensor(obj) or _pywrap_utils.IsVariable(obj):
742      yield obj
743    elif isinstance(obj, composite_tensor.CompositeTensor):
744      components = type_spec.type_spec_from_value(obj)._to_components(obj)  # pylint: disable=protected-access
745      yield from _extract_tensors_and_variables(components)
746    else:
747      raise ValueError(f"Passed in object {obj} of type {type(obj).__name__!r}"
748                       f", not tf.Tensor or tf.Variable or ExtensionType.")
749
750
751@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"])
752class GradientTape:
753  """Record operations for automatic differentiation.
754
755  Operations are recorded if they are executed within this context manager and
756  at least one of their inputs is being "watched".
757
758  Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`,
759  where `trainable=True` is default in both cases) are automatically watched.
760  Tensors can be manually watched by invoking the `watch` method on this context
761  manager.
762
763  For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
764  be computed as:
765
766  >>> x = tf.constant(3.0)
767  >>> with tf.GradientTape() as g:
768  ...   g.watch(x)
769  ...   y = x * x
770  >>> dy_dx = g.gradient(y, x)
771  >>> print(dy_dx)
772  tf.Tensor(6.0, shape=(), dtype=float32)
773
774  GradientTapes can be nested to compute higher-order derivatives. For example,
775
776  >>> x = tf.constant(5.0)
777  >>> with tf.GradientTape() as g:
778  ...   g.watch(x)
779  ...   with tf.GradientTape() as gg:
780  ...     gg.watch(x)
781  ...     y = x * x
782  ...   dy_dx = gg.gradient(y, x)  # dy_dx = 2 * x
783  >>> d2y_dx2 = g.gradient(dy_dx, x)  # d2y_dx2 = 2
784  >>> print(dy_dx)
785  tf.Tensor(10.0, shape=(), dtype=float32)
786  >>> print(d2y_dx2)
787  tf.Tensor(2.0, shape=(), dtype=float32)
788
789  By default, the resources held by a GradientTape are released as soon as
790  GradientTape.gradient() method is called. To compute multiple gradients over
791  the same computation, create a persistent gradient tape. This allows multiple
792  calls to the gradient() method as resources are released when the tape object
793  is garbage collected. For example:
794
795  >>> x = tf.constant(3.0)
796  >>> with tf.GradientTape(persistent=True) as g:
797  ...   g.watch(x)
798  ...   y = x * x
799  ...   z = y * y
800  >>> dz_dx = g.gradient(z, x)  # (4*x^3 at x = 3)
801  >>> print(dz_dx)
802  tf.Tensor(108.0, shape=(), dtype=float32)
803  >>> dy_dx = g.gradient(y, x)
804  >>> print(dy_dx)
805  tf.Tensor(6.0, shape=(), dtype=float32)
806
807  By default GradientTape will automatically watch any trainable variables that
808  are accessed inside the context. If you want fine grained control over which
809  variables are watched you can disable automatic tracking by passing
810  `watch_accessed_variables=False` to the tape constructor:
811
812  >>> x = tf.Variable(2.0)
813  >>> w = tf.Variable(5.0)
814  >>> with tf.GradientTape(
815  ...     watch_accessed_variables=False, persistent=True) as tape:
816  ...   tape.watch(x)
817  ...   y = x ** 2  # Gradients will be available for `x`.
818  ...   z = w ** 3  # No gradients will be available as `w` isn't being watched.
819  >>> dy_dx = tape.gradient(y, x)
820  >>> print(dy_dx)
821  tf.Tensor(4.0, shape=(), dtype=float32)
822  >>> # No gradients will be available as `w` isn't being watched.
823  >>> dz_dw = tape.gradient(z, w)
824  >>> print(dz_dw)
825  None
826
827  Note that when using models you should ensure that your variables exist when
828  using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
829  first iteration not have any gradients:
830
831  ```python
832  a = tf.keras.layers.Dense(32)
833  b = tf.keras.layers.Dense(32)
834
835  with tf.GradientTape(watch_accessed_variables=False) as tape:
836    tape.watch(a.variables)  # Since `a.build` has not been called at this point
837                             # `a.variables` will return an empty list and the
838                             # tape will not be watching anything.
839    result = b(a(inputs))
840    tape.gradient(result, a.variables)  # The result of this computation will be
841                                        # a list of `None`s since a's variables
842                                        # are not being watched.
843  ```
844
845  Note that only tensors with real or complex dtypes are differentiable.
846  """
847
848  def __init__(self, persistent=False, watch_accessed_variables=True):
849    """Creates a new GradientTape.
850
851    Args:
852      persistent: Boolean controlling whether a persistent gradient tape
853        is created. False by default, which means at most one call can
854        be made to the gradient() method on this object.
855      watch_accessed_variables: Boolean controlling whether the tape will
856        automatically `watch` any (trainable) variables accessed while the tape
857        is active. Defaults to True meaning gradients can be requested from any
858        result computed in the tape derived from reading a trainable `Variable`.
859        If False users must explicitly `watch` any `Variable`s they want to
860        request gradients from.
861    """
862    self._tape = None
863    self._persistent = persistent
864    self._watch_accessed_variables = watch_accessed_variables
865    self._watched_variables = ()
866    self._recording = False
867
868  def __enter__(self):
869    """Enters a context inside which operations are recorded on this tape."""
870    self._push_tape()
871    return self
872
873  def __exit__(self, typ, value, traceback):
874    """Exits the recording context, no further operations are traced."""
875    if self._recording:
876      self._pop_tape()
877
878  def _push_tape(self):
879    """Pushes a new tape onto the tape stack."""
880    if self._recording:
881      raise ValueError("Tape is still recording, This can happen if you try to "
882                       "re-enter an already-active tape.")
883    if self._tape is None:
884      self._tape = tape.push_new_tape(
885          persistent=self._persistent,
886          watch_accessed_variables=self._watch_accessed_variables)
887    else:
888      tape.push_tape(self._tape)
889    self._recording = True
890
891  def _pop_tape(self):
892    if not self._recording:
893      raise ValueError("Tape is not recording.")
894    tape.pop_tape(self._tape)
895    self._recording = False
896
897  @tf_contextlib.contextmanager
898  def _ensure_recording(self):
899    """Ensures that this tape is recording."""
900    if not self._recording:
901      try:
902        self._push_tape()
903        yield
904      finally:
905        self._pop_tape()
906    else:
907      yield
908
909  # TODO(b/209081027): Add a variable in composite tensor test case after
910  # variables become composite tensors.
911  def watch(self, tensor):
912    """Ensures that `tensor` is being traced by this tape.
913
914    Args:
915      tensor: a Tensor/Variable or list of Tensors/Variables.
916
917    Raises:
918      ValueError: if it encounters something that is not a tensor.
919    """
920    for t in _extract_tensors_and_variables(tensor):
921      if not backprop_util.IsTrainable(t):
922        logging.log_first_n(
923            logging.WARN, "The dtype of the watched tensor must be "
924            "floating (e.g. tf.float32), got %r", 5, t.dtype)
925      if hasattr(t, "handle"):
926        # There are many variable-like objects, all of them currently have
927        # `handle` attribute that points to a tensor. If this changes,
928        # internals of watch_variable need to change as well.
929        tape.watch_variable(self._tape, t)
930      else:
931        tape.watch(self._tape, t)
932
933  @tf_contextlib.contextmanager
934  def stop_recording(self):
935    """Temporarily stops recording operations on this tape.
936
937    Operations executed while this context manager is active will not be
938    recorded on the tape. This is useful for reducing the memory used by tracing
939    all computations.
940
941    For example:
942
943    >>> x = tf.constant(4.0)
944    >>> with tf.GradientTape() as tape:
945    ...   with tape.stop_recording():
946    ...     y = x ** 2
947    >>> dy_dx = tape.gradient(y, x)
948    >>> print(dy_dx)
949    None
950
951    Yields:
952      None
953    Raises:
954      RuntimeError: if the tape is not currently recording.
955    """
956    if self._tape is None:
957      raise RuntimeError(
958          "Trying to stop recording a tape which is not recording.")
959    self._pop_tape()
960    try:
961      yield
962    finally:
963      self._push_tape()
964
965  def reset(self):
966    """Clears all information stored in this tape.
967
968    Equivalent to exiting and reentering the tape context manager with a new
969    tape. For example, the two following code blocks are equivalent:
970
971    ```
972    with tf.GradientTape() as t:
973      loss = loss_fn()
974    with tf.GradientTape() as t:
975      loss += other_loss_fn()
976    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
977
978
979    # The following is equivalent to the above
980    with tf.GradientTape() as t:
981      loss = loss_fn()
982      t.reset()
983      loss += other_loss_fn()
984    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
985    ```
986
987    This is useful if you don't want to exit the context manager for the tape,
988    or can't because the desired reset point is inside a control flow construct:
989
990    ```
991    with tf.GradientTape() as t:
992      loss = ...
993      if loss > k:
994        t.reset()
995    ```
996    """
997    self._pop_tape()
998    self._tape = None
999    self._push_tape()
1000
1001  def watched_variables(self):
1002    """Returns variables watched by this tape in order of construction."""
1003    if self._tape is not None:
1004      self._watched_variables = self._tape.watched_variables()
1005    return self._watched_variables
1006
1007  def gradient(self,
1008               target,
1009               sources,
1010               output_gradients=None,
1011               unconnected_gradients=UnconnectedGradients.NONE):
1012    """Computes the gradient using operations recorded in context of this tape.
1013
1014    Note: Unless you set `persistent=True` a GradientTape can only be used to
1015    compute one set of gradients (or jacobians).
1016
1017    In addition to Tensors, gradient also supports RaggedTensors. For example,
1018
1019    >>> x = tf.ragged.constant([[1.0, 2.0], [3.0]])
1020    >>> with tf.GradientTape() as g:
1021    ...   g.watch(x)
1022    ...   y = x * x
1023    >>> g.gradient(y, x)
1024    <tf.RaggedTensor [[2.0, 4.0], [6.0]]>
1025
1026    Args:
1027      target: a list or nested structure of Tensors or Variables or
1028        CompositeTensors to be differentiated.
1029      sources: a list or nested structure of Tensors or Variables or
1030        CompositeTensors. `target` will be differentiated against elements in
1031        `sources`.
1032      output_gradients: a list of gradients, one for each differentiable
1033        element of target. Defaults to None.
1034      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1035        alters the value which will be returned if the target and sources are
1036        unconnected. The possible values and effects are detailed in
1037        'UnconnectedGradients' and it defaults to 'none'.
1038
1039    Returns:
1040      a list or nested structure of Tensors (or IndexedSlices, or None, or
1041      CompositeTensor), one for each element in `sources`. Returned structure
1042      is the same as the structure of `sources`.
1043
1044    Raises:
1045      RuntimeError: If called on a used, non-persistent tape.
1046      RuntimeError: If called inside the context of the tape.
1047      TypeError: If the target is a None object.
1048      ValueError: If the target is a variable or if unconnected gradients is
1049       called with an unknown value.
1050    """
1051    if self._tape is None:
1052      raise RuntimeError("A non-persistent GradientTape can only be used to "
1053                         "compute one set of gradients (or jacobians)")
1054    if self._recording:
1055      if not self._persistent:
1056        self._pop_tape()
1057      else:
1058        logging.log_first_n(
1059            logging.WARN, "Calling GradientTape.gradient on a persistent "
1060            "tape inside its context is significantly less "
1061            "efficient than calling it outside the context (it "
1062            "causes the gradient ops to be recorded on the "
1063            "tape, leading to increased CPU and memory usage). "
1064            "Only call GradientTape.gradient inside the "
1065            "context if you actually want to trace the "
1066            "gradient in order to compute higher order "
1067            "derivatives.", 1)
1068
1069    if target is None:
1070      raise TypeError("Argument `target` should be a list or nested structure"
1071                      " of Tensors, Variables or CompositeTensors to be "
1072                      "differentiated, but received None.")
1073
1074    flat_targets = []
1075    for t in nest.flatten(target):
1076      flat_targets.append(_handle_or_self(t))
1077    flat_targets = composite_tensor_gradient.get_flat_tensors_for_gradients(
1078        flat_targets)
1079    for t in flat_targets:
1080      if not backprop_util.IsTrainable(t):
1081        logging.vlog(
1082            1, "The dtype of the target tensor must be "
1083            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1084            "got %r", t.dtype)
1085
1086    flat_sources_raw = nest.flatten(sources)
1087    flat_sources = []
1088    for t in flat_sources_raw:
1089      flat_sources.append(_handle_or_self(t))
1090    flat_sources = composite_tensor_gradient.get_flat_tensors_for_gradients(
1091        flat_sources)
1092    for t in flat_sources:
1093      if not backprop_util.IsTrainable(t):
1094        logging.vlog(
1095            1, "The dtype of the source tensor must be "
1096            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1097            "got %r", t.dtype)
1098      if getattr(t, "is_packed", False):
1099        raise ValueError(
1100            "GradientTape.gradient is not supported on packed EagerTensors yet."
1101        )
1102
1103    if output_gradients is not None:
1104      output_gradients = nest.flatten(
1105          variable_utils.convert_variables_to_tensors(output_gradients))
1106      output_gradients = (
1107          composite_tensor_gradient.get_flat_tensors_for_gradients(
1108              output_gradients))
1109      output_gradients = [None if x is None else ops.convert_to_tensor(x)
1110                          for x in output_gradients]
1111
1112    flat_grad = imperative_grad.imperative_grad(
1113        self._tape,
1114        flat_targets,
1115        flat_sources,
1116        output_gradients=output_gradients,
1117        sources_raw=flat_sources_raw,
1118        unconnected_gradients=unconnected_gradients)
1119
1120    if not self._persistent:
1121      # Keep track of watched variables before setting tape to None
1122      self._watched_variables = self._tape.watched_variables()
1123      self._tape = None
1124
1125    flat_sources_raw = nest.map_structure(_handle_or_self, flat_sources_raw)
1126    flat_grad = composite_tensor_gradient.replace_flat_tensors_for_gradients(
1127        flat_sources_raw, flat_grad)
1128    grad = nest.pack_sequence_as(sources, flat_grad)
1129    return grad
1130
1131  def jacobian(self,
1132               target,
1133               sources,
1134               unconnected_gradients=UnconnectedGradients.NONE,
1135               parallel_iterations=None,
1136               experimental_use_pfor=True):
1137    """Computes the jacobian using operations recorded in context of this tape.
1138
1139    Note: Unless you set `persistent=True` a GradientTape can only be used to
1140    compute one set of gradients (or jacobians).
1141
1142    Note: By default the jacobian implementation uses parallel for (pfor), which
1143    creates a tf.function under the hood for each jacobian call. For better
1144    performance, and to avoid recompilation and vectorization rewrites on each
1145    call, enclose GradientTape code in @tf.function.
1146
1147    See[wikipedia
1148    article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1149    for the definition of a Jacobian.
1150
1151    Example usage:
1152
1153    ```python
1154    with tf.GradientTape() as g:
1155      x  = tf.constant([1.0, 2.0])
1156      g.watch(x)
1157      y = x * x
1158    jacobian = g.jacobian(y, x)
1159    # jacobian value is [[2., 0.], [0., 4.]]
1160    ```
1161
1162    Args:
1163      target: Tensor to be differentiated.
1164      sources: a list or nested structure of Tensors or Variables. `target`
1165        will be differentiated against elements in `sources`.
1166      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1167        alters the value which will be returned if the target and sources are
1168        unconnected. The possible values and effects are detailed in
1169        'UnconnectedGradients' and it defaults to 'none'.
1170      parallel_iterations: A knob to control how many iterations are dispatched
1171        in parallel. This knob can be used to control the total memory usage.
1172      experimental_use_pfor: If true, vectorizes the jacobian computation. Else
1173        falls back to a sequential while_loop. Vectorization can sometimes fail
1174        or lead to excessive memory usage. This option can be used to disable
1175        vectorization in such cases.
1176
1177    Returns:
1178      A list or nested structure of Tensors (or None), one for each element in
1179      `sources`. Returned structure is the same as the structure of `sources`.
1180      Note if any gradient is sparse (IndexedSlices), jacobian function
1181      currently makes it dense and returns a Tensor instead. This may change in
1182      the future.
1183
1184
1185    Raises:
1186      RuntimeError: If called on a used, non-persistent tape.
1187      RuntimeError: If called on a non-persistent tape with eager execution
1188        enabled and without enabling experimental_use_pfor.
1189      ValueError: If vectorization of jacobian computation fails.
1190    """
1191    if self._tape is None:
1192      raise RuntimeError("A non-persistent GradientTape can only be used to "
1193                         "compute one set of gradients (or jacobians)")
1194
1195    flat_sources = nest.flatten(sources)
1196    target_static_shape = target.shape
1197    target_shape = array_ops.shape(target)
1198    # Note that we push and pop the tape here and below. This is needed since we
1199    # need gradients through the enclosed operations.
1200    with self._ensure_recording():
1201      target = array_ops.reshape(target, [-1])
1202
1203    def loop_fn(i):
1204      with self._ensure_recording():
1205        y = array_ops.gather(target, i)
1206      return self.gradient(y, flat_sources,
1207                           unconnected_gradients=unconnected_gradients)
1208
1209    try:
1210      target_size = int(target.shape[0])
1211    except TypeError:
1212      target_size = array_ops.shape(target)[0]
1213
1214    if experimental_use_pfor:
1215      try:
1216        output = pfor_ops.pfor(loop_fn, target_size,
1217                               parallel_iterations=parallel_iterations)
1218      except ValueError as err:
1219        raise ValueError(
1220            "Encountered an exception while vectorizing the "
1221            "jacobian computation. Vectorization can be disabled by setting"
1222            " experimental_use_pfor to False.") from err
1223    else:
1224      if context.executing_eagerly() and not self._persistent:
1225        raise RuntimeError(
1226            "GradientTape must be created with persistent=True"
1227            " to compute the jacobian with eager execution enabled and with "
1228            " experimental_use_pfor set to False.")
1229      output = pfor_ops.for_loop(
1230          loop_fn, [target.dtype] * len(flat_sources), target_size,
1231          parallel_iterations=parallel_iterations)
1232
1233    for i, out in enumerate(output):
1234      if out is not None:
1235        new_shape = array_ops.concat(
1236            [target_shape, array_ops.shape(out)[1:]], axis=0)
1237        out = array_ops.reshape(out, new_shape)
1238        if context.executing_eagerly():
1239          out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
1240      output[i] = out
1241
1242    return nest.pack_sequence_as(sources, output)
1243
1244  def batch_jacobian(self,
1245                     target,
1246                     source,
1247                     unconnected_gradients=UnconnectedGradients.NONE,
1248                     parallel_iterations=None,
1249                     experimental_use_pfor=True):
1250    """Computes and stacks per-example jacobians.
1251
1252    See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1253    for the definition of a Jacobian. This function is essentially an efficient
1254    implementation of the following:
1255
1256    `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
1257
1258    Note that compared to `GradientTape.jacobian` which computes gradient of
1259    each output value w.r.t each input value, this function is useful when
1260    `target[i,...]` is independent of `source[j,...]` for `j != i`. This
1261    assumption allows more efficient computation as compared to
1262    `GradientTape.jacobian`. The output, as well as intermediate activations,
1263    are lower dimensional and avoid a bunch of redundant zeros which would
1264    result in the jacobian computation given the independence assumption.
1265
1266    Note: Unless you set `persistent=True` a GradientTape can only be used to
1267    compute one set of gradients (or jacobians).
1268
1269    Note: By default the batch_jacobian implementation uses parallel for (pfor),
1270    which creates a tf.function under the hood for each batch_jacobian call.
1271    For better performance, and to avoid recompilation and vectorization
1272    rewrites on each call, enclose GradientTape code in @tf.function.
1273
1274
1275    Example usage:
1276
1277    ```python
1278    with tf.GradientTape() as g:
1279      x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
1280      g.watch(x)
1281      y = x * x
1282    batch_jacobian = g.batch_jacobian(y, x)
1283    # batch_jacobian is [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
1284    ```
1285
1286    Args:
1287      target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
1288        `target[i,...]` should only depend on `source[i,...]`.
1289      source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
1290      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1291        alters the value which will be returned if the target and sources are
1292        unconnected. The possible values and effects are detailed in
1293        'UnconnectedGradients' and it defaults to 'none'.
1294      parallel_iterations: A knob to control how many iterations are dispatched
1295        in parallel. This knob can be used to control the total memory usage.
1296      experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
1297        uses a tf.while_loop.
1298
1299    Returns:
1300      A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
1301      is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
1302      per-example jacobians.
1303
1304    Raises:
1305      RuntimeError: If called on a used, non-persistent tape.
1306      RuntimeError: If called on a non-persistent tape with eager execution
1307        enabled and without enabling experimental_use_pfor.
1308      ValueError: If vectorization of jacobian computation fails or if first
1309        dimension of `target` and `source` do not match.
1310    """
1311    if self._tape is None:
1312      raise RuntimeError("A non-persistent GradientTape can only be used to"
1313                         "compute one set of gradients (or jacobians)")
1314    target_shape = target.shape
1315    if target_shape.rank is None:
1316      dim = tensor_shape.Dimension(None)
1317    else:
1318      dim = target_shape.dims[0]
1319    if not (target_shape.with_rank_at_least(2) and
1320            source.shape.with_rank_at_least(2) and
1321            dim.is_compatible_with(source.shape[0])):
1322      raise ValueError(
1323          "Need first dimension of target shape (%s) and "
1324          "source shape (%s) to match." % (target.shape, source.shape))
1325    if target_shape.is_fully_defined():
1326      batch_size = int(target_shape[0])
1327      target_row_size = target_shape.num_elements() // batch_size
1328    else:
1329      target_shape = array_ops.shape(target)
1330      batch_size = target_shape[0]
1331      target_row_size = array_ops.size(target) // batch_size
1332    source_shape = array_ops.shape(source)
1333    # Flatten target to 2-D.
1334    # Note that we push and pop the tape here and below. This is needed since we
1335    # need gradients through the enclosed operations.
1336    with self._ensure_recording():
1337      with ops.control_dependencies(
1338          [check_ops.assert_equal(batch_size, source_shape[0])]):
1339        target = array_ops.reshape(target, [batch_size, target_row_size])
1340
1341    run_once = False
1342
1343    def loop_fn(i):
1344      nonlocal run_once
1345      if run_once and not self._persistent:
1346        if parallel_iterations is not None:
1347          raise RuntimeError(
1348              "GradientTape must be created with persistent=True"
1349              " to compute the batch_jacobian with parallel_iterations.")
1350        else:
1351          raise RuntimeError(
1352              "GradientTape must be created with persistent=True"
1353              " to compute the batch_jacobian.")
1354      run_once = True
1355
1356      with self._ensure_recording():
1357        y = array_ops.gather(target, i, axis=1)
1358      return self.gradient(y, source,
1359                           unconnected_gradients=unconnected_gradients)
1360
1361    if experimental_use_pfor:
1362      try:
1363        output = pfor_ops.pfor(loop_fn, target_row_size,
1364                               parallel_iterations=parallel_iterations)
1365      except ValueError as err:
1366        raise ValueError(
1367            "Encountered an exception while vectorizing the "
1368            "batch_jacobian computation. Vectorization can be disabled by "
1369            "setting experimental_use_pfor to False.") from err
1370    else:
1371      if context.executing_eagerly() and not self._persistent:
1372        raise RuntimeError(
1373            "GradientTape must be created with persistent=True"
1374            " to compute the batch_jacobian with eager execution enabled and "
1375            " with experimental_use_pfor set to False.")
1376      output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
1377                                 parallel_iterations=parallel_iterations)
1378    new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
1379    if output is None:
1380      # Note that this block is returning zeros when it could use `None` to
1381      # represent unconnected gradients. This is to maintain compatibility with
1382      # the previous behavior, which ignored `unconnected_gradients`.
1383      output = array_ops.zeros(new_shape, target.dtype)
1384      return output
1385    else:
1386      output = array_ops.reshape(output,
1387                                 [target_row_size, batch_size, -1])
1388      output = array_ops.transpose(output, [1, 0, 2])
1389
1390      output = array_ops.reshape(output, new_shape)
1391      return output
1392