1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Code for backpropagation using the tape utilities.""" 16 17# TODO(b/159343581): Properly support CompositeTensor in all functions in this 18# file. 19 20import functools 21import operator 22 23from tensorflow.python import pywrap_tfe 24from tensorflow.python.eager import backprop_util 25from tensorflow.python.eager import context 26from tensorflow.python.eager import execute 27from tensorflow.python.eager import imperative_grad 28from tensorflow.python.eager import tape 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import composite_tensor_gradient 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import indexed_slices 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.framework import type_spec 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import check_ops 40from tensorflow.python.ops import control_flow_util 41from tensorflow.python.ops import default_gradient 42from tensorflow.python.ops import gen_array_ops 43from tensorflow.python.ops import gen_math_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import resource_variable_ops 46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.util import _pywrap_utils 49from tensorflow.python.util import nest 50from tensorflow.python.util import tf_contextlib 51from tensorflow.python.util import tf_inspect 52from tensorflow.python.util import variable_utils 53from tensorflow.python.util.lazy_loader import LazyLoader 54from tensorflow.python.util.tf_export import tf_export 55 56 57# Note that we need to lazy load the following two modules to avoid creating 58# circular dependencies. 59# TODO(b/119775953): fix the circular dependencies. 60pfor_ops = LazyLoader( 61 "pfor_ops", globals(), 62 "tensorflow.python.ops.parallel_for.control_flow_ops") 63 64function = LazyLoader("function", globals(), 65 "tensorflow.python.eager.function") 66 67_op_attr_type_cache = {} 68 69 70def op_attr_type(op_type, attr_name): 71 try: 72 return _op_attr_type_cache[(op_type, attr_name)] 73 except KeyError: 74 context.ensure_initialized() 75 h = context.context()._handle # pylint: disable=protected-access 76 attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name) 77 _op_attr_type_cache[(op_type, attr_name)] = attr_type 78 return attr_type 79 80 81def make_attr(attr_type, value): 82 # pybind11 enums do not return the raw value like SWIG enums do. They are 83 # useful when comparing amongst each other but not direct integers as we are 84 # doing in most tests. 85 # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types 86 # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons 87 # from integer value to class. 88 if attr_type == int(pywrap_tfe.TF_ATTR_TYPE): 89 return dtypes.as_dtype(value) 90 if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]: 91 return [dtypes.as_dtype(v) for v in value] 92 if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE): 93 return tensor_shape.as_shape(value).as_proto() 94 if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]: 95 return [tensor_shape.as_shape(v).as_proto() for v in value] 96 return nest.map_structure( 97 lambda v: v.encode() if isinstance(v, str) else v, 98 value) 99 100 101class _MockOp(object): 102 """Pretends to be a tf.Operation for the gradient functions.""" 103 104 def __init__(self, attrs, inputs, outputs, typ, skip_input_indices): 105 self.attrs = attrs 106 self.inputs = inputs 107 self.outputs = outputs 108 self.type = typ 109 self.skip_input_indices = skip_input_indices 110 111 def get_attr(self, attr): 112 typ = op_attr_type(self.type, attr) 113 for i in range(0, len(self.attrs), 2): 114 if self.attrs[i] == attr: 115 return make_attr(typ, self.attrs[i + 1]) 116 raise KeyError(attr) 117 118 def _get_control_flow_context(self): 119 raise NotImplementedError( 120 "tf.GradientTape.gradients() does not support graph control flow " 121 "operations like tf.cond or tf.while at this time. Use tf.gradients() " 122 "instead. If you need this feature, please file a feature request at " 123 "https://github.com/tensorflow/tensorflow/issues/new" 124 ) 125 126 127def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, 128 out_grads, skip_input_indices, forward_pass_name_scope): 129 """Calls the gradient function of the op. 130 131 Args: 132 op_name: the name of the op to be differentiated. 133 attr_tuple: the attrs, as a tuple. 134 num_inputs: the number of inputs to the op. 135 inputs: inputs to the original operation. 136 outputs: outputs to the original operation. 137 out_grads: gradients of the operation wrt its outputs. 138 skip_input_indices: a tuple that is passed to the gradient function, 139 indicating which inputs to skip calculating the gradient for 140 forward_pass_name_scope: the namescope of the op in the forward pass. 141 142 Returns: 143 The gradients with respect to the inputs of the function, as a list. 144 """ 145 mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices) 146 grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access 147 if grad_fn is None: 148 return [None] * num_inputs 149 150 # This does not work with v1 TensorArrays. 151 if ops.executing_eagerly_outside_functions( 152 ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()): 153 gradient_name_scope = "gradient_tape/" 154 if forward_pass_name_scope: 155 gradient_name_scope += forward_pass_name_scope + "/" 156 with ops.name_scope(gradient_name_scope): 157 return grad_fn(mock_op, *out_grads) 158 else: 159 return grad_fn(mock_op, *out_grads) 160 161 162pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function) 163 164 165def _must_record_gradient(): 166 return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() 167 168 169@tf_export("__internal__.record_gradient", v1=[]) 170def record_gradient(op_name, inputs, attrs, outputs): 171 """Explicitly record the gradient for a given op. 172 173 Args: 174 op_name: The op name as listed in the `OpDef` for the op. 175 inputs: A list of tensor inputs to the op. 176 attrs: The op attributes as a flattened list of alternating attribute names 177 and attribute values. 178 outputs: A list of tensor outputs from the op. 179 """ 180 pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, outputs, 181 ops.get_name_scope()) 182 183 184execute.must_record_gradient = _must_record_gradient 185execute.record_gradient = record_gradient 186 187 188def implicit_val_and_grad(f): 189 """Returns a function which differentiates f with respect to variables. 190 191 The wrapped function returns the value and the gradient of f when called with 192 the same arguments. The gradient is with respect to all trainable TFE 193 variables accessed by `f`. 194 195 This function is useful when the exact set of variables to differentiate with 196 is not known ahead of time. 197 198 Example: 199 200 ```python 201 dense_layer = tf.compat.v1.layers.Dense(1) 202 def loss(x, y): 203 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 204 205 # Obtain the gradient function. 206 val_grad_fn = tfe.implicit_value_and_gradients(loss) 207 208 # Invoke the gradient function with concrete values of x and y. 209 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 210 y = tf.constant([[10.0], [20.0]]) 211 value, grads_and_vars = val_grad_fn(x, y) 212 print('Value of loss: %s' % value) 213 214 # Apply the gradients to Variables. 215 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 216 optimizer.apply_gradients(grads_and_vars) 217 ``` 218 219 Args: 220 f: function to be differentiated. If `f` returns a scalar, this scalar will 221 be differentiated. If `f` returns a tensor or list of tensors, by default 222 a scalar will be computed by adding all their values to produce a single 223 scalar. 224 225 Returns: 226 A function which, when called, returns a tuple pair. 227 Its first element is the value to which the function evaluates. 228 Its second element is list of (gradient, variable) pairs. 229 230 Raises: 231 ValueError: if `f` returns None. 232 """ 233 # TODO(cais): Remove calls to tf.constant() once the gradients functions 234 # accept lists and np.ndarrays. 235 236 def grad_fn(*args, **kwds): 237 """Computes the gradient of the wrapped function.""" 238 this_tape = tape.push_new_tape() 239 try: 240 end_node = f(*args, **kwds) 241 if end_node is None: 242 raise ValueError("Cannot differentiate a function that returns None; " 243 "did you forget to return a value from {}?".format( 244 f.__name__)) 245 finally: 246 tape.pop_tape(this_tape) 247 # Note: variables are returned in construction order. This ensures unique 248 # order across executions. 249 variables = this_tape.watched_variables() 250 if not variables: 251 raise ValueError("No trainable variables were accessed while the " 252 "function was being computed.") 253 254 sources = [v.handle for v in variables] 255 for s in sources: 256 if getattr(s, "is_packed", False): 257 raise ValueError( 258 "GradientTape.gradient is not supported on packed EagerTensors yet." 259 ) 260 grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), 261 sources) 262 return end_node, list(zip(grad, variables)) 263 264 return grad_fn 265 266 267def implicit_grad(f): 268 """Returns a function which differentiates f with respect to variables. 269 270 The wrapped function returns the gradient of f when called with the same 271 arguments. The gradient is with respect to all trainable TFE variables 272 accessed by `f`. 273 274 This function is useful when the exact set of variables to differentiate with 275 is not known ahead of time. 276 277 Example: 278 279 ```python 280 dense_layer = tf.compat.v1.layers.Dense(1) 281 def loss(x, y): 282 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 283 284 # Obtain the gradient function. 285 grad_fn = tfe.implicit_gradients(loss) 286 287 # Invoke the gradient function with concrete values of x and y. 288 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 289 y = tf.constant([[10.0], [20.0]]) 290 grads_and_vars = grad_fn(x, y) 291 292 # Apply the gradients to Variables. 293 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 294 optimizer.apply_gradients(grads_and_vars) 295 ``` 296 297 Args: 298 f: function to be differentiated. If `f` returns a scalar, this scalar will 299 be differentiated. If `f` returns a tensor or list of tensors, by default 300 a scalar will be computed by adding all their values to produce a single 301 scalar. 302 303 Returns: 304 A function which, when called, returns a list of (gradient, variable) pairs. 305 """ 306 # TODO(cais): Remove calls to tf.constant() once the gradients functions 307 # accept lists and np.ndarrays. 308 309 def grad_fn(*args, **kwds): 310 """Computes the gradient of the wrapped function.""" 311 return implicit_val_and_grad(f)(*args, **kwds)[1] 312 313 return grad_fn 314 315 316def _get_arg_spec(f, params, param_args): 317 """The positions of the parameters of f to be differentiated in param_args.""" 318 try: 319 args = tf_inspect.getfullargspec(f).args 320 except TypeError as e: 321 # TypeError can happen when f is a callable object. 322 if params is None: 323 return range(len(param_args)) 324 elif all(isinstance(x, int) for x in params): 325 return params 326 raise ValueError("Either callable provided is not a function or could not " 327 "inspect its arguments by name: %s. Original error: %s" 328 % (f, e)) 329 if params is None: 330 if not args: 331 return range(len(param_args)) 332 if args[0] == "self": 333 return range(len(args) - 1) 334 else: 335 return range(len(args)) 336 elif all(isinstance(x, str) for x in params): 337 return [args.index(n) for n in params] 338 elif all(isinstance(x, int) for x in params): 339 return params 340 else: 341 raise ValueError( 342 "params must be all strings or all integers; got %s." % params) 343 344 345def gradients_function(f, params=None): 346 """Returns a function which differentiates f with respect to params. 347 348 Example: 349 ```python 350 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 351 # Therefore, the 1st order derivatives are: 352 # df / dx = 3 * (x ^ 2) * y - y ^ 2 353 # df / dy = x ^ 3 - 2 * x * y 354 # The 2nd order derivatives with respect to x is: 355 # d^2 f / (dx)^2 = 6 * x * y 356 def f(x, y): 357 return x * x * x * y - x * y * y 358 359 # Obtain a function that returns 1st order gradients. 360 grad_fn = tfe.gradients_function(f) 361 362 x = 2.0 363 y = 3.0 364 365 # Invoke the 1st order gradient function. 366 x_grad, y_grad = grad_fn(x, y) 367 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 368 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 369 370 # Obtain a function that returns the 2nd order gradient with respect to x. 371 gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0]) 372 373 # Invoke the 2nd order gradient function. 374 x_gradgrad = gradgrad_fn(x, y)[0] 375 assert x_gradgrad.numpy() == 6 * 2 * 3 376 377 # To obtain a callable that returns the gradient(s) of `f` with respect to a 378 # subset of its inputs, use the `params` keyword argument with 379 # `gradients_function()`. 380 ygrad_fn = tfe.gradients_function(f, params=[1]) 381 382 (y_grad,) = ygrad_fn(x, y) 383 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 384 ``` 385 386 Note that only tensors with real or complex dtypes are differentiable. 387 388 Args: 389 f: function to be differentiated. If `f` returns a scalar, this scalar will 390 be differentiated. If `f` returns a tensor or list of tensors, by default 391 a scalar will be computed by adding all their values to produce a single 392 scalar. If desired, the tensors can be elementwise multiplied by the 393 tensors passed as the `dy` keyword argument to the returned gradient 394 function. 395 params: list of parameter names of f or list of integers indexing the 396 parameters with respect to which we'll differentiate. Passing None 397 differentiates with respect to all parameters. 398 399 Returns: 400 function which, when called, returns the value of f and the gradient 401 of `f` with respect to all of `params`. The function takes an extra optional 402 keyword argument `dy`. Setting it allows computation of vector jacobian 403 products for vectors other than the vector of ones. 404 405 Raises: 406 ValueError: if the params are not all strings or all integers. 407 """ 408 409 def decorated(*args, **kwds): 410 """Computes the gradient of the decorated function.""" 411 412 _, grad = val_and_grad_function(f, params=params)(*args, **kwds) 413 return grad 414 415 return decorated 416 417 418def _ensure_unique_tensor_objects(parameter_positions, args): 419 """Make each of the parameter_positions in args a unique ops.Tensor object. 420 421 Ensure that each parameter is treated independently. 422 For example: 423 424 def f(x, y): return x * y 425 g = gradients_function(f) 426 one = tf.constant(1.) 427 428 g(one, one) should return [1., 1.] 429 (even though the two arguments are the same Tensor object). 430 431 Args: 432 parameter_positions: List of indices into args defining the arguments to 433 differentiate against. 434 args: A list of arguments to the function to be differentiated. 435 436 Returns: 437 args, possibly edited in-place. 438 """ 439 s = set() 440 for (i, t) in enumerate(args): 441 if i in parameter_positions: 442 tid = ops.tensor_id(t) 443 if tid in s: 444 args[i] = gen_array_ops.identity(args[i]) 445 else: 446 s.add(tid) 447 return args 448 449 450def val_and_grad_function(f, params=None): 451 """Returns a function that computes f and its derivative w.r.t. params. 452 453 Example: 454 ```python 455 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 456 # Therefore, the 1st order derivatives are: 457 # df / dx = 3 * (x ^ 2) * y - y ^ 2 458 # df / dy = x ^ 3 - 2 * x * y 459 def f(x, y): 460 return x * x * x * y - x * y * y 461 462 # Obtain a function that returns the function value and the 1st order 463 # gradients. 464 val_grads_fn = tfe.value_and_gradients_function(f) 465 466 x = 2.0 467 y = 3.0 468 469 # Invoke the value-and-gradients function. 470 f_val, (x_grad, y_grad) = val_grads_fn(x, y) 471 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 472 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 473 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 474 475 # To obtain a callable that returns the value of `f` and the gradient(s) of 476 # `f` with respect to a subset of its inputs, use the `params` keyword 477 # argument with `value_and_gradients_function()`. 478 val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1]) 479 480 f_val, (y_grad,) = val_ygrad_fn(x, y) 481 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 482 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 483 ``` 484 485 Args: 486 f: function to be differentiated. If `f` returns a scalar, this scalar will 487 be differentiated. If `f` returns a tensor or list of tensors, by default 488 a scalar will be computed by adding all their values to produce a single 489 scalar. If desired, the tensors can be elementwise multiplied by the 490 tensors passed as the `dy` keyword argument to the returned gradient 491 function. 492 params: list of parameter names of f or list of integers indexing the 493 parameters with respect to which we'll differentiate. Passing `None` 494 differentiates with respect to all parameters. 495 496 Returns: 497 function which, when called, returns the value of f and the gradient 498 of f with respect to all of `params`. The function takes an extra optional 499 keyword argument "dy". Setting it allows computation of vector jacobian 500 products for vectors other than the vector of ones. 501 502 Raises: 503 ValueError: if the params are not all strings or all integers. 504 """ 505 506 def decorated(*args, **kwds): 507 """Computes the value and gradient of the decorated function.""" 508 dy = kwds.pop("dy", None) 509 if kwds: 510 raise ValueError("Functions to be differentiated cannot " 511 "receive keyword arguments.") 512 val, vjp = make_vjp(f, params)(*args, **kwds) 513 return val, vjp(dy=dy) 514 515 return decorated 516 517 518def make_vjp(f, params=None, persistent=True): 519 """Returns a function that computes f and its vjp w.r.t. 520 521 params. 522 523 The term "vjp" here is an abbreviation for vector-jacobian product. 524 525 Args: 526 f: the function to be differentiated. 527 params: the parameters (numbers or names) to differentiate with respect to. 528 A value of None will differentiate with respect to all parameters. 529 persistent: Boolean controlling whether the VJP function can be re-used. 530 Must be True or False. 531 532 Returns: 533 A function, which when called, returns a tuple (value, vjp), where: 534 - value is the result of calling f. 535 - vjp is a function, which takes a vector as an argument and 536 returns the product of that vector with the Jacobian of f. 537 Providing no argument to vjp is equivalent to providing a 538 vector of ones. 539 540 For example, 541 ```python 542 def f(x): 543 return x * x 544 545 wrapped_fn = tfe.make_vjp(f) 546 result, vjp = wrapped_fn(tf.constant(3.0)) 547 # result is 9.0 548 vjp() # the vjp function returns 6.0 549 550 Raises: 551 ValueError: if `f` returns None. 552 """ 553 554 def decorated(*args, **kwds): 555 """Computes the value and gradient of the decorated function.""" 556 parameter_positions = _get_arg_spec(f, params, args) 557 assert not kwds, "The gradient function can't take keyword arguments." 558 this_tape = tape.push_new_tape(persistent=persistent) 559 try: 560 sources = [] 561 args = [ 562 ops.convert_to_tensor(arg) if i in parameter_positions else arg 563 for i, arg in enumerate(args) 564 ] 565 args = _ensure_unique_tensor_objects(parameter_positions, args) 566 for i in parameter_positions: 567 if getattr(args[i], "is_packed", False): 568 raise ValueError( 569 "GradientTape.gradient is not supported on packed EagerTensors" 570 "yet.") 571 sources.append(args[i]) 572 tape.watch(this_tape, args[i]) 573 result = f(*args) 574 if result is None: 575 raise ValueError("Cannot differentiate a function that returns None; " 576 "did you forget to return a value from {}?".format( 577 f.__name__)) 578 flat_result = nest.flatten(result) 579 flat_result = [gen_array_ops.identity(x) for x in flat_result] 580 result = nest.pack_sequence_as(result, flat_result) 581 finally: 582 tape.pop_tape(this_tape) 583 def vjp(dy=None): 584 if dy is not None: 585 dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] 586 return imperative_grad.imperative_grad( 587 this_tape, nest.flatten(result), sources, output_gradients=dy) 588 589 return result, vjp 590 591 return decorated 592 593 594def flatten_nested_indexed_slices(grad): 595 assert isinstance(grad, indexed_slices.IndexedSlices) 596 if isinstance(grad.values, ops.Tensor): 597 return grad 598 else: 599 assert isinstance(grad.values, indexed_slices.IndexedSlices) 600 g = flatten_nested_indexed_slices(grad.values) 601 return indexed_slices.IndexedSlices( 602 g.values, array_ops.gather(grad.indices, g.indices), g.dense_shape) 603 604 605def aggregate_indexed_slices_gradients(grads): 606 """Aggregates gradients containing `IndexedSlices`s.""" 607 if len(grads) < 1: 608 return None 609 if len(grads) == 1: 610 return grads[0] 611 grads = [g for g in grads if g is not None] 612 # If any gradient is a `Tensor`, sum them up and return a dense tensor 613 # object. 614 if any(isinstance(g, ops.Tensor) for g in grads): 615 return math_ops.add_n(grads) 616 617 # The following `_as_indexed_slices_list` casts ids of IndexedSlices into 618 # int64. It is to make sure the inputs of `concat` all have same the data 619 # type. 620 grads = math_ops._as_indexed_slices_list(grads) # pylint: disable=protected-access 621 622 grads = [flatten_nested_indexed_slices(x) for x in grads] 623 # Form IndexedSlices out of the concatenated values and indices. 624 concat_grad = indexed_slices.IndexedSlices( 625 array_ops.concat([x.values for x in grads], axis=0), 626 array_ops.concat([x.indices for x in grads], axis=0), 627 grads[0].dense_shape) 628 629 return concat_grad 630 631 632def _aggregate_grads(gradients): 633 """Aggregate gradients from multiple sources. 634 635 Args: 636 gradients: A list of 'Tensor' or 'IndexedSlices' gradients. 637 638 Returns: 639 If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. 640 Otherwise returns an aggregated 'IndexedSlices'. 641 """ 642 assert gradients, "No gradients to aggregate" 643 644 if len(gradients) == 1: 645 return gradients[0] 646 if all(isinstance(g, ops.Tensor) for g in gradients): 647 return gen_math_ops.add_n(gradients) 648 else: 649 assert all( 650 isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) 651 for g in gradients) 652 return aggregate_indexed_slices_gradients(gradients) 653 654 655def _num_elements(grad): 656 """The number of elements in the `grad` tensor.""" 657 if isinstance(grad, ops.Tensor): 658 shape_tuple = grad._shape_tuple() # pylint: disable=protected-access 659 elif isinstance(grad, indexed_slices.IndexedSlices): 660 shape_tuple = grad.values._shape_tuple() # pylint: disable=protected-access 661 else: 662 raise ValueError("`grad` not a Tensor or IndexedSlices.") 663 if shape_tuple is None or None in shape_tuple: 664 return 0 665 return functools.reduce(operator.mul, shape_tuple, 1) 666 667 668def _fast_fill(value, shape, dtype): 669 return array_ops.fill( 670 constant_op.constant(shape, dtype=dtypes.int32), 671 constant_op.constant(value, dtype=dtype)) 672 673 674def _zeros(shape, dtype): 675 """Helper to return (possibly cached) zero tensors in eager mode.""" 676 # Note: variants will use _zeros_like 677 if dtype == dtypes.string or dtype == dtypes.resource: 678 return None 679 680 ctx = context.context() 681 if not ctx.executing_eagerly(): 682 return array_ops.zeros(shape, dtype) 683 684 device = ctx.device_name 685 686 if tensor_util.is_tf_type(shape): 687 shape_key = shape.ref() 688 else: 689 shape_key = shape 690 cache_key = shape_key, dtype, device 691 cached = ctx.zeros_cache().get(cache_key) 692 if cached is None: 693 if dtypes.as_dtype(dtype).is_bool: 694 value = False 695 else: 696 value = 0 697 cached = _fast_fill(value, shape, dtype) 698 ctx.zeros_cache().put(cache_key, cached) 699 return cached 700 701 702def _ones(shape, dtype): 703 as_dtype = dtypes.as_dtype(dtype) 704 if as_dtype == dtypes.string: 705 return None 706 707 if not context.executing_eagerly(): 708 return array_ops.ones(shape, dtype) 709 710 if as_dtype.is_bool: 711 value = True 712 else: 713 value = 1 714 715 if shape == (): # pylint: disable=g-explicit-bool-comparison 716 return constant_op.constant(value, dtype=dtype) 717 return _fast_fill(value, shape, dtype) 718 719 720_default_vspace = imperative_grad.VSpace( 721 num_elements_fn=_num_elements, 722 aggregate_fn=_aggregate_grads, 723 zeros_fn=_zeros, 724 ones_fn=_ones, 725 zeros_like_fn=default_gradient.zeros_like, 726 ones_like_fn=default_gradient.ones_like, 727 graph_shape_fn=gen_array_ops.shape) 728pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace) 729 730 731def _handle_or_self(x): 732 """Unwrap resource variable/ndarray to return tensors.""" 733 if resource_variable_ops.is_resource_variable(x): 734 return x.handle 735 return x 736 737 738def _extract_tensors_and_variables(tensor): 739 """Extracts tensors and variables from the input object.""" 740 for obj in nest.flatten(tensor): 741 if _pywrap_utils.IsTensor(obj) or _pywrap_utils.IsVariable(obj): 742 yield obj 743 elif isinstance(obj, composite_tensor.CompositeTensor): 744 components = type_spec.type_spec_from_value(obj)._to_components(obj) # pylint: disable=protected-access 745 yield from _extract_tensors_and_variables(components) 746 else: 747 raise ValueError(f"Passed in object {obj} of type {type(obj).__name__!r}" 748 f", not tf.Tensor or tf.Variable or ExtensionType.") 749 750 751@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"]) 752class GradientTape: 753 """Record operations for automatic differentiation. 754 755 Operations are recorded if they are executed within this context manager and 756 at least one of their inputs is being "watched". 757 758 Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, 759 where `trainable=True` is default in both cases) are automatically watched. 760 Tensors can be manually watched by invoking the `watch` method on this context 761 manager. 762 763 For example, consider the function `y = x * x`. The gradient at `x = 3.0` can 764 be computed as: 765 766 >>> x = tf.constant(3.0) 767 >>> with tf.GradientTape() as g: 768 ... g.watch(x) 769 ... y = x * x 770 >>> dy_dx = g.gradient(y, x) 771 >>> print(dy_dx) 772 tf.Tensor(6.0, shape=(), dtype=float32) 773 774 GradientTapes can be nested to compute higher-order derivatives. For example, 775 776 >>> x = tf.constant(5.0) 777 >>> with tf.GradientTape() as g: 778 ... g.watch(x) 779 ... with tf.GradientTape() as gg: 780 ... gg.watch(x) 781 ... y = x * x 782 ... dy_dx = gg.gradient(y, x) # dy_dx = 2 * x 783 >>> d2y_dx2 = g.gradient(dy_dx, x) # d2y_dx2 = 2 784 >>> print(dy_dx) 785 tf.Tensor(10.0, shape=(), dtype=float32) 786 >>> print(d2y_dx2) 787 tf.Tensor(2.0, shape=(), dtype=float32) 788 789 By default, the resources held by a GradientTape are released as soon as 790 GradientTape.gradient() method is called. To compute multiple gradients over 791 the same computation, create a persistent gradient tape. This allows multiple 792 calls to the gradient() method as resources are released when the tape object 793 is garbage collected. For example: 794 795 >>> x = tf.constant(3.0) 796 >>> with tf.GradientTape(persistent=True) as g: 797 ... g.watch(x) 798 ... y = x * x 799 ... z = y * y 800 >>> dz_dx = g.gradient(z, x) # (4*x^3 at x = 3) 801 >>> print(dz_dx) 802 tf.Tensor(108.0, shape=(), dtype=float32) 803 >>> dy_dx = g.gradient(y, x) 804 >>> print(dy_dx) 805 tf.Tensor(6.0, shape=(), dtype=float32) 806 807 By default GradientTape will automatically watch any trainable variables that 808 are accessed inside the context. If you want fine grained control over which 809 variables are watched you can disable automatic tracking by passing 810 `watch_accessed_variables=False` to the tape constructor: 811 812 >>> x = tf.Variable(2.0) 813 >>> w = tf.Variable(5.0) 814 >>> with tf.GradientTape( 815 ... watch_accessed_variables=False, persistent=True) as tape: 816 ... tape.watch(x) 817 ... y = x ** 2 # Gradients will be available for `x`. 818 ... z = w ** 3 # No gradients will be available as `w` isn't being watched. 819 >>> dy_dx = tape.gradient(y, x) 820 >>> print(dy_dx) 821 tf.Tensor(4.0, shape=(), dtype=float32) 822 >>> # No gradients will be available as `w` isn't being watched. 823 >>> dz_dw = tape.gradient(z, w) 824 >>> print(dz_dw) 825 None 826 827 Note that when using models you should ensure that your variables exist when 828 using `watch_accessed_variables=False`. Otherwise it's quite easy to make your 829 first iteration not have any gradients: 830 831 ```python 832 a = tf.keras.layers.Dense(32) 833 b = tf.keras.layers.Dense(32) 834 835 with tf.GradientTape(watch_accessed_variables=False) as tape: 836 tape.watch(a.variables) # Since `a.build` has not been called at this point 837 # `a.variables` will return an empty list and the 838 # tape will not be watching anything. 839 result = b(a(inputs)) 840 tape.gradient(result, a.variables) # The result of this computation will be 841 # a list of `None`s since a's variables 842 # are not being watched. 843 ``` 844 845 Note that only tensors with real or complex dtypes are differentiable. 846 """ 847 848 def __init__(self, persistent=False, watch_accessed_variables=True): 849 """Creates a new GradientTape. 850 851 Args: 852 persistent: Boolean controlling whether a persistent gradient tape 853 is created. False by default, which means at most one call can 854 be made to the gradient() method on this object. 855 watch_accessed_variables: Boolean controlling whether the tape will 856 automatically `watch` any (trainable) variables accessed while the tape 857 is active. Defaults to True meaning gradients can be requested from any 858 result computed in the tape derived from reading a trainable `Variable`. 859 If False users must explicitly `watch` any `Variable`s they want to 860 request gradients from. 861 """ 862 self._tape = None 863 self._persistent = persistent 864 self._watch_accessed_variables = watch_accessed_variables 865 self._watched_variables = () 866 self._recording = False 867 868 def __enter__(self): 869 """Enters a context inside which operations are recorded on this tape.""" 870 self._push_tape() 871 return self 872 873 def __exit__(self, typ, value, traceback): 874 """Exits the recording context, no further operations are traced.""" 875 if self._recording: 876 self._pop_tape() 877 878 def _push_tape(self): 879 """Pushes a new tape onto the tape stack.""" 880 if self._recording: 881 raise ValueError("Tape is still recording, This can happen if you try to " 882 "re-enter an already-active tape.") 883 if self._tape is None: 884 self._tape = tape.push_new_tape( 885 persistent=self._persistent, 886 watch_accessed_variables=self._watch_accessed_variables) 887 else: 888 tape.push_tape(self._tape) 889 self._recording = True 890 891 def _pop_tape(self): 892 if not self._recording: 893 raise ValueError("Tape is not recording.") 894 tape.pop_tape(self._tape) 895 self._recording = False 896 897 @tf_contextlib.contextmanager 898 def _ensure_recording(self): 899 """Ensures that this tape is recording.""" 900 if not self._recording: 901 try: 902 self._push_tape() 903 yield 904 finally: 905 self._pop_tape() 906 else: 907 yield 908 909 # TODO(b/209081027): Add a variable in composite tensor test case after 910 # variables become composite tensors. 911 def watch(self, tensor): 912 """Ensures that `tensor` is being traced by this tape. 913 914 Args: 915 tensor: a Tensor/Variable or list of Tensors/Variables. 916 917 Raises: 918 ValueError: if it encounters something that is not a tensor. 919 """ 920 for t in _extract_tensors_and_variables(tensor): 921 if not backprop_util.IsTrainable(t): 922 logging.log_first_n( 923 logging.WARN, "The dtype of the watched tensor must be " 924 "floating (e.g. tf.float32), got %r", 5, t.dtype) 925 if hasattr(t, "handle"): 926 # There are many variable-like objects, all of them currently have 927 # `handle` attribute that points to a tensor. If this changes, 928 # internals of watch_variable need to change as well. 929 tape.watch_variable(self._tape, t) 930 else: 931 tape.watch(self._tape, t) 932 933 @tf_contextlib.contextmanager 934 def stop_recording(self): 935 """Temporarily stops recording operations on this tape. 936 937 Operations executed while this context manager is active will not be 938 recorded on the tape. This is useful for reducing the memory used by tracing 939 all computations. 940 941 For example: 942 943 >>> x = tf.constant(4.0) 944 >>> with tf.GradientTape() as tape: 945 ... with tape.stop_recording(): 946 ... y = x ** 2 947 >>> dy_dx = tape.gradient(y, x) 948 >>> print(dy_dx) 949 None 950 951 Yields: 952 None 953 Raises: 954 RuntimeError: if the tape is not currently recording. 955 """ 956 if self._tape is None: 957 raise RuntimeError( 958 "Trying to stop recording a tape which is not recording.") 959 self._pop_tape() 960 try: 961 yield 962 finally: 963 self._push_tape() 964 965 def reset(self): 966 """Clears all information stored in this tape. 967 968 Equivalent to exiting and reentering the tape context manager with a new 969 tape. For example, the two following code blocks are equivalent: 970 971 ``` 972 with tf.GradientTape() as t: 973 loss = loss_fn() 974 with tf.GradientTape() as t: 975 loss += other_loss_fn() 976 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn 977 978 979 # The following is equivalent to the above 980 with tf.GradientTape() as t: 981 loss = loss_fn() 982 t.reset() 983 loss += other_loss_fn() 984 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn 985 ``` 986 987 This is useful if you don't want to exit the context manager for the tape, 988 or can't because the desired reset point is inside a control flow construct: 989 990 ``` 991 with tf.GradientTape() as t: 992 loss = ... 993 if loss > k: 994 t.reset() 995 ``` 996 """ 997 self._pop_tape() 998 self._tape = None 999 self._push_tape() 1000 1001 def watched_variables(self): 1002 """Returns variables watched by this tape in order of construction.""" 1003 if self._tape is not None: 1004 self._watched_variables = self._tape.watched_variables() 1005 return self._watched_variables 1006 1007 def gradient(self, 1008 target, 1009 sources, 1010 output_gradients=None, 1011 unconnected_gradients=UnconnectedGradients.NONE): 1012 """Computes the gradient using operations recorded in context of this tape. 1013 1014 Note: Unless you set `persistent=True` a GradientTape can only be used to 1015 compute one set of gradients (or jacobians). 1016 1017 In addition to Tensors, gradient also supports RaggedTensors. For example, 1018 1019 >>> x = tf.ragged.constant([[1.0, 2.0], [3.0]]) 1020 >>> with tf.GradientTape() as g: 1021 ... g.watch(x) 1022 ... y = x * x 1023 >>> g.gradient(y, x) 1024 <tf.RaggedTensor [[2.0, 4.0], [6.0]]> 1025 1026 Args: 1027 target: a list or nested structure of Tensors or Variables or 1028 CompositeTensors to be differentiated. 1029 sources: a list or nested structure of Tensors or Variables or 1030 CompositeTensors. `target` will be differentiated against elements in 1031 `sources`. 1032 output_gradients: a list of gradients, one for each differentiable 1033 element of target. Defaults to None. 1034 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1035 alters the value which will be returned if the target and sources are 1036 unconnected. The possible values and effects are detailed in 1037 'UnconnectedGradients' and it defaults to 'none'. 1038 1039 Returns: 1040 a list or nested structure of Tensors (or IndexedSlices, or None, or 1041 CompositeTensor), one for each element in `sources`. Returned structure 1042 is the same as the structure of `sources`. 1043 1044 Raises: 1045 RuntimeError: If called on a used, non-persistent tape. 1046 RuntimeError: If called inside the context of the tape. 1047 TypeError: If the target is a None object. 1048 ValueError: If the target is a variable or if unconnected gradients is 1049 called with an unknown value. 1050 """ 1051 if self._tape is None: 1052 raise RuntimeError("A non-persistent GradientTape can only be used to " 1053 "compute one set of gradients (or jacobians)") 1054 if self._recording: 1055 if not self._persistent: 1056 self._pop_tape() 1057 else: 1058 logging.log_first_n( 1059 logging.WARN, "Calling GradientTape.gradient on a persistent " 1060 "tape inside its context is significantly less " 1061 "efficient than calling it outside the context (it " 1062 "causes the gradient ops to be recorded on the " 1063 "tape, leading to increased CPU and memory usage). " 1064 "Only call GradientTape.gradient inside the " 1065 "context if you actually want to trace the " 1066 "gradient in order to compute higher order " 1067 "derivatives.", 1) 1068 1069 if target is None: 1070 raise TypeError("Argument `target` should be a list or nested structure" 1071 " of Tensors, Variables or CompositeTensors to be " 1072 "differentiated, but received None.") 1073 1074 flat_targets = [] 1075 for t in nest.flatten(target): 1076 flat_targets.append(_handle_or_self(t)) 1077 flat_targets = composite_tensor_gradient.get_flat_tensors_for_gradients( 1078 flat_targets) 1079 for t in flat_targets: 1080 if not backprop_util.IsTrainable(t): 1081 logging.vlog( 1082 1, "The dtype of the target tensor must be " 1083 "floating (e.g. tf.float32) when calling GradientTape.gradient, " 1084 "got %r", t.dtype) 1085 1086 flat_sources_raw = nest.flatten(sources) 1087 flat_sources = [] 1088 for t in flat_sources_raw: 1089 flat_sources.append(_handle_or_self(t)) 1090 flat_sources = composite_tensor_gradient.get_flat_tensors_for_gradients( 1091 flat_sources) 1092 for t in flat_sources: 1093 if not backprop_util.IsTrainable(t): 1094 logging.vlog( 1095 1, "The dtype of the source tensor must be " 1096 "floating (e.g. tf.float32) when calling GradientTape.gradient, " 1097 "got %r", t.dtype) 1098 if getattr(t, "is_packed", False): 1099 raise ValueError( 1100 "GradientTape.gradient is not supported on packed EagerTensors yet." 1101 ) 1102 1103 if output_gradients is not None: 1104 output_gradients = nest.flatten( 1105 variable_utils.convert_variables_to_tensors(output_gradients)) 1106 output_gradients = ( 1107 composite_tensor_gradient.get_flat_tensors_for_gradients( 1108 output_gradients)) 1109 output_gradients = [None if x is None else ops.convert_to_tensor(x) 1110 for x in output_gradients] 1111 1112 flat_grad = imperative_grad.imperative_grad( 1113 self._tape, 1114 flat_targets, 1115 flat_sources, 1116 output_gradients=output_gradients, 1117 sources_raw=flat_sources_raw, 1118 unconnected_gradients=unconnected_gradients) 1119 1120 if not self._persistent: 1121 # Keep track of watched variables before setting tape to None 1122 self._watched_variables = self._tape.watched_variables() 1123 self._tape = None 1124 1125 flat_sources_raw = nest.map_structure(_handle_or_self, flat_sources_raw) 1126 flat_grad = composite_tensor_gradient.replace_flat_tensors_for_gradients( 1127 flat_sources_raw, flat_grad) 1128 grad = nest.pack_sequence_as(sources, flat_grad) 1129 return grad 1130 1131 def jacobian(self, 1132 target, 1133 sources, 1134 unconnected_gradients=UnconnectedGradients.NONE, 1135 parallel_iterations=None, 1136 experimental_use_pfor=True): 1137 """Computes the jacobian using operations recorded in context of this tape. 1138 1139 Note: Unless you set `persistent=True` a GradientTape can only be used to 1140 compute one set of gradients (or jacobians). 1141 1142 Note: By default the jacobian implementation uses parallel for (pfor), which 1143 creates a tf.function under the hood for each jacobian call. For better 1144 performance, and to avoid recompilation and vectorization rewrites on each 1145 call, enclose GradientTape code in @tf.function. 1146 1147 See[wikipedia 1148 article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) 1149 for the definition of a Jacobian. 1150 1151 Example usage: 1152 1153 ```python 1154 with tf.GradientTape() as g: 1155 x = tf.constant([1.0, 2.0]) 1156 g.watch(x) 1157 y = x * x 1158 jacobian = g.jacobian(y, x) 1159 # jacobian value is [[2., 0.], [0., 4.]] 1160 ``` 1161 1162 Args: 1163 target: Tensor to be differentiated. 1164 sources: a list or nested structure of Tensors or Variables. `target` 1165 will be differentiated against elements in `sources`. 1166 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1167 alters the value which will be returned if the target and sources are 1168 unconnected. The possible values and effects are detailed in 1169 'UnconnectedGradients' and it defaults to 'none'. 1170 parallel_iterations: A knob to control how many iterations are dispatched 1171 in parallel. This knob can be used to control the total memory usage. 1172 experimental_use_pfor: If true, vectorizes the jacobian computation. Else 1173 falls back to a sequential while_loop. Vectorization can sometimes fail 1174 or lead to excessive memory usage. This option can be used to disable 1175 vectorization in such cases. 1176 1177 Returns: 1178 A list or nested structure of Tensors (or None), one for each element in 1179 `sources`. Returned structure is the same as the structure of `sources`. 1180 Note if any gradient is sparse (IndexedSlices), jacobian function 1181 currently makes it dense and returns a Tensor instead. This may change in 1182 the future. 1183 1184 1185 Raises: 1186 RuntimeError: If called on a used, non-persistent tape. 1187 RuntimeError: If called on a non-persistent tape with eager execution 1188 enabled and without enabling experimental_use_pfor. 1189 ValueError: If vectorization of jacobian computation fails. 1190 """ 1191 if self._tape is None: 1192 raise RuntimeError("A non-persistent GradientTape can only be used to " 1193 "compute one set of gradients (or jacobians)") 1194 1195 flat_sources = nest.flatten(sources) 1196 target_static_shape = target.shape 1197 target_shape = array_ops.shape(target) 1198 # Note that we push and pop the tape here and below. This is needed since we 1199 # need gradients through the enclosed operations. 1200 with self._ensure_recording(): 1201 target = array_ops.reshape(target, [-1]) 1202 1203 def loop_fn(i): 1204 with self._ensure_recording(): 1205 y = array_ops.gather(target, i) 1206 return self.gradient(y, flat_sources, 1207 unconnected_gradients=unconnected_gradients) 1208 1209 try: 1210 target_size = int(target.shape[0]) 1211 except TypeError: 1212 target_size = array_ops.shape(target)[0] 1213 1214 if experimental_use_pfor: 1215 try: 1216 output = pfor_ops.pfor(loop_fn, target_size, 1217 parallel_iterations=parallel_iterations) 1218 except ValueError as err: 1219 raise ValueError( 1220 "Encountered an exception while vectorizing the " 1221 "jacobian computation. Vectorization can be disabled by setting" 1222 " experimental_use_pfor to False.") from err 1223 else: 1224 if context.executing_eagerly() and not self._persistent: 1225 raise RuntimeError( 1226 "GradientTape must be created with persistent=True" 1227 " to compute the jacobian with eager execution enabled and with " 1228 " experimental_use_pfor set to False.") 1229 output = pfor_ops.for_loop( 1230 loop_fn, [target.dtype] * len(flat_sources), target_size, 1231 parallel_iterations=parallel_iterations) 1232 1233 for i, out in enumerate(output): 1234 if out is not None: 1235 new_shape = array_ops.concat( 1236 [target_shape, array_ops.shape(out)[1:]], axis=0) 1237 out = array_ops.reshape(out, new_shape) 1238 if context.executing_eagerly(): 1239 out.set_shape(target_static_shape.concatenate(flat_sources[i].shape)) 1240 output[i] = out 1241 1242 return nest.pack_sequence_as(sources, output) 1243 1244 def batch_jacobian(self, 1245 target, 1246 source, 1247 unconnected_gradients=UnconnectedGradients.NONE, 1248 parallel_iterations=None, 1249 experimental_use_pfor=True): 1250 """Computes and stacks per-example jacobians. 1251 1252 See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) 1253 for the definition of a Jacobian. This function is essentially an efficient 1254 implementation of the following: 1255 1256 `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`. 1257 1258 Note that compared to `GradientTape.jacobian` which computes gradient of 1259 each output value w.r.t each input value, this function is useful when 1260 `target[i,...]` is independent of `source[j,...]` for `j != i`. This 1261 assumption allows more efficient computation as compared to 1262 `GradientTape.jacobian`. The output, as well as intermediate activations, 1263 are lower dimensional and avoid a bunch of redundant zeros which would 1264 result in the jacobian computation given the independence assumption. 1265 1266 Note: Unless you set `persistent=True` a GradientTape can only be used to 1267 compute one set of gradients (or jacobians). 1268 1269 Note: By default the batch_jacobian implementation uses parallel for (pfor), 1270 which creates a tf.function under the hood for each batch_jacobian call. 1271 For better performance, and to avoid recompilation and vectorization 1272 rewrites on each call, enclose GradientTape code in @tf.function. 1273 1274 1275 Example usage: 1276 1277 ```python 1278 with tf.GradientTape() as g: 1279 x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32) 1280 g.watch(x) 1281 y = x * x 1282 batch_jacobian = g.batch_jacobian(y, x) 1283 # batch_jacobian is [[[2, 0], [0, 4]], [[6, 0], [0, 8]]] 1284 ``` 1285 1286 Args: 1287 target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n]. 1288 `target[i,...]` should only depend on `source[i,...]`. 1289 source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m]. 1290 unconnected_gradients: a value which can either hold 'none' or 'zero' and 1291 alters the value which will be returned if the target and sources are 1292 unconnected. The possible values and effects are detailed in 1293 'UnconnectedGradients' and it defaults to 'none'. 1294 parallel_iterations: A knob to control how many iterations are dispatched 1295 in parallel. This knob can be used to control the total memory usage. 1296 experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else 1297 uses a tf.while_loop. 1298 1299 Returns: 1300 A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]` 1301 is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked 1302 per-example jacobians. 1303 1304 Raises: 1305 RuntimeError: If called on a used, non-persistent tape. 1306 RuntimeError: If called on a non-persistent tape with eager execution 1307 enabled and without enabling experimental_use_pfor. 1308 ValueError: If vectorization of jacobian computation fails or if first 1309 dimension of `target` and `source` do not match. 1310 """ 1311 if self._tape is None: 1312 raise RuntimeError("A non-persistent GradientTape can only be used to" 1313 "compute one set of gradients (or jacobians)") 1314 target_shape = target.shape 1315 if target_shape.rank is None: 1316 dim = tensor_shape.Dimension(None) 1317 else: 1318 dim = target_shape.dims[0] 1319 if not (target_shape.with_rank_at_least(2) and 1320 source.shape.with_rank_at_least(2) and 1321 dim.is_compatible_with(source.shape[0])): 1322 raise ValueError( 1323 "Need first dimension of target shape (%s) and " 1324 "source shape (%s) to match." % (target.shape, source.shape)) 1325 if target_shape.is_fully_defined(): 1326 batch_size = int(target_shape[0]) 1327 target_row_size = target_shape.num_elements() // batch_size 1328 else: 1329 target_shape = array_ops.shape(target) 1330 batch_size = target_shape[0] 1331 target_row_size = array_ops.size(target) // batch_size 1332 source_shape = array_ops.shape(source) 1333 # Flatten target to 2-D. 1334 # Note that we push and pop the tape here and below. This is needed since we 1335 # need gradients through the enclosed operations. 1336 with self._ensure_recording(): 1337 with ops.control_dependencies( 1338 [check_ops.assert_equal(batch_size, source_shape[0])]): 1339 target = array_ops.reshape(target, [batch_size, target_row_size]) 1340 1341 run_once = False 1342 1343 def loop_fn(i): 1344 nonlocal run_once 1345 if run_once and not self._persistent: 1346 if parallel_iterations is not None: 1347 raise RuntimeError( 1348 "GradientTape must be created with persistent=True" 1349 " to compute the batch_jacobian with parallel_iterations.") 1350 else: 1351 raise RuntimeError( 1352 "GradientTape must be created with persistent=True" 1353 " to compute the batch_jacobian.") 1354 run_once = True 1355 1356 with self._ensure_recording(): 1357 y = array_ops.gather(target, i, axis=1) 1358 return self.gradient(y, source, 1359 unconnected_gradients=unconnected_gradients) 1360 1361 if experimental_use_pfor: 1362 try: 1363 output = pfor_ops.pfor(loop_fn, target_row_size, 1364 parallel_iterations=parallel_iterations) 1365 except ValueError as err: 1366 raise ValueError( 1367 "Encountered an exception while vectorizing the " 1368 "batch_jacobian computation. Vectorization can be disabled by " 1369 "setting experimental_use_pfor to False.") from err 1370 else: 1371 if context.executing_eagerly() and not self._persistent: 1372 raise RuntimeError( 1373 "GradientTape must be created with persistent=True" 1374 " to compute the batch_jacobian with eager execution enabled and " 1375 " with experimental_use_pfor set to False.") 1376 output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size, 1377 parallel_iterations=parallel_iterations) 1378 new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0) 1379 if output is None: 1380 # Note that this block is returning zeros when it could use `None` to 1381 # represent unconnected gradients. This is to maintain compatibility with 1382 # the previous behavior, which ignored `unconnected_gradients`. 1383 output = array_ops.zeros(new_shape, target.dtype) 1384 return output 1385 else: 1386 output = array_ops.reshape(output, 1387 [target_row_size, batch_size, -1]) 1388 output = array_ops.transpose(output, [1, 0, 2]) 1389 1390 output = array_ops.reshape(output, new_shape) 1391 return output 1392