xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/gradients_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import check_ops  # pylint: disable=unused-import
22from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import gradients_util
25from tensorflow.python.ops import image_grad  # pylint: disable=unused-import
26from tensorflow.python.ops import linalg_grad  # pylint: disable=unused-import
27from tensorflow.python.ops import linalg_ops  # pylint: disable=unused-import
28from tensorflow.python.ops import logging_ops  # pylint: disable=unused-import
29from tensorflow.python.ops import manip_grad  # pylint: disable=unused-import
30from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import optional_grad  # pylint: disable=unused-import
33from tensorflow.python.ops import random_grad  # pylint: disable=unused-import
34from tensorflow.python.ops import tensor_array_ops
35from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
36from tensorflow.python.util.tf_export import tf_export
37
38
39@tf_export(v1=["gradients"])
40def gradients(ys,
41              xs,
42              grad_ys=None,
43              name="gradients",
44              colocate_gradients_with_ops=False,
45              gate_gradients=False,
46              aggregation_method=None,
47              stop_gradients=None,
48              unconnected_gradients=UnconnectedGradients.NONE):
49  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
50
51  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
52  is a list of `Tensor`, holding the gradients received by the
53  `ys`. The list must be the same length as `ys`.
54
55  `gradients()` adds ops to the graph to output the derivatives of `ys` with
56  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
57  each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`.
58
59  `grad_ys` is a list of tensors of the same length as `ys` that holds
60  the initial gradients for each y in `ys`.  When `grad_ys` is None,
61  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
62  user can provide their own initial `grad_ys` to compute the
63  derivatives using a different initial gradient for each y (e.g., if
64  one wanted to weight the gradient differently for each value in
65  each y).
66
67  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
68  with respect to all `xs`. These tensors will not be backpropagated through,
69  as though they had been explicitly disconnected using `stop_gradient`.  Among
70  other things, this allows computation of partial derivatives as opposed to
71  total derivatives. For example:
72
73  ```python
74  a = tf.constant(0.)
75  b = 2 * a
76  g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
77  ```
78
79  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
80  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
81  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
82  equivalent to:
83
84  ```python
85  a = tf.stop_gradient(tf.constant(0.))
86  b = tf.stop_gradient(2 * a)
87  g = tf.gradients(a + b, [a, b])
88  ```
89
90  `stop_gradients` provides a way of stopping gradient after the graph has
91  already been constructed, as compared to `tf.stop_gradient` which is used
92  during graph construction.  When the two approaches are combined,
93  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
94  `stop_gradients`, whichever is encountered first.
95
96  All integer tensors are considered constant with respect to all `xs`, as if
97  they were included in `stop_gradients`.
98
99  `unconnected_gradients` determines the value returned for each x in xs if it
100  is unconnected in the graph to ys. By default this is None to safeguard
101  against errors. Mathematically these gradients are zero which can be requested
102  using the `'zero'` option. `tf.UnconnectedGradients` provides the
103  following options and behaviors:
104
105  ```python
106  a = tf.ones([1, 2])
107  b = tf.ones([3, 1])
108  g1 = tf.gradients([b], [a], unconnected_gradients='none')
109  sess.run(g1)  # [None]
110
111  g2 = tf.gradients([b], [a], unconnected_gradients='zero')
112  sess.run(g2)  # [array([[0., 0.]], dtype=float32)]
113  ```
114
115  Let us take one practical example which comes during the back propogation
116  phase. This function is used to evaluate the derivatives of the cost function
117  with respect to Weights `Ws` and Biases `bs`. Below sample implementation
118  provides the exaplantion of what it is actually used for :
119
120  ```python
121  Ws = tf.constant(0.)
122  bs = 2 * Ws
123  cost = Ws + bs  # This is just an example. So, please ignore the formulas.
124  g = tf.gradients(cost, [Ws, bs])
125  dCost_dW, dCost_db = g
126  ```
127
128
129  Args:
130    ys: A `Tensor` or list of tensors to be differentiated.
131    xs: A `Tensor` or list of tensors to be used for differentiation.
132    grad_ys: Optional. A `Tensor` or list of tensors the same size as
133      `ys` and holding the gradients computed for each y in `ys`.
134    name: Optional name to use for grouping all the gradient ops together.
135      defaults to 'gradients'.
136    colocate_gradients_with_ops: If True, try colocating gradients with
137      the corresponding op.
138    gate_gradients: If True, add a tuple around the gradients returned
139      for an operations.  This avoids some race conditions.
140    aggregation_method: Specifies the method used to combine gradient terms.
141      Accepted values are constants defined in the class `AggregationMethod`.
142    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
143      through.
144    unconnected_gradients: Optional. Specifies the gradient value returned when
145      the given input tensors are unconnected. Accepted values are constants
146      defined in the class `tf.UnconnectedGradients` and the default value is
147      `none`.
148
149  Returns:
150    A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
151    for y in `ys` and for x in `xs`.
152
153  Raises:
154    LookupError: if one of the operations between `x` and `y` does not
155      have a registered gradient function.
156    ValueError: if the arguments are invalid.
157    RuntimeError: if called in Eager mode.
158
159  """
160  # Creating the gradient graph for control flow mutates Operations.
161  # _mutation_lock ensures a Session.run call cannot occur between creating and
162  # mutating new ops.
163  # pylint: disable=protected-access
164  with ops.get_default_graph()._mutation_lock():
165    return gradients_util._GradientsHelper(
166        ys, xs, grad_ys, name, colocate_gradients_with_ops,
167        gate_gradients, aggregation_method, stop_gradients,
168        unconnected_gradients)
169  # pylint: enable=protected-access
170
171
172@tf_export("gradients", v1=[])
173def gradients_v2(ys,  # pylint: disable=invalid-name
174                 xs,
175                 grad_ys=None,
176                 name="gradients",
177                 gate_gradients=False,
178                 aggregation_method=None,
179                 stop_gradients=None,
180                 unconnected_gradients=UnconnectedGradients.NONE):
181  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
182
183  `tf.gradients` is only valid in a graph context. In particular,
184  it is valid in the context of a `tf.function` wrapper, where code
185  is executing as a graph.
186
187  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
188  is a list of `Tensor`, holding the gradients received by the
189  `ys`. The list must be the same length as `ys`.
190
191  `gradients()` adds ops to the graph to output the derivatives of `ys` with
192  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
193  each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`.
194
195  `grad_ys` is a list of tensors of the same length as `ys` that holds
196  the initial gradients for each y in `ys`.  When `grad_ys` is None,
197  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
198  user can provide their own initial `grad_ys` to compute the
199  derivatives using a different initial gradient for each y (e.g., if
200  one wanted to weight the gradient differently for each value in
201  each y).
202
203  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
204  with respect to all `xs`. These tensors will not be backpropagated through,
205  as though they had been explicitly disconnected using `stop_gradient`.  Among
206  other things, this allows computation of partial derivatives as opposed to
207  total derivatives. For example:
208
209  >>> @tf.function
210  ... def example():
211  ...   a = tf.constant(0.)
212  ...   b = 2 * a
213  ...   return tf.gradients(a + b, [a, b], stop_gradients=[a, b])
214  >>> example()
215  [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
216  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
217
218  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
219  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
220  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
221  equivalent to:
222
223  >>> @tf.function
224  ... def example():
225  ...   a = tf.stop_gradient(tf.constant(0.))
226  ...   b = tf.stop_gradient(2 * a)
227  ...   return tf.gradients(a + b, [a, b])
228  >>> example()
229  [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
230  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
231
232  `stop_gradients` provides a way of stopping gradient after the graph has
233  already been constructed, as compared to `tf.stop_gradient` which is used
234  during graph construction.  When the two approaches are combined,
235  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
236  `stop_gradients`, whichever is encountered first.
237
238  All integer tensors are considered constant with respect to all `xs`, as if
239  they were included in `stop_gradients`.
240
241  `unconnected_gradients` determines the value returned for each x in xs if it
242  is unconnected in the graph to ys. By default this is None to safeguard
243  against errors. Mathematically these gradients are zero which can be requested
244  using the `'zero'` option. `tf.UnconnectedGradients` provides the
245  following options and behaviors:
246
247  >>> @tf.function
248  ... def example(use_zero):
249  ...   a = tf.ones([1, 2])
250  ...   b = tf.ones([3, 1])
251  ...   if use_zero:
252  ...     return tf.gradients([b], [a], unconnected_gradients='zero')
253  ...   else:
254  ...     return tf.gradients([b], [a], unconnected_gradients='none')
255  >>> example(False)
256  [None]
257  >>> example(True)
258  [<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 0.]], ...)>]
259
260  Let us take one practical example which comes during the back propogation
261  phase. This function is used to evaluate the derivatives of the cost function
262  with respect to Weights `Ws` and Biases `bs`. Below sample implementation
263  provides the exaplantion of what it is actually used for :
264
265  >>> @tf.function
266  ... def example():
267  ...   Ws = tf.constant(0.)
268  ...   bs = 2 * Ws
269  ...   cost = Ws + bs  # This is just an example. Please ignore the formulas.
270  ...   g = tf.gradients(cost, [Ws, bs])
271  ...   dCost_dW, dCost_db = g
272  ...   return dCost_dW, dCost_db
273  >>> example()
274  (<tf.Tensor: shape=(), dtype=float32, numpy=3.0>,
275  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
276
277  Args:
278    ys: A `Tensor` or list of tensors to be differentiated.
279    xs: A `Tensor` or list of tensors to be used for differentiation.
280    grad_ys: Optional. A `Tensor` or list of tensors the same size as
281      `ys` and holding the gradients computed for each y in `ys`.
282    name: Optional name to use for grouping all the gradient ops together.
283      defaults to 'gradients'.
284    gate_gradients: If True, add a tuple around the gradients returned
285      for an operations.  This avoids some race conditions.
286    aggregation_method: Specifies the method used to combine gradient terms.
287      Accepted values are constants defined in the class `AggregationMethod`.
288    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
289      through.
290    unconnected_gradients: Optional. Specifies the gradient value returned when
291      the given input tensors are unconnected. Accepted values are constants
292      defined in the class `tf.UnconnectedGradients` and the default value is
293      `none`.
294
295  Returns:
296    A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
297    for y in `ys` and for x in `xs`.
298
299  Raises:
300    LookupError: if one of the operations between `x` and `y` does not
301      have a registered gradient function.
302    ValueError: if the arguments are invalid.
303    RuntimeError: if called in Eager mode.
304
305  """
306  # Creating the gradient graph for control flow mutates Operations.
307  # _mutation_lock ensures a Session.run call cannot occur between creating and
308  # mutating new ops.
309  # pylint: disable=protected-access
310  with ops.get_default_graph()._mutation_lock():
311    return gradients_util._GradientsHelper(
312        ys, xs, grad_ys, name, True, gate_gradients,
313        aggregation_method, stop_gradients,
314        unconnected_gradients)
315  # pylint: enable=protected-access
316
317
318# TODO(vrv): Make this available when we want to make it public.
319def _hessian_vector_product(ys, xs, v):
320  """Multiply the Hessian of `ys` wrt `xs` by `v`.
321
322  This is an efficient construction that uses a backprop-like approach
323  to compute the product between the Hessian and another vector. The
324  Hessian is usually too large to be explicitly computed or even
325  represented, but this method allows us to at least multiply by it
326  for the same big-O cost as backprop.
327
328  Implicit Hessian-vector products are the main practical, scalable way
329  of using second derivatives with neural networks. They allow us to
330  do things like construct Krylov subspaces and approximate conjugate
331  gradient descent.
332
333  Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y,
334  x, v)` will return an expression that evaluates to the same values
335  as (A + A.T) `v`.
336
337  Args:
338    ys: A scalar value, or a tensor or list of tensors to be summed to
339        yield a scalar.
340    xs: A list of tensors that we should construct the Hessian over.
341    v: A list of tensors, with the same shapes as xs, that we want to
342       multiply by the Hessian.
343
344  Returns:
345    A list of tensors (or if the list would be length 1, a single tensor)
346    containing the product between the Hessian and `v`.
347
348  Raises:
349    ValueError: `xs` and `v` have different length.
350
351  """
352
353  # Validate the input
354  length = len(xs)
355  if len(v) != length:
356    raise ValueError("xs and v must have the same length.")
357
358  # First backprop
359  grads = gradients(ys, xs)
360
361  assert len(grads) == length
362  elemwise_products = [
363      math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem))
364      for grad_elem, v_elem in zip(grads, v)
365      if grad_elem is not None
366  ]
367
368  # Second backprop
369  return gradients(elemwise_products, xs)
370
371
372@tf_export(v1=["hessians"])
373def hessians(ys,
374             xs,
375             name="hessians",
376             colocate_gradients_with_ops=False,
377             gate_gradients=False,
378             aggregation_method=None):
379  """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
380
381  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
382  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
383  where each tensor is the Hessian of `sum(ys)`.
384
385  The Hessian is a matrix of second-order partial derivatives of a scalar
386  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
387
388  Args:
389    ys: A `Tensor` or list of tensors to be differentiated.
390    xs: A `Tensor` or list of tensors to be used for differentiation.
391    name: Optional name to use for grouping all the gradient ops together.
392      defaults to 'hessians'.
393    colocate_gradients_with_ops: See `gradients()` documentation for details.
394    gate_gradients: See `gradients()` documentation for details.
395    aggregation_method: See `gradients()` documentation for details.
396
397  Returns:
398    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.
399
400  Raises:
401    LookupError: if one of the operations between `xs` and `ys` does not
402      have a registered gradient function.
403  """
404  xs = gradients_util._AsList(xs)  # pylint: disable=protected-access
405  kwargs = {
406      "colocate_gradients_with_ops": colocate_gradients_with_ops,
407      "gate_gradients": gate_gradients,
408      "aggregation_method": aggregation_method
409  }
410  # Compute first-order derivatives and iterate for each x in xs.
411  hessians = []
412  _gradients = gradients(ys, xs, **kwargs)
413  for gradient, x in zip(_gradients, xs):
414    # change shape to one-dimension without graph branching
415    gradient = array_ops.reshape(gradient, [-1])
416
417    # Declare an iterator and tensor array loop variables for the gradients.
418    n = array_ops.size(x)
419    loop_vars = [
420        array_ops.constant(0, dtypes.int32),
421        tensor_array_ops.TensorArray(x.dtype, n)
422    ]
423    # Iterate over all elements of the gradient and compute second order
424    # derivatives.
425    _, hessian = control_flow_ops.while_loop(
426        lambda j, _: j < n,
427        lambda j, result: (j + 1,
428                           result.write(j, gradients(gradient[j], x)[0])),
429        loop_vars
430    )
431
432    _shape = array_ops.shape(x)
433    _reshaped_hessian = array_ops.reshape(hessian.stack(),
434                                          array_ops.concat((_shape, _shape), 0))
435    hessians.append(_reshaped_hessian)
436  return hessians
437
438
439@tf_export("hessians", v1=[])
440def HessiansV2(ys,
441               xs,
442               gate_gradients=False,
443               aggregation_method=None,
444               name="hessians"):
445  """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
446
447  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
448  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
449  where each tensor is the Hessian of `sum(ys)`.
450
451  The Hessian is a matrix of second-order partial derivatives of a scalar
452  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
453
454  Args:
455    ys: A `Tensor` or list of tensors to be differentiated.
456    xs: A `Tensor` or list of tensors to be used for differentiation.
457    gate_gradients: See `gradients()` documentation for details.
458    aggregation_method: See `gradients()` documentation for details.
459    name: Optional name to use for grouping all the gradient ops together.
460      defaults to 'hessians'.
461
462  Returns:
463    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.
464
465  Raises:
466    LookupError: if one of the operations between `xs` and `ys` does not
467      have a registered gradient function.
468  """
469  return hessians(
470      ys,
471      xs,
472      name=name,
473      colocate_gradients_with_ops=True,
474      gate_gradients=gate_gradients,
475      aggregation_method=aggregation_method)
476