1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_grad # pylint: disable=unused-import 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import check_ops # pylint: disable=unused-import 22from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import gradients_util 25from tensorflow.python.ops import image_grad # pylint: disable=unused-import 26from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import 27from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import 28from tensorflow.python.ops import logging_ops # pylint: disable=unused-import 29from tensorflow.python.ops import manip_grad # pylint: disable=unused-import 30from tensorflow.python.ops import math_grad # pylint: disable=unused-import 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import optional_grad # pylint: disable=unused-import 33from tensorflow.python.ops import random_grad # pylint: disable=unused-import 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 36from tensorflow.python.util.tf_export import tf_export 37 38 39@tf_export(v1=["gradients"]) 40def gradients(ys, 41 xs, 42 grad_ys=None, 43 name="gradients", 44 colocate_gradients_with_ops=False, 45 gate_gradients=False, 46 aggregation_method=None, 47 stop_gradients=None, 48 unconnected_gradients=UnconnectedGradients.NONE): 49 """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. 50 51 `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` 52 is a list of `Tensor`, holding the gradients received by the 53 `ys`. The list must be the same length as `ys`. 54 55 `gradients()` adds ops to the graph to output the derivatives of `ys` with 56 respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where 57 each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`. 58 59 `grad_ys` is a list of tensors of the same length as `ys` that holds 60 the initial gradients for each y in `ys`. When `grad_ys` is None, 61 we fill in a tensor of '1's of the shape of y for each y in `ys`. A 62 user can provide their own initial `grad_ys` to compute the 63 derivatives using a different initial gradient for each y (e.g., if 64 one wanted to weight the gradient differently for each value in 65 each y). 66 67 `stop_gradients` is a `Tensor` or a list of tensors to be considered constant 68 with respect to all `xs`. These tensors will not be backpropagated through, 69 as though they had been explicitly disconnected using `stop_gradient`. Among 70 other things, this allows computation of partial derivatives as opposed to 71 total derivatives. For example: 72 73 ```python 74 a = tf.constant(0.) 75 b = 2 * a 76 g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) 77 ``` 78 79 Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the 80 total derivatives `tf.gradients(a + b, [a, b])`, which take into account the 81 influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is 82 equivalent to: 83 84 ```python 85 a = tf.stop_gradient(tf.constant(0.)) 86 b = tf.stop_gradient(2 * a) 87 g = tf.gradients(a + b, [a, b]) 88 ``` 89 90 `stop_gradients` provides a way of stopping gradient after the graph has 91 already been constructed, as compared to `tf.stop_gradient` which is used 92 during graph construction. When the two approaches are combined, 93 backpropagation stops at both `tf.stop_gradient` nodes and nodes in 94 `stop_gradients`, whichever is encountered first. 95 96 All integer tensors are considered constant with respect to all `xs`, as if 97 they were included in `stop_gradients`. 98 99 `unconnected_gradients` determines the value returned for each x in xs if it 100 is unconnected in the graph to ys. By default this is None to safeguard 101 against errors. Mathematically these gradients are zero which can be requested 102 using the `'zero'` option. `tf.UnconnectedGradients` provides the 103 following options and behaviors: 104 105 ```python 106 a = tf.ones([1, 2]) 107 b = tf.ones([3, 1]) 108 g1 = tf.gradients([b], [a], unconnected_gradients='none') 109 sess.run(g1) # [None] 110 111 g2 = tf.gradients([b], [a], unconnected_gradients='zero') 112 sess.run(g2) # [array([[0., 0.]], dtype=float32)] 113 ``` 114 115 Let us take one practical example which comes during the back propogation 116 phase. This function is used to evaluate the derivatives of the cost function 117 with respect to Weights `Ws` and Biases `bs`. Below sample implementation 118 provides the exaplantion of what it is actually used for : 119 120 ```python 121 Ws = tf.constant(0.) 122 bs = 2 * Ws 123 cost = Ws + bs # This is just an example. So, please ignore the formulas. 124 g = tf.gradients(cost, [Ws, bs]) 125 dCost_dW, dCost_db = g 126 ``` 127 128 129 Args: 130 ys: A `Tensor` or list of tensors to be differentiated. 131 xs: A `Tensor` or list of tensors to be used for differentiation. 132 grad_ys: Optional. A `Tensor` or list of tensors the same size as 133 `ys` and holding the gradients computed for each y in `ys`. 134 name: Optional name to use for grouping all the gradient ops together. 135 defaults to 'gradients'. 136 colocate_gradients_with_ops: If True, try colocating gradients with 137 the corresponding op. 138 gate_gradients: If True, add a tuple around the gradients returned 139 for an operations. This avoids some race conditions. 140 aggregation_method: Specifies the method used to combine gradient terms. 141 Accepted values are constants defined in the class `AggregationMethod`. 142 stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate 143 through. 144 unconnected_gradients: Optional. Specifies the gradient value returned when 145 the given input tensors are unconnected. Accepted values are constants 146 defined in the class `tf.UnconnectedGradients` and the default value is 147 `none`. 148 149 Returns: 150 A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)` 151 for y in `ys` and for x in `xs`. 152 153 Raises: 154 LookupError: if one of the operations between `x` and `y` does not 155 have a registered gradient function. 156 ValueError: if the arguments are invalid. 157 RuntimeError: if called in Eager mode. 158 159 """ 160 # Creating the gradient graph for control flow mutates Operations. 161 # _mutation_lock ensures a Session.run call cannot occur between creating and 162 # mutating new ops. 163 # pylint: disable=protected-access 164 with ops.get_default_graph()._mutation_lock(): 165 return gradients_util._GradientsHelper( 166 ys, xs, grad_ys, name, colocate_gradients_with_ops, 167 gate_gradients, aggregation_method, stop_gradients, 168 unconnected_gradients) 169 # pylint: enable=protected-access 170 171 172@tf_export("gradients", v1=[]) 173def gradients_v2(ys, # pylint: disable=invalid-name 174 xs, 175 grad_ys=None, 176 name="gradients", 177 gate_gradients=False, 178 aggregation_method=None, 179 stop_gradients=None, 180 unconnected_gradients=UnconnectedGradients.NONE): 181 """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. 182 183 `tf.gradients` is only valid in a graph context. In particular, 184 it is valid in the context of a `tf.function` wrapper, where code 185 is executing as a graph. 186 187 `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` 188 is a list of `Tensor`, holding the gradients received by the 189 `ys`. The list must be the same length as `ys`. 190 191 `gradients()` adds ops to the graph to output the derivatives of `ys` with 192 respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where 193 each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`. 194 195 `grad_ys` is a list of tensors of the same length as `ys` that holds 196 the initial gradients for each y in `ys`. When `grad_ys` is None, 197 we fill in a tensor of '1's of the shape of y for each y in `ys`. A 198 user can provide their own initial `grad_ys` to compute the 199 derivatives using a different initial gradient for each y (e.g., if 200 one wanted to weight the gradient differently for each value in 201 each y). 202 203 `stop_gradients` is a `Tensor` or a list of tensors to be considered constant 204 with respect to all `xs`. These tensors will not be backpropagated through, 205 as though they had been explicitly disconnected using `stop_gradient`. Among 206 other things, this allows computation of partial derivatives as opposed to 207 total derivatives. For example: 208 209 >>> @tf.function 210 ... def example(): 211 ... a = tf.constant(0.) 212 ... b = 2 * a 213 ... return tf.gradients(a + b, [a, b], stop_gradients=[a, b]) 214 >>> example() 215 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 216 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 217 218 Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the 219 total derivatives `tf.gradients(a + b, [a, b])`, which take into account the 220 influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is 221 equivalent to: 222 223 >>> @tf.function 224 ... def example(): 225 ... a = tf.stop_gradient(tf.constant(0.)) 226 ... b = tf.stop_gradient(2 * a) 227 ... return tf.gradients(a + b, [a, b]) 228 >>> example() 229 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 230 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 231 232 `stop_gradients` provides a way of stopping gradient after the graph has 233 already been constructed, as compared to `tf.stop_gradient` which is used 234 during graph construction. When the two approaches are combined, 235 backpropagation stops at both `tf.stop_gradient` nodes and nodes in 236 `stop_gradients`, whichever is encountered first. 237 238 All integer tensors are considered constant with respect to all `xs`, as if 239 they were included in `stop_gradients`. 240 241 `unconnected_gradients` determines the value returned for each x in xs if it 242 is unconnected in the graph to ys. By default this is None to safeguard 243 against errors. Mathematically these gradients are zero which can be requested 244 using the `'zero'` option. `tf.UnconnectedGradients` provides the 245 following options and behaviors: 246 247 >>> @tf.function 248 ... def example(use_zero): 249 ... a = tf.ones([1, 2]) 250 ... b = tf.ones([3, 1]) 251 ... if use_zero: 252 ... return tf.gradients([b], [a], unconnected_gradients='zero') 253 ... else: 254 ... return tf.gradients([b], [a], unconnected_gradients='none') 255 >>> example(False) 256 [None] 257 >>> example(True) 258 [<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 0.]], ...)>] 259 260 Let us take one practical example which comes during the back propogation 261 phase. This function is used to evaluate the derivatives of the cost function 262 with respect to Weights `Ws` and Biases `bs`. Below sample implementation 263 provides the exaplantion of what it is actually used for : 264 265 >>> @tf.function 266 ... def example(): 267 ... Ws = tf.constant(0.) 268 ... bs = 2 * Ws 269 ... cost = Ws + bs # This is just an example. Please ignore the formulas. 270 ... g = tf.gradients(cost, [Ws, bs]) 271 ... dCost_dW, dCost_db = g 272 ... return dCost_dW, dCost_db 273 >>> example() 274 (<tf.Tensor: shape=(), dtype=float32, numpy=3.0>, 275 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) 276 277 Args: 278 ys: A `Tensor` or list of tensors to be differentiated. 279 xs: A `Tensor` or list of tensors to be used for differentiation. 280 grad_ys: Optional. A `Tensor` or list of tensors the same size as 281 `ys` and holding the gradients computed for each y in `ys`. 282 name: Optional name to use for grouping all the gradient ops together. 283 defaults to 'gradients'. 284 gate_gradients: If True, add a tuple around the gradients returned 285 for an operations. This avoids some race conditions. 286 aggregation_method: Specifies the method used to combine gradient terms. 287 Accepted values are constants defined in the class `AggregationMethod`. 288 stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate 289 through. 290 unconnected_gradients: Optional. Specifies the gradient value returned when 291 the given input tensors are unconnected. Accepted values are constants 292 defined in the class `tf.UnconnectedGradients` and the default value is 293 `none`. 294 295 Returns: 296 A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)` 297 for y in `ys` and for x in `xs`. 298 299 Raises: 300 LookupError: if one of the operations between `x` and `y` does not 301 have a registered gradient function. 302 ValueError: if the arguments are invalid. 303 RuntimeError: if called in Eager mode. 304 305 """ 306 # Creating the gradient graph for control flow mutates Operations. 307 # _mutation_lock ensures a Session.run call cannot occur between creating and 308 # mutating new ops. 309 # pylint: disable=protected-access 310 with ops.get_default_graph()._mutation_lock(): 311 return gradients_util._GradientsHelper( 312 ys, xs, grad_ys, name, True, gate_gradients, 313 aggregation_method, stop_gradients, 314 unconnected_gradients) 315 # pylint: enable=protected-access 316 317 318# TODO(vrv): Make this available when we want to make it public. 319def _hessian_vector_product(ys, xs, v): 320 """Multiply the Hessian of `ys` wrt `xs` by `v`. 321 322 This is an efficient construction that uses a backprop-like approach 323 to compute the product between the Hessian and another vector. The 324 Hessian is usually too large to be explicitly computed or even 325 represented, but this method allows us to at least multiply by it 326 for the same big-O cost as backprop. 327 328 Implicit Hessian-vector products are the main practical, scalable way 329 of using second derivatives with neural networks. They allow us to 330 do things like construct Krylov subspaces and approximate conjugate 331 gradient descent. 332 333 Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y, 334 x, v)` will return an expression that evaluates to the same values 335 as (A + A.T) `v`. 336 337 Args: 338 ys: A scalar value, or a tensor or list of tensors to be summed to 339 yield a scalar. 340 xs: A list of tensors that we should construct the Hessian over. 341 v: A list of tensors, with the same shapes as xs, that we want to 342 multiply by the Hessian. 343 344 Returns: 345 A list of tensors (or if the list would be length 1, a single tensor) 346 containing the product between the Hessian and `v`. 347 348 Raises: 349 ValueError: `xs` and `v` have different length. 350 351 """ 352 353 # Validate the input 354 length = len(xs) 355 if len(v) != length: 356 raise ValueError("xs and v must have the same length.") 357 358 # First backprop 359 grads = gradients(ys, xs) 360 361 assert len(grads) == length 362 elemwise_products = [ 363 math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem)) 364 for grad_elem, v_elem in zip(grads, v) 365 if grad_elem is not None 366 ] 367 368 # Second backprop 369 return gradients(elemwise_products, xs) 370 371 372@tf_export(v1=["hessians"]) 373def hessians(ys, 374 xs, 375 name="hessians", 376 colocate_gradients_with_ops=False, 377 gate_gradients=False, 378 aggregation_method=None): 379 """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. 380 381 `hessians()` adds ops to the graph to output the Hessian matrix of `ys` 382 with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` 383 where each tensor is the Hessian of `sum(ys)`. 384 385 The Hessian is a matrix of second-order partial derivatives of a scalar 386 tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details). 387 388 Args: 389 ys: A `Tensor` or list of tensors to be differentiated. 390 xs: A `Tensor` or list of tensors to be used for differentiation. 391 name: Optional name to use for grouping all the gradient ops together. 392 defaults to 'hessians'. 393 colocate_gradients_with_ops: See `gradients()` documentation for details. 394 gate_gradients: See `gradients()` documentation for details. 395 aggregation_method: See `gradients()` documentation for details. 396 397 Returns: 398 A list of Hessian matrices of `sum(ys)` for each `x` in `xs`. 399 400 Raises: 401 LookupError: if one of the operations between `xs` and `ys` does not 402 have a registered gradient function. 403 """ 404 xs = gradients_util._AsList(xs) # pylint: disable=protected-access 405 kwargs = { 406 "colocate_gradients_with_ops": colocate_gradients_with_ops, 407 "gate_gradients": gate_gradients, 408 "aggregation_method": aggregation_method 409 } 410 # Compute first-order derivatives and iterate for each x in xs. 411 hessians = [] 412 _gradients = gradients(ys, xs, **kwargs) 413 for gradient, x in zip(_gradients, xs): 414 # change shape to one-dimension without graph branching 415 gradient = array_ops.reshape(gradient, [-1]) 416 417 # Declare an iterator and tensor array loop variables for the gradients. 418 n = array_ops.size(x) 419 loop_vars = [ 420 array_ops.constant(0, dtypes.int32), 421 tensor_array_ops.TensorArray(x.dtype, n) 422 ] 423 # Iterate over all elements of the gradient and compute second order 424 # derivatives. 425 _, hessian = control_flow_ops.while_loop( 426 lambda j, _: j < n, 427 lambda j, result: (j + 1, 428 result.write(j, gradients(gradient[j], x)[0])), 429 loop_vars 430 ) 431 432 _shape = array_ops.shape(x) 433 _reshaped_hessian = array_ops.reshape(hessian.stack(), 434 array_ops.concat((_shape, _shape), 0)) 435 hessians.append(_reshaped_hessian) 436 return hessians 437 438 439@tf_export("hessians", v1=[]) 440def HessiansV2(ys, 441 xs, 442 gate_gradients=False, 443 aggregation_method=None, 444 name="hessians"): 445 """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. 446 447 `hessians()` adds ops to the graph to output the Hessian matrix of `ys` 448 with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` 449 where each tensor is the Hessian of `sum(ys)`. 450 451 The Hessian is a matrix of second-order partial derivatives of a scalar 452 tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details). 453 454 Args: 455 ys: A `Tensor` or list of tensors to be differentiated. 456 xs: A `Tensor` or list of tensors to be used for differentiation. 457 gate_gradients: See `gradients()` documentation for details. 458 aggregation_method: See `gradients()` documentation for details. 459 name: Optional name to use for grouping all the gradient ops together. 460 defaults to 'hessians'. 461 462 Returns: 463 A list of Hessian matrices of `sum(ys)` for each `x` in `xs`. 464 465 Raises: 466 LookupError: if one of the operations between `xs` and `ys` does not 467 have a registered gradient function. 468 """ 469 return hessians( 470 ys, 471 xs, 472 name=name, 473 colocate_gradients_with_ops=True, 474 gate_gradients=gate_gradients, 475 aggregation_method=aggregation_method) 476