xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/backprop_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import functools
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python import pywrap_tfe
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function
25from tensorflow.python.eager import tape as tape_lib
26from tensorflow.python.eager import test
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import test_util
35from tensorflow.python.framework.memory_checker import MemoryChecker
36from tensorflow.python.layers.pooling import max_pooling3d
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import custom_gradient
40from tensorflow.python.ops import embedding_ops
41from tensorflow.python.ops import functional_ops
42from tensorflow.python.ops import gradient_checker_v2
43from tensorflow.python.ops import gradients
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import nn
46from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
47from tensorflow.python.ops import nn_ops
48from tensorflow.python.ops import random_ops
49from tensorflow.python.ops import resource_variable_ops
50from tensorflow.python.ops import sparse_ops
51from tensorflow.python.ops import variables
52from tensorflow.python.training import training
53
54
55class BackpropTest(test.TestCase, parameterized.TestCase):
56
57  @test_util.run_in_graph_and_eager_modes
58  def testAggregateGradients(self):
59
60    def fn(x):
61      ind1 = constant_op.constant(np.array([0, 1]))
62      ind2 = constant_op.constant(np.array([2, 3]))
63      ind3 = constant_op.constant(np.array([1, 3]))
64      g1 = embedding_ops.embedding_lookup(x, ind1)
65      g2 = embedding_ops.embedding_lookup(x, ind2)
66      g3 = embedding_ops.embedding_lookup(x, ind3)
67      return g1 * g2 * g3
68
69    var_np = np.random.rand(4, 2).astype(np.float32)
70    var = constant_op.constant(var_np)
71    grad = backprop.gradients_function(fn, [0])(var)[0]
72    grad = self.evaluate(ops.convert_to_tensor(grad))
73
74    if not context.executing_eagerly():
75      tf_var = array_ops.constant(var_np, dtypes.float32)
76      tf_ind1 = array_ops.constant([0, 1])
77      tf_ind2 = array_ops.constant([2, 3])
78      tf_ind3 = array_ops.constant([1, 3])
79      tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
80      tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2)
81      tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3)
82      tf_y = tf_g1 * tf_g2 * tf_g3
83      tf_grad = gradients.gradients(tf_y, [tf_var])[0]
84
85      tf_dense_grad = math_ops.unsorted_segment_sum(tf_grad.values,
86                                                    tf_grad.indices,
87                                                    tf_grad.dense_shape[0])
88
89      self.assertAllClose(grad, self.evaluate(tf_dense_grad))
90
91  @test_util.run_in_graph_and_eager_modes
92  def testAggregateGradientsWithTensor(self):
93
94    def fn(x):
95      ind1 = constant_op.constant(np.array([0, 1]))
96      # A mixture of IndexedSlices and dense tensor to aggregate.
97      g1 = embedding_ops.embedding_lookup(x, ind1)
98      g2 = math_ops.reduce_sum(x * constant_op.constant(2.0))
99      return g1 * g2
100
101    var_np = np.random.rand(4, 2).astype(np.float32)
102    var = constant_op.constant(var_np)
103    grad = backprop.gradients_function(fn, [0])(var)[0]
104    grad = self.evaluate(ops.convert_to_tensor(grad))
105
106    if not context.executing_eagerly():
107      tf_var = array_ops.constant(var_np, dtypes.float32)
108      tf_ind1 = array_ops.constant([0, 1])
109      tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
110      tf_g2 = math_ops.reduce_sum(tf_var * 2.0, axis=(0, 1))
111      tf_y = tf_g1 * tf_g2
112      tf_grad = gradients.gradients(tf_y, [tf_var])[0]
113
114      self.assertAllClose(grad, tf_grad)
115
116  def testImplicitGradWithResourceVariable(self):
117    x = resource_variable_ops.ResourceVariable(
118        initial_value=constant_op.constant(1.0), name='x')
119
120    def fn():
121      b = constant_op.constant(2.0)
122      c = math_ops.add(x.value(), b)
123      return math_ops.add(c, constant_op.constant(3.0))
124
125    grads_and_vars = backprop.implicit_grad(fn)()
126    self.assertAllEqual(grads_and_vars[0][0], 1.0)
127    self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
128
129  @parameterized.named_parameters([('Function', def_function.function),
130                                   ('NoFunction', lambda f: f)])
131  def testNoOpBehaviorConsistent(self, decorator):
132
133    @decorator
134    def f(x):
135      # Test all different types of no-ops
136      x1 = array_ops.identity(x)
137      x2 = math_ops.add_v2(x, 0)
138      x3 = math_ops.subtract(x, 0)
139      x4 = math_ops.multiply(x, 1)
140      with backprop.GradientTape() as t:
141        t.watch(x)
142        t.watch(x1)
143        t.watch(x2)
144        t.watch(x3)
145        t.watch(x4)
146        y1 = x * 2.
147        y2 = x1 * 3.
148        y3 = x2 * 3.
149        y4 = x3 * 3.
150        y5 = x4 * 3.
151        loss = y1 + y2 + y3 + y4 + y5
152      return t.gradient(loss, [x, x1, x2, x3, x4])
153
154    self.assertAllClose([2., 3., 3., 3., 3.], f(constant_op.constant(10.)))
155
156  def testResourceHandleOutputWithoutHandleData(self):
157    # This is a bit of a weird thing to test since we try to maintain handle
158    # data. But users do create their own resources, and those often do not have
159    # any handle data.
160    h = resource_variable_ops.var_handle_op(
161        shape=[], dtype=dtypes.float32, shared_name='abc')
162
163    with backprop.GradientTape() as tape:
164      x = constant_op.constant(1.)
165      tape.watch(x)
166      tape.watch(h)
167      y, h = array_ops.identity_n([x, h])
168
169    self.assertAllClose(1., tape.gradient(y, x))
170
171  def testGradientInsideLoop(self):
172    with ops.Graph().as_default():
173      v = resource_variable_ops.ResourceVariable(1.0)
174
175      def body(_):
176        _ = v + 1.0  # This reads the variable inside the loop context
177        with backprop.GradientTape() as t:
178          result = v * 2
179        self.assertIsNotNone(t.gradient(result, v))
180        return 1.0
181
182      control_flow_ops.while_loop(lambda i: False, body, [1.0])
183
184  def testWhereGradient(self):
185    # Note: where is special because only some of its arguments are of
186    # differentiable dtypes.
187
188    def f(x):
189      return array_ops.where(x < 10, x, x * x)
190
191    g = backprop.gradients_function(f)
192
193    self.assertAllEqual(g(5.)[0], 1.0)
194    self.assertAllEqual(g(50.)[0], 100.0)
195
196  def testTwoTargets(self):
197    with backprop.GradientTape() as t:
198      x = constant_op.constant(3.0)
199      y = constant_op.constant(2.0)
200      t.watch([x, y])
201      xx = 2 * x
202      yy = 3 * y
203    dx, dy = t.gradient([xx, yy], [x, y])
204    self.assertAllEqual(dx, 2.0)
205    self.assertAllEqual(dy, 3.0)
206
207  def testCustomGradientEmptyError(self):
208
209    @custom_gradient.custom_gradient
210    def identity(x):
211
212      def grad(_):
213        return []  # This return value is wrong!
214
215      return x, grad
216
217    x = variables.Variable(1.0)
218    with backprop.GradientTape() as t:
219      y = identity(x)
220    with self.assertRaises(ValueError):
221      t.gradient(y, [x])
222
223  def test_stop_gradient_hides_downstream_ops(self):
224
225    @custom_gradient.custom_gradient
226    def _backward_pass_error(x):
227
228      def _grad(_):
229        raise AssertionError(
230            'Unexpectedly ran the backward function. This probably means that '
231            'tf.GradientTape is not properly ignoring tensors downstream of '
232            'tf.stop_gradient.')
233
234      return x, _grad
235
236    @def_function.function
237    def f(x):
238      return _backward_pass_error(x)
239
240    x = constant_op.constant(1.)
241    with backprop.GradientTape() as tape:
242      tape.watch(x)
243      y = f(array_ops.stop_gradient(x))
244
245    self.assertIsNone(tape.gradient(y, x))
246
247  def testOutputGradUsedInComputation(self):
248    with backprop.GradientTape() as t:
249      x = constant_op.constant(3.0)
250      y = constant_op.constant(2.0)
251      t.watch([x, y])
252      loss = x * y
253    dx, = t.gradient([loss, x], [x], output_gradients=[1.0, 2.0])
254    self.assertAllEqual(dx, 4.0)
255
256  def testDy(self):
257
258    def f(x):
259      return x
260
261    grad_fn = backprop.gradients_function(f)
262    self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
263
264  def testGradientInteger(self):
265
266    def f(x):
267      return x + x
268
269    int_tensor = constant_op.constant(1)
270    self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
271
272  def testErrors(self):
273
274    @custom_gradient.custom_gradient
275    def f(x):
276
277      def grad(_):
278        raise RuntimeError('x')
279
280      return x, grad
281
282    # TODO(apassos) raise the right error here
283    with self.assertRaises(RuntimeError):
284      backprop.gradients_function(f)(constant_op.constant(1.0))
285
286  def testGradientsFunctionInCustomGradient(self):
287
288    @custom_gradient.custom_gradient
289    def f(x):
290      (y,) = backprop.gradients_function(lambda x: x * x)(x)
291
292      def grad(dy):
293        return [2 * dy]
294
295      return y, grad
296
297    self.assertAllEqual(f(1.0), 2.0)
298
299  def testImplicitGradOverEmbeddingLookup(self):
300    batch_size = 8
301    embedding_size = 512
302    vocab_size = 1000
303    lrn_rate = 0.1
304    random_init = random_ops.random_uniform([vocab_size, embedding_size])
305
306    x = array_ops.ones((batch_size), dtypes.int64)
307    embedding = resource_variable_ops.ResourceVariable(
308        initial_value=random_init, dtype=dtypes.float32, name='embedding')
309
310    def f():
311      embedded_x = embedding_ops.embedding_lookup(embedding, x)
312      return constant_op.constant(1.0, dtypes.float32) - embedded_x
313
314    grad = backprop.implicit_grad(f)()[0][0]
315    opt = training.GradientDescentOptimizer(lrn_rate)
316
317    with ops.Graph().as_default(), self.cached_session():
318      tf_x = array_ops.ones((batch_size), dtypes.int64)
319      # TODO(ashankar,apassos): Change to ResourceVariable.
320      tf_embedding = variables.Variable(
321          random_init.numpy(), name='tf_embedding')
322      tf_embedded_x = embedding_ops.embedding_lookup(tf_embedding, tf_x)
323      tf_y = 1.0 - tf_embedded_x
324      tf_grad = gradients.gradients(tf_y, [tf_embedding])[0]
325      tf_opt = training.GradientDescentOptimizer(0.1)
326      tf_embedding.initializer.run()
327
328      self.assertAllClose(tf_grad.indices, grad.indices)
329      self.assertAllClose(tf_grad.values, grad.values)
330
331      tf_opt.apply_gradients([(tf_grad, tf_embedding)]).run()
332      expected = self.evaluate(tf_embedding)
333    opt.apply_gradients([(grad, embedding)])
334    self.assertAllClose(expected, embedding.read_value())
335
336  def testImplicitGradOrdering(self):
337    v0 = resource_variable_ops.ResourceVariable(1.0)
338    v1 = resource_variable_ops.ResourceVariable(2.0)
339
340    def f():
341      x = v1 * v1
342      y = v0 * v0
343      return x + y
344
345    grads = backprop.implicit_grad(f)()
346    ordered_variables = [x[1] for x in grads]
347    self.assertIs(ordered_variables[0], v0)
348    self.assertIs(ordered_variables[1], v1)
349
350  def testTapeNoOpGradient(self):
351    x = constant_op.constant(3.0)
352    with backprop.GradientTape() as t:
353      t.watch(x)
354      y = x
355    self.assertEqual(t.gradient(y, x).numpy(), 1.0)
356
357  def testTapeIdentityGradientIsIdentity(self):
358    x = constant_op.constant(3.0)
359    with backprop.GradientTape() as t:
360      t.watch(x)
361      y = array_ops.identity(x)
362    self.assertEqual(t.gradient(y, x).numpy(), 1.0)
363
364  def testFunctionIndexedSlicesGradient(self):
365
366    @def_function.function
367    def f(x):
368      return x + 1
369
370    with backprop.GradientTape() as t:
371      x = constant_op.constant([1.0])
372      t.watch(x)
373      y = f(x)
374      y = array_ops.gather(y, [0])
375    self.assertAllEqual(t.gradient(y, x), [1.0])
376
377  def testTapeGradientMultiTargetOneIsSource(self):
378    x = constant_op.constant(2.0)
379    with backprop.GradientTape() as t:
380      t.watch(x)
381      y = x * x
382    self.assertEqual(t.gradient([x, y], x).numpy(), 5.0)
383
384  def testTapeNoOpGradientWithMultiTargetAllSource(self):
385    x = constant_op.constant(3.0)
386    with backprop.GradientTape() as t:
387      t.watch(x)
388      y = x
389    self.assertEqual(t.gradient([y, y], x).numpy(), 2.0)
390
391  def testTapeNoOpGradientWithMultiTargetMultiSource(self):
392    x = constant_op.constant(3.0)
393    y = constant_op.constant(5.0)
394    with backprop.GradientTape() as t:
395      t.watch(x)
396      t.watch(y)
397      z = y * y
398    self.assertAllEqual(t.gradient([x, y, z], [x, y]), [1.0, 11.0])
399
400  def testTapeGradientStringTarget(self):
401    s = constant_op.constant('unknown', dtype=dtypes.string)
402    x = constant_op.constant(3.0)
403
404    with backprop.GradientTape() as t:
405      t.watch(x)
406      t.watch(s)
407    grads = t.gradient(s, x)
408    self.assertEqual(grads, None)
409
410  def testTapeNoOpGradientStringSourceAndTarget(self):
411    s = constant_op.constant('unknown', dtype=dtypes.string)
412
413    with backprop.GradientTape() as t:
414      t.watch(s)
415    grads = t.gradient(s, s)
416    self.assertEqual(grads, None)
417
418  def testTapeNoOpGradientWithMultiTargetMultiSourceIncludeString(self):
419    x = constant_op.constant(3.0)
420    y = constant_op.constant(5.0)
421    s = constant_op.constant('unknown', dtype=dtypes.string)
422
423    with backprop.GradientTape() as t:
424      t.watch(x)
425      t.watch(y)
426      t.watch(s)
427      z = y * y
428    grads = t.gradient([x, y, z, s], [x, y, s])
429    self.assertAllEqual(grads[:2], [1.0, 11.0])
430    self.assertEqual(grads[2], None)
431
432  def testTapeNoOpOnVariableIsIdentity(self):
433    v0 = resource_variable_ops.ResourceVariable(1.0)
434    with backprop.GradientTape() as t:
435      y = v0.read_value()
436    self.assertEqual(t.gradient(y, v0).numpy(), 1.0)
437
438  @test_util.assert_no_new_tensors
439  @test_util.assert_no_garbage_created
440  def testTapeNoOpGradient2By2(self):
441    a_2_by_2 = constant_op.constant(2.0, shape=[2, 2])
442    with backprop.GradientTape(persistent=True) as tape:
443      tape.watch(a_2_by_2)
444    dy_dy = tape.gradient(a_2_by_2, [a_2_by_2])[0]
445    self.assertAllEqual(dy_dy.numpy(),
446                        constant_op.constant(1.0, shape=[2, 2]).numpy())
447
448  @test_util.assert_no_new_pyobjects_executing_eagerly
449  def testTapeNoOpGradientMultiTarget2By2(self):
450    a_2_by_2 = constant_op.constant(2.0, shape=[2, 2])
451    with backprop.GradientTape(persistent=True) as tape:
452      tape.watch(a_2_by_2)
453    dy_dy = tape.gradient([a_2_by_2, a_2_by_2], [a_2_by_2])[0]
454    self.assertAllEqual(dy_dy.numpy(),
455                        constant_op.constant(2.0, shape=[2, 2]).numpy())
456
457  def testTapeStopRecording(self):
458    with backprop.GradientTape() as t:
459      x = resource_variable_ops.ResourceVariable(1.0)
460      with t.stop_recording():
461        y = x * x
462    self.assertEqual(t.gradient(y, x), None)
463
464  def testTapeStopStartRecording(self):
465    with backprop.GradientTape(persistent=True) as t:
466      x = resource_variable_ops.ResourceVariable(1.0)
467      x2 = x * 2  # This should be differentiated through.
468      with t.stop_recording():
469        y = x2 * x2
470      z = x2 * x2
471    self.assertEqual(t.gradient(y, x2), None)
472
473    # If the x*2 was not differentiated through, this would be 2.0, not 4.0
474    self.assertEqual(t.gradient(z, x2).numpy(), 4.0)
475
476  def testTapeReset(self):
477    with backprop.GradientTape() as t:
478      v = resource_variable_ops.ResourceVariable(1.0)
479      loss = v * v
480      t.reset()
481      loss += v * v
482    self.assertAllEqual(t.gradient(loss, v), 2.0)
483
484  def testPythonMax(self):
485    x = [
486        resource_variable_ops.ResourceVariable(2.),
487        resource_variable_ops.ResourceVariable(3.),
488        resource_variable_ops.ResourceVariable(5.)
489    ]
490    with backprop.GradientTape() as t:
491      f = max(x)
492    grad = t.gradient(f, x)
493    self.assertAllEqual(self.evaluate(f), 5.)
494    self.assertAllEqual(self.evaluate(grad), [None, None, 1.0])
495
496  def testAutomaticWatchedVariables(self):
497    with backprop.GradientTape() as t:
498      self.assertEqual(0, len(t.watched_variables()))
499      v = resource_variable_ops.ResourceVariable(1.0)
500      loss = v * v
501      self.assertAllEqual([v], t.watched_variables())
502
503      t.reset()
504      self.assertEqual(0, len(t.watched_variables()))
505      loss += v * v
506      self.assertAllEqual([v], t.watched_variables())
507
508  def testExplicitWatchedVariables(self):
509    with backprop.GradientTape() as t:
510      self.assertEqual(0, len(t.watched_variables()))
511      v = resource_variable_ops.ResourceVariable(1.0)
512      t.watch(v)
513      self.assertAllEqual([v], t.watched_variables())
514
515      t.reset()
516      self.assertEqual(0, len(t.watched_variables()))
517      t.watch(v)
518      self.assertAllEqual([v], t.watched_variables())
519
520  @test_util.assert_no_new_tensors
521  def testGradientNone(self):
522
523    def loss(x, l):
524      return math_ops.reduce_mean(
525          nn_ops.softmax_cross_entropy_with_logits(logits=x, labels=l),
526          constant_op.constant([0]))
527
528    logits = constant_op.constant([[0.0, 0.0]])
529    labels = constant_op.constant([[1.0, 0.0]])
530    # softmax_cross_entropy_with_logits returns two outputs and in this case the
531    # gradient wrt the second is None.
532    g, = backprop.gradients_function(loss, [0])(logits, labels)
533    self.assertAllEqual(g.numpy(), [[-0.5, 0.5]])
534
535  @test_util.run_in_graph_and_eager_modes
536  def testGradientWithinTapeBlock(self):
537    v1 = resource_variable_ops.ResourceVariable(1.)
538    self.evaluate(v1.initializer)
539    with backprop.GradientTape() as t:
540      loss = 2 * v1
541      grad = t.gradient(loss, v1)
542    self.assertAllEqual(self.evaluate(grad), 2.0)
543
544    with backprop.GradientTape(persistent=True) as t:
545      loss = 2 * v1
546      grad = t.gradient(loss, v1)
547    self.assertAllEqual(self.evaluate(grad), 2.0)
548
549  @test_util.run_in_graph_and_eager_modes
550  def testNestedSelfContexts(self):
551    v1 = resource_variable_ops.ResourceVariable(1.)
552    self.evaluate(v1.initializer)
553    with backprop.GradientTape() as t:
554      with self.assertRaises(ValueError):
555        with t:
556          pass
557
558  @test_util.assert_no_new_tensors
559  def testSecondGrad(self):
560
561    def first(x):
562      l = constant_op.constant([[0.0]])
563      x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=x)
564      x = math_ops.reduce_sum(x, constant_op.constant([0]))
565      return x
566
567    def second(x):
568      grad = backprop.gradients_function(first, [0])(x)[0]
569      return math_ops.reduce_sum(grad, constant_op.constant([0]))
570
571    f = constant_op.constant([[0.1]])
572    grad = backprop.gradients_function(second, [0])(f)[0]
573    self.assertAllEqual([[0.0]], grad)
574
575  @test_util.run_in_graph_and_eager_modes
576  def testWatchingIsTapeLocal(self):
577    x1 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
578    x2 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
579
580    with backprop.GradientTape() as tape1:
581      with backprop.GradientTape() as tape2:
582        tape1.watch(x1)
583        tape2.watch([x1, x2])
584        y = x1**3
585        z = x2**2
586        dy, dz = tape2.gradient([y, z], [x1, x2])
587      d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
588
589    self.evaluate([x1.initializer, x2.initializer])
590    self.assertEqual(self.evaluate(d2y), 12.0)
591    self.assertIsNone(d2z)
592
593  @test_util.assert_no_new_tensors
594  def testMakeVJP(self):
595
596    def f(x):
597      return x * x
598
599    wrapped_fn = backprop.make_vjp(f, persistent=False)
600    result, vjp = wrapped_fn(constant_op.constant(3.0))
601    self.assertAllEqual(result, 9.0)
602    self.assertAllEqual(vjp(2.0)[0], 12.0)
603
604  def testPersistentMakeVJP(self):
605
606    def f(x):
607      return x * x
608
609    wrapped_fn = backprop.make_vjp(f, persistent=True)
610    _, vjp = wrapped_fn(constant_op.constant(3.0))
611    vjp_result1 = vjp(2.0)[0]
612    vjp_result2 = vjp(2.0)[0]
613    self.assertAllEqual(vjp_result1, vjp_result2, 12.0)
614
615  @test_util.assert_no_new_tensors
616  def testGradGrad(self):
617
618    def sq(x):
619      return x * x
620
621    def grad(x):
622      value = backprop.gradients_function(sq, [0])(x)[0]
623      return value
624
625    gradgrad = backprop.gradients_function(grad, [0])
626
627    self.assertAllEqual(gradgrad(constant_op.constant(3.0))[0], 2.0)
628
629  @test_util.assert_no_new_tensors
630  def testGradGradExp(self):
631
632    def grad(x):
633      value = backprop.gradients_function(math_ops.exp, [0])(x)[0]
634      return value
635
636    gradgrad = backprop.gradients_function(grad, [0])
637
638    self.assertAllEqual(gradgrad(constant_op.constant(0.0))[0], 1.0)
639
640  @test_util.assert_no_new_tensors
641  def testStopGradient(self):
642    grad = backprop.gradients_function(
643        lambda x: array_ops.stop_gradient(math_ops.argmax(x)))
644    self.assertAllEqual(grad([0.0])[0], None)
645
646  @test_util.assert_no_new_tensors
647  def testArgmax(self):
648
649    def argmax(x):
650      i = math_ops.argmax(x)
651      return array_ops.stop_gradient(i)
652
653    grad = backprop.gradients_function(argmax)
654    self.assertAllEqual(grad([0.0])[0], None)
655
656  @test_util.run_gpu_only
657  @test_util.assert_no_new_tensors
658  def testGPU(self):
659
660    def fn(x):
661      with context.device('/gpu:0'):
662        b = constant_op.constant(2.0)
663        c = math_ops.add(x.gpu(), b)
664        # TODO(apassos): remove cpu below by making TensorVSPace aware
665        # of devices.
666        return math_ops.add(c, constant_op.constant(3.0)).cpu()
667
668    grad = backprop.gradients_function(fn, [0])(constant_op.constant(1.0))[0]
669    self.assertAllEqual(grad, 1.0)
670
671  @test_util.run_gpu_only
672  @test_util.assert_no_new_tensors
673  def testGPUImplicitGrad(self):
674    with context.device('gpu:0'):
675      v = resource_variable_ops.ResourceVariable(
676          constant_op.constant(1.0), name='v')
677
678    def f():
679      with context.device('gpu:0'):
680        return v.read_value()
681
682    self.assertEqual(backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
683
684  @test_util.assert_no_new_tensors
685  def testCPU(self):
686
687    def fn(x):
688      b = constant_op.constant(2.0)
689      c = math_ops.add(x, b)
690      return math_ops.add(c, constant_op.constant(3.0))
691
692    grad = backprop.gradients_function(fn, [0])(constant_op.constant(1.0))[0]
693    self.assertAllEqual(grad, 1.0)
694
695  @test_util.run_gpu_only
696  @test_util.assert_no_new_tensors
697  def testTensorCopyGPU2CPU2GPU(self):
698
699    def f(a, b):
700      return a.cpu() + b.cpu()
701
702    with context.device('/gpu:0'):
703      a = constant_op.constant(1.0)
704      b = constant_op.constant(2.0)
705
706    grad = backprop.gradients_function(f, [0])(a, b)[0]
707    self.assertAllEqual(grad, 1.0)
708
709  @test_util.assert_no_new_tensors
710  def testEmptyParams(self):
711
712    def fn(a, b):
713      return a * b
714
715    x = constant_op.constant(1.0)
716    y = constant_op.constant(2.0)
717    dx, dy = backprop.gradients_function(fn)(x, y)
718    self.assertAllEqual(dx, y.numpy())
719    self.assertAllEqual(dy, x.numpy())
720
721  @test_util.assert_no_new_tensors
722  def testUnconnectedNone(self):
723    v = resource_variable_ops.ResourceVariable(1.0, name='testUnconnectedNone')
724
725    def f():
726      v.read_value()
727      return constant_op.constant(1.0)
728
729    self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
730
731  @test_util.assert_no_new_tensors
732  def testGradientTapeReEnterContext(self):
733    g = backprop.GradientTape()
734    with g:
735      x = constant_op.constant(3.0)
736      g.watch(x)
737      y = 2 * x
738    with g:
739      z = 2 * y
740    grad = g.gradient(target=z, sources=[x])
741    self.assertEqual(self.evaluate(grad), [4.0])
742
743  @test_util.assert_no_new_tensors
744  @test_util.run_in_graph_and_eager_modes
745  def testGradientTapeRepeatedSource(self):
746    with backprop.GradientTape(persistent=False) as g:
747      x = constant_op.constant(3.0)
748      g.watch(x)
749      y = 2 * x
750    grad = g.gradient(target=y, sources=[x, x])
751    self.assertEqual(self.evaluate(grad), [2.0, 2.0])
752
753  @test_util.assert_no_new_tensors
754  @test_util.run_in_graph_and_eager_modes
755  def testPersistentGradientTapeRepeatedSource(self):
756    with backprop.GradientTape(persistent=True) as g:
757      x = constant_op.constant(3.0)
758      y = constant_op.constant(5.0)
759      g.watch(x)
760      g.watch(y)
761      z = x * x + x * y
762    grad = g.gradient(target=z, sources=[x, x])
763    self.assertEqual(self.evaluate(grad), [11.0, 11.0])
764    grad = g.gradient(target=z, sources=[y, x])
765    self.assertEqual(self.evaluate(grad), [3.0, 11.0])
766
767  @test_util.assert_no_new_tensors
768  @test_util.run_in_graph_and_eager_modes
769  def testGradientTapeStructure(self):
770    with backprop.GradientTape(persistent=True) as g:
771      # Using different constant values because constant tensors are
772      # cached, leading to a different gradient then what one might expect.
773      x1 = constant_op.constant(3.0)
774      x2 = constant_op.constant(3.1)
775      x3 = constant_op.constant(3.2)
776      g.watch(x1)
777      g.watch(x2)
778      g.watch(x3)
779      y = x1 + 2 * x2 + 3 * x3
780    self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
781    self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
782    self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
783    self.assertEqual(
784        self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])), [(1.0, 2.0),
785                                                             (2.0, 3.0)])
786    self.assertEqual(
787        self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
788        (1.0, 2.0, [1.0, 3.0]))
789    self.assertEqual(
790        self.evaluate(g.gradient(y, [x1, {
791            'x2': x2,
792            'x3': x3
793        }])), [1.0, {
794            'x2': 2.0,
795            'x3': 3.0
796        }])
797
798  @test_util.assert_no_new_tensors
799  @test_util.run_in_graph_and_eager_modes
800  def testGradientTape(self):
801    with backprop.GradientTape() as g:
802      x = constant_op.constant(3.0)
803      g.watch(x)
804      y = x * x
805      with backprop.GradientTape() as gg:
806        gg.watch(y)
807        z = 2 * y
808      inner_grad = gg.gradient(z, [y])[0]
809      self.assertEqual(self.evaluate(inner_grad), 2.0)
810      y += inner_grad
811    grad = g.gradient(y, [x])[0]
812    self.assertEqual(self.evaluate(grad), 6.0)
813
814  @test_util.assert_no_new_tensors
815  @test_util.run_in_graph_and_eager_modes
816  def testGadientTapeCalledOnConstantTarget(self):
817    with backprop.GradientTape() as g:
818      x = variables.Variable([3.0])
819      y = variables.Variable([2.0])
820    grad = g.gradient(x, y)
821    self.assertAllEqual(grad, None)
822
823  @test_util.run_in_graph_and_eager_modes
824  @test_util.run_v1_only('b/120545219')
825  def testGradientTapeWithCond(self):
826    x = constant_op.constant(3.0)
827
828    def true_fn():
829      return x
830
831    def false_fn():
832      return x * x
833
834    with backprop.GradientTape() as g:
835      g.watch(x)
836      y = control_flow_ops.cond(x < x, true_fn, false_fn)
837
838    if not context.executing_eagerly():
839      with self.assertRaisesRegex(NotImplementedError, 'tf.gradients'):
840        dy = g.gradient(y, [x])[0]
841    else:
842      dy = g.gradient(y, [x])[0]
843      self.assertEqual(self.evaluate(dy), 6.0)
844
845  @test_util.run_in_graph_and_eager_modes
846  @test_util.run_v1_only('b/120545219')
847  def testGradientTapeWithWhileLoop(self):
848    i = constant_op.constant(1)
849    x = constant_op.constant(2.)
850
851    def cond(i, _):
852      return i < 3
853
854    def body(i, x):
855      return i + 1, x * 2
856
857    with backprop.GradientTape() as g:
858      g.watch([x])
859      _, y = control_flow_ops.while_loop(cond, body, [i, x])
860
861    if not context.executing_eagerly():
862      with self.assertRaisesRegex(NotImplementedError, 'tf.gradients'):
863        dy = g.gradient(y, [x])[0]
864    else:
865      dy = g.gradient(y, [x])[0]
866      self.assertEqual(self.evaluate(dy), 4.0)
867
868  @test_util.assert_no_new_tensors
869  def testGradientTapeGradientCalledMultipleTimes(self):
870    with backprop.GradientTape() as g:
871      x = constant_op.constant(3.0)
872      g.watch(x)
873      y = x * x
874      z = y * y
875    g.gradient(z, [x])
876    with self.assertRaisesRegex(
877        RuntimeError, 'A non-persistent GradientTape can only'):
878      g.gradient(y, [x])
879
880  @test_util.assert_no_new_tensors
881  def testGradientTapeJacobianCalledMultipleTimes(self):
882    with backprop.GradientTape() as g:
883      x = constant_op.constant(3.0)
884      g.watch(x)
885      y = x * x
886      z = y * y
887    g.jacobian(z, [x])
888    with self.assertRaisesRegex(
889        RuntimeError, 'A non-persistent GradientTape can only'):
890      g.jacobian(y, [x])
891
892  @test_util.assert_no_new_tensors
893  def testJacobianInsideGradientTapeScope(self):
894    with backprop.GradientTape() as g:
895      x = constant_op.constant(3.0)
896      g.watch(x)
897      y = x * x
898      z = y * y
899      self.assertAllClose(4. * 3. ** 3., g.jacobian(z, x))
900
901  @test_util.assert_no_new_tensors
902  def testBatchJacobianInsideGradientTapeScope(self):
903    with backprop.GradientTape(persistent=True) as g:
904      x = constant_op.constant([[3.0]])
905      g.watch(x)
906      y = x * x
907      z = y * y
908      self.assertAllClose([[[4. * 3. ** 3.]]], g.batch_jacobian(z, x))
909
910  def testBatchJacobianParallelIterations(self):
911    @def_function.function
912    def f(persistent):
913      with backprop.GradientTape(persistent=persistent) as t:
914        x = constant_op.constant([[3.0]])
915        t.watch(x)
916        y = x * x
917        z = array_ops.tile(y * y, [1, 16])
918      return t.batch_jacobian(z, x, parallel_iterations=8)
919    with self.assertRaisesRegex(RuntimeError,
920                                'persistent=True.*parallel_iterations'):
921      f(persistent=False)
922    self.assertAllClose([[[4. * 3. ** 3.]] * 16], f(persistent=True))
923
924  @test_util.assert_no_new_tensors
925  def testGradientTapeBatchJacobianCalledMultipleTimes(self):
926    with backprop.GradientTape() as g:
927      x = constant_op.constant([[3.0]])
928      g.watch(x)
929      y = x * x
930      z = y * y
931    g.batch_jacobian(z, x)
932    with self.assertRaisesRegex(
933        RuntimeError, 'A non-persistent GradientTape can only'):
934      g.batch_jacobian(y, [x])
935
936  @test_util.assert_no_new_tensors
937  @test_util.run_in_graph_and_eager_modes
938  @test_util.run_v1_only('b/120545219')
939  def testPersistentTape(self):
940    with backprop.GradientTape(persistent=True) as g:
941      x = constant_op.constant(3.0)
942      g.watch(x)
943      y = x * x
944      z = y * y
945    dz_dx = g.gradient(z, [x])[0]
946    self.assertEqual(self.evaluate(dz_dx), 4 * 3 * 3 * 3)
947    dy_dx = g.gradient(y, [x])[0]
948    self.assertEqual(self.evaluate(dy_dx), 2 * 3)
949    del g
950
951  @test_util.assert_no_new_tensors
952  @test_util.run_in_graph_and_eager_modes
953  def testHigherOrderGradient(self):
954    with backprop.GradientTape(persistent=True) as g:
955      x = constant_op.constant(3.0)
956      g.watch(x)
957      y = x**3  # y       := x^3
958      dy_dx = g.gradient(y, x)  # dy/dx   := 3x^2
959      d2y_dx2 = g.gradient(dy_dx, x)  # d2y/dx2 := 6x
960    d3y_dx3 = g.gradient(d2y_dx2, x)  # d3y/dx3 := 6
961    x = 3
962    self.assertEqual(self.evaluate(y), x**3)
963    self.assertEqual(self.evaluate(dy_dx), 3 * x**2)
964    self.assertEqual(self.evaluate(d2y_dx2), 6 * x)
965    self.assertEqual(self.evaluate(d3y_dx3), 6)
966    del g
967
968  @test_util.assert_no_new_tensors
969  @test_util.run_in_graph_and_eager_modes
970  def testPersistentNestedTape(self):
971    with backprop.GradientTape(persistent=True) as g:
972      x = constant_op.constant(3.0)
973      g.watch(x)
974      y = x * x
975      with backprop.GradientTape(persistent=True) as gg:
976        gg.watch(y)
977        z = 2 * y
978      for _ in range(2):
979        inner_grad = gg.gradient(z, [y])[0]
980        self.assertEqual(self.evaluate(inner_grad), 2.0)
981      y += inner_grad
982      del gg
983    grad = g.gradient(y, [x])[0]
984    self.assertEqual(self.evaluate(grad), 6.0)
985    grad = g.gradient(z, [x])[0]
986    self.assertEqual(self.evaluate(grad), 12.0)
987    del g
988
989  @test_util.assert_no_new_tensors
990  @test_util.run_in_graph_and_eager_modes
991  def testGradientTapeVariable(self):
992    v = resource_variable_ops.ResourceVariable(1.0, name='v')
993    self.evaluate(v.initializer)
994    with backprop.GradientTape() as g:
995      y = v * v
996    grad = g.gradient(y, [v])[0]
997    self.assertAllEqual(self.evaluate(grad), 2.0)
998
999  @test_util.assert_no_new_tensors
1000  @test_util.run_in_graph_and_eager_modes
1001  def testNestedGradients(self):
1002    x = constant_op.constant(3.0)
1003    with backprop.GradientTape() as g:
1004      g.watch(x)
1005      y = x * x
1006      z = y * y
1007    dz_dx, dz_dy = g.gradient(z, [x, y])
1008    self.assertEqual(self.evaluate(dz_dx), 108.0)
1009    self.assertEqual(self.evaluate(dz_dy), 18.0)
1010
1011  @test_util.assert_no_new_tensors
1012  @test_util.run_in_graph_and_eager_modes
1013  def testUnconnectedGradientsDefault(self):
1014    x = constant_op.constant(1.0)
1015    y = constant_op.constant(3.0)
1016    with backprop.GradientTape() as g:
1017      g.watch([x, y])
1018      z = y * 2
1019    dz_dx = g.gradient(z, x)
1020    self.assertEqual(dz_dx, None)
1021
1022  @test_util.assert_no_new_tensors
1023  @test_util.run_in_graph_and_eager_modes
1024  def testUnconnectedGradientsZeros(self):
1025    x = constant_op.constant(1.0, shape=[2, 2])
1026    y = constant_op.constant(3.0)
1027    with backprop.GradientTape() as g:
1028      g.watch([x, y])
1029      z = y * 2
1030    dz_dx = g.gradient(z, x, unconnected_gradients='zero')
1031    self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(dz_dx))
1032
1033  @test_util.assert_no_new_tensors
1034  @test_util.run_in_graph_and_eager_modes
1035  def testUnconnectedGradientsVariablesZeros(self):
1036    x = resource_variable_ops.ResourceVariable(
1037        constant_op.constant(1., shape=[2, 2]))
1038    self.evaluate(x.initializer)
1039    y = resource_variable_ops.ResourceVariable(constant_op.constant(3.))
1040    self.evaluate(y.initializer)
1041    with backprop.GradientTape() as g:
1042      g.watch([x, y])
1043      z = y * 2
1044    dz_dx = g.gradient(z, x, unconnected_gradients='zero')
1045    self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(dz_dx))
1046
1047  @test_util.run_in_graph_and_eager_modes
1048  def testUnknownUnconnectedGradientsValueGiven(self):
1049    x = constant_op.constant(1.0)
1050    y = constant_op.constant(1.0)
1051    with backprop.GradientTape() as g:
1052      g.watch([x, y])
1053      z = y * 2
1054    with self.assertRaisesRegex(
1055        ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
1056      g.gradient(z, x, unconnected_gradients='nonsense')
1057
1058  @test_util.run_in_graph_and_eager_modes
1059  def testUnconnectedGradientsNestedDefunZeros(self):
1060
1061    @function.defun
1062    def f(x):
1063      return x * x
1064
1065    @function.defun
1066    def h(y):
1067      z = f(y)
1068      return array_ops.stop_gradient(z)
1069
1070    x = constant_op.constant(1.0)
1071    with backprop.GradientTape() as g:
1072      g.watch(x)
1073      k = x + 2.
1074      y = h(k)
1075
1076    dy_dx = g.gradient(y, x, unconnected_gradients='zero')
1077    self.assertEqual(0.0, self.evaluate(dy_dx))
1078
1079  def testInvalidRecordOperationMessage(self):
1080    y = constant_op.constant(2.)
1081    x = constant_op.constant(1.)
1082    with backprop.GradientTape() as g:
1083      g.watch(x)
1084      tape_lib.record_operation('InvalidBackprop', [y], [x], lambda dy: [])
1085    with self.assertRaisesRegex(errors_impl.InternalError,
1086                                'InvalidBackprop.*too few gradients'):
1087      g.gradient(y, x)
1088
1089  @test_util.assert_no_new_tensors
1090  def testEmptyParamsForValueAndGradFunction(self):
1091
1092    def fn(a, b):
1093      return a * b
1094
1095    val_and_grads_fn = backprop.val_and_grad_function(fn)
1096
1097    x = 2.0
1098    y = 3.0
1099    val, (dx, dy) = val_and_grads_fn(x, y)
1100    self.assertAllClose(val, x * y)
1101    self.assertAllEqual(dx, y)
1102    self.assertAllEqual(dy, x)
1103
1104  @test_util.assert_no_new_tensors
1105  def testNonEmptyParamsForValueAndGradFunction(self):
1106
1107    def fn(a, b):
1108      return a * b
1109
1110    val_and_grad_fn = backprop.val_and_grad_function(fn, params=[1])
1111
1112    x = 2.0
1113    y = 3.0
1114    val, grads = val_and_grad_fn(x, y)
1115    self.assertAllClose(val, x * y)
1116    self.assertEqual(1, len(grads))
1117    self.assertAllEqual(grads[0], x)
1118
1119  @test_util.run_gpu_only
1120  @test_util.assert_no_new_tensors
1121  def testTensorCopyCPU2GPU2CPU(self):
1122    # forward: a (cpu->gpu) -> add (gpu) -> c (gpu->cpu) -> add (cpu) -> e (cpu)
1123    # back: e (cpu) -> add (cpu) -> c (cpu->gpu) -> add (gpu) -> grad (gpu->cpu)
1124    def f(a, b):
1125      with context.device('/gpu:0'):
1126        c = math_ops.add(a.gpu(0), b.gpu(0))
1127      return math_ops.add(c.cpu(), constant_op.constant(3.0))
1128
1129    with context.device('/cpu:0'):
1130      a = constant_op.constant(1.0)
1131      b = constant_op.constant(2.0)
1132
1133    grad = backprop.gradients_function(f, [0])(a, b)[0]
1134    self.assertAllEqual(grad, 1.0)
1135
1136  def testGetAttrType(self):
1137    typ = backprop.op_attr_type('Add', 'T')
1138    self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE))
1139
1140  def testGetAttrList(self):
1141    typ = backprop.op_attr_type('MaxPool', 'ksize')
1142    self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)])
1143
1144  def testMakeAttrType(self):
1145    self.assertEqual(dtypes.float32,
1146                     backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1))
1147
1148  def testMakeAttrTypeList(self):
1149    self.assertEqual([dtypes.float32],
1150                     backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1]))
1151
1152  def testMakeAttrString(self):
1153    self.assertEqual(b'a',
1154                     backprop.make_attr(int(pywrap_tfe.TF_ATTR_STRING), 'a'))
1155
1156  def testMakeAttrStringList(self):
1157    self.assertEqual(
1158        [b'a'], backprop.make_attr([int(pywrap_tfe.TF_ATTR_STRING)], ['a']))
1159
1160  def testMulType(self):
1161
1162    def mul(x):
1163      return math_ops._mul_dispatch(x, x)  # pylint: disable=protected-access
1164
1165    self.assertAllEqual(backprop.gradients_function(mul)(3.0)[0].numpy(), 6.0)
1166
1167  def testMakeAttrShape(self):
1168    for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
1169      expected = tensor_shape.TensorShape(s).as_proto()
1170      actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s)
1171      self.assertEqual(
1172          expected,
1173          actual,
1174          msg=('For shape %r, expected %r != %r actual' %
1175               (s, expected, actual)))
1176
1177  def testMakeAttrShapeList(self):
1178    shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
1179    self.assertEqual(
1180        [tensor_shape.TensorShape(s).as_proto() for s in shape_list],
1181        backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list))
1182
1183  def testArgsGradientFunction(self):
1184
1185    def f(*args):
1186      return args[0] * args[0]
1187
1188    grad = backprop.gradients_function(f)
1189    self.assertAllEqual(grad(1.0)[0], 2.0)
1190
1191  def testPartial(self):
1192
1193    def f(x, y):
1194      return x * y
1195
1196    part = functools.partial(f, constant_op.constant(2.0))
1197    self.assertAllEqual(
1198        backprop.gradients_function(part)(constant_op.constant(1.0))[0], 2.0)
1199
1200  def testReturnSameThing(self):
1201
1202    def f(x):
1203      return x, 2 * x
1204
1205    self.assertAllEqual(backprop.gradients_function(f)(1.0)[0], 3.0)
1206
1207  @test_util.assert_no_new_tensors
1208  def testExceptionSafety(self):
1209
1210    def f(unused_x):
1211      raise ValueError()
1212
1213    try:
1214      backprop.gradients_function(f)(1.0)
1215    except ValueError:
1216      pass
1217
1218    def real_f(x):
1219      return x * x
1220
1221    self.assertAllEqual(backprop.gradients_function(real_f)(1.0)[0], 2.0)
1222
1223  @test_util.assert_no_new_tensors
1224  def testMultiValueConvertToTensor(self):
1225    x = resource_variable_ops.ResourceVariable(
1226        initial_value=array_ops.constant([1.0]), name='x')
1227
1228    def fn():
1229      a = math_ops.add(x.value(), 1.0)
1230      # Make sure convert_to_tensor works correctly with list of TensorNodes.
1231      b = array_ops.stack([a, a], axis=0)
1232      return math_ops.reduce_mean(b)
1233
1234    grad = backprop.implicit_grad(fn)()[0][0]
1235    self.assertAllEqual([1.0], grad)
1236
1237  def testOutput(self):
1238
1239    def multiout(x):
1240      return x + 2, x * x
1241
1242    x = constant_op.constant([0.0, 1.0, 2.0])
1243
1244    grad = backprop.gradients_function(multiout)(x)[0]
1245    self.assertAllEqual([1.0, 3.0, 5.0], grad)
1246
1247  def testMultiValuePreservesIfNotDiffedAgainst(self):
1248
1249    def tfe_conv2d(timage, tkernel, conv2dstrides):
1250      return nn_ops.conv2d(timage, tkernel, conv2dstrides, 'SAME')
1251
1252    i = constant_op.constant([[[[1.0]]]])
1253    k = constant_op.constant([[[[2.0]]]])
1254    s = [1, 1, 1, 1]
1255
1256    grad = backprop.gradients_function(tfe_conv2d, params=(0,))(i, k, s)[0]
1257    self.assertAllEqual([[[[2.0]]]], grad)
1258
1259  def testSameObjectForMultipleArguments(self):
1260
1261    def f(x, y):
1262      return math_ops.multiply(x, y)
1263
1264    g = backprop.gradients_function(f)
1265
1266    def np_g(x, y):
1267      dx, dy = g(x, y)
1268      return [dx.numpy(), dy.numpy()]
1269
1270    x = constant_op.constant(1.)
1271    self.assertAllEqual([1., 1.], np_g(x, x))
1272    x = 1.
1273    self.assertAllEqual([1., 1.], np_g(x, x))
1274    x = constant_op.constant([[1.]])
1275    self.assertAllEqual([[[1.]], [[1.]]], np_g(x, x))
1276    x = [[1.]]
1277    self.assertAllEqual([[[1.]], [[1.]]], np_g(x, x))
1278
1279    v = resource_variable_ops.ResourceVariable(
1280        initial_value=1., name='testSameObjectForMultipleArguments.Variable')
1281    self.assertAllEqual([1., 1.], np_g(v, v))
1282
1283  @test_util.assert_no_new_tensors
1284  def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
1285
1286    @custom_gradient.custom_gradient
1287    def my_square(x):
1288      result = math_ops.square(x)
1289
1290      def grad(dr):
1291        return 2 * dr * x + 1
1292
1293      return result, grad
1294
1295    x = resource_variable_ops.ResourceVariable(
1296        initial_value=3., name='X.' + self.id())
1297
1298    def f():
1299      return my_square(x)
1300
1301    g = backprop.implicit_grad(f)
1302
1303    grads_and_vars = g()
1304    self.assertEqual(1, len(grads_and_vars))
1305    grad, var = grads_and_vars[0]
1306    self.assertAllEqual(7, grad)
1307    self.assertAllEqual(x, var)
1308
1309  def testJacobianCustomGradient(self):
1310
1311    class MyCallable(object):
1312
1313      def __init__(self):
1314        self.a = variables.Variable(1.)
1315        self.b = variables.Variable(2.)
1316        self.c = variables.Variable(3.)
1317
1318      def __call__(self, x):
1319        return self.a * x * x + self.b * x + self.c
1320
1321    @def_function.function
1322    def call(c, x):
1323
1324      @custom_gradient.custom_gradient
1325      def _call():
1326        y = c(x)
1327
1328        def grad(dy, variables=None):  # pylint: disable=redefined-outer-name
1329          with backprop.GradientTape(persistent=True) as g:
1330            g.watch(variables)
1331            y = c(x)
1332          grad_vars = [
1333              2 * math_ops.reduce_sum(dy * g.jacobian(y, v)) for v in variables
1334          ]
1335          del g
1336          return (), grad_vars
1337
1338        return y, grad
1339
1340      return _call()
1341
1342    c = MyCallable()
1343    x = constant_op.constant([1., 2., 3.])
1344    with backprop.GradientTape(persistent=True) as g:
1345      g.watch([c.a, c.b, c.c])
1346      y = call(c, x)
1347    self.assertAllEqual(g.gradient(y, x), None)
1348
1349  @test_util.assert_no_new_tensors
1350  def testCustomGradient(self):
1351
1352    @custom_gradient.custom_gradient
1353    def my_mul(x, y):
1354      result = x * y
1355
1356      def grad(dr):
1357        return [dr * y, dr * x]
1358
1359      return result, grad
1360
1361    lr = 0.25
1362    x = resource_variable_ops.ResourceVariable(2., name='x')
1363
1364    def loss(x):
1365      return my_mul(2., x.read_value())
1366
1367    loss_grads_fn = backprop.implicit_val_and_grad(loss)
1368
1369    losses = []
1370    for _ in range(5):
1371      loss, grads_and_vars = loss_grads_fn(x)
1372      losses.append(loss.numpy())
1373      for (grad, var) in grads_and_vars:
1374        var.assign_sub(lr * grad)
1375    self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
1376
1377  @test_util.assert_no_new_tensors
1378  def testCustomGradientIdentity(self):
1379
1380    @custom_gradient.custom_gradient
1381    def my_identity(x):
1382
1383      def grad(dresult):
1384        return [2 * dresult]
1385
1386      return x, grad
1387
1388    self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
1389
1390  def testDifferentiatingFunctionThatReturnsNone(self):
1391
1392    def fn(x, y):
1393      result = x * y  # pylint: disable=unused-variable
1394
1395    x = constant_op.constant(1)
1396    y = constant_op.constant(2)
1397
1398    loss_grads_fn = backprop.implicit_val_and_grad(fn)
1399    with self.assertRaisesRegex(
1400        ValueError, 'Cannot differentiate a function that returns None; '
1401        'did you forget to return a value from fn?'):
1402      loss_grads_fn(x, y)
1403
1404    val_and_grads_fn = backprop.val_and_grad_function(fn)
1405    with self.assertRaisesRegex(
1406        ValueError, 'Cannot differentiate a function that returns None; '
1407        'did you forget to return a value from fn?'):
1408      val_and_grads_fn(x, y)
1409
1410  def testZerosCacheDoesntLeakAcrossGraphs(self):
1411    with ops.Graph().as_default():
1412
1413      def get_grad():
1414        with ops.Graph().as_default(), self.cached_session():
1415          t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
1416          x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
1417          with backprop.GradientTape() as tape:
1418            tape.watch(x)
1419            x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
1420            y1 = x1**2
1421            y = array_ops.concat([y1, t], axis=1)
1422          return self.evaluate(tape.gradient(y, x))
1423
1424      grad1 = get_grad()
1425      grad2 = get_grad()
1426
1427      self.assertAllEqual(grad1, grad2)
1428
1429  @test_util.run_in_graph_and_eager_modes
1430  def testSelectivelyWatchVariables(self):
1431    x1 = resource_variable_ops.ResourceVariable(1.0)
1432    x2 = resource_variable_ops.ResourceVariable(1.0)
1433    with backprop.GradientTape(watch_accessed_variables=False) as tape:
1434      tape.watch(x2)
1435      y = x1**2
1436      z = x2**3
1437    self.assertTupleEqual(tape.watched_variables(), (x2,))
1438    dy, dz = tape.gradient([y, z], [x1, x2])
1439    self.evaluate([x1.initializer, x2.initializer])
1440    self.assertIsNone(dy)
1441    self.assertEqual(self.evaluate(dz), 3.0)
1442
1443  @test_util.run_in_graph_and_eager_modes
1444  def testDifferentiatingScalarCache(self):
1445    # In the following test, if x2 = x1 (i.e the objects are the exact same),
1446    # then y is essentially, 2*x1, and dy/dx1 = 2.
1447    # When we had a pure scalar cache in eager, this would be the case. This
1448    # test prevents us from going back to that case.
1449    with backprop.GradientTape(persistent=False) as g:
1450      x1 = constant_op.constant(3.0)
1451      x2 = constant_op.constant(3.0)
1452      g.watch(x1)
1453      g.watch(x2)
1454      y = x1 + x2
1455    grad = g.gradient(target=y, sources=[x1])
1456    self.assertEqual(self.evaluate(grad), [1.0])
1457
1458  def testVariablesAndConstantsProduceTheSameGradients(self):
1459
1460    # In the following test, differentiating [y, z] against [a, b] gives:
1461    # (dy/da + dz/da, dy/db + dz/db).
1462    # If a and b are the same constant, dz/da will not be 0 (which it should
1463    # be).
1464    # This is solved by using variable since doing a read_value on a tensor will
1465    # produce a new tensor and corresponding TensorHandle, and not reuse the
1466    # same tensor (which would happen if we are using a cache and reusing
1467    # EagerTensor objects).
1468    def get_grads(a, b):
1469      with backprop.GradientTape() as tape:
1470        tape.watch([a, b])
1471        y = a**3
1472        z = b**2
1473      return tape.gradient([y, z], [a, b])
1474
1475    gradients_constants = get_grads(
1476        constant_op.constant(2.0), constant_op.constant(2.0))
1477    gradients_variables = get_grads(
1478        resource_variable_ops.ResourceVariable(2.0),
1479        resource_variable_ops.ResourceVariable(2.0))
1480    self.assertAllEqual(gradients_constants, gradients_variables)
1481
1482  def testUnknownShapes(self):
1483    with ops.Graph().as_default():
1484      with backprop.GradientTape() as tape:
1485        a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
1486        tape.watch(a)
1487        b = a**3
1488
1489      db_da = tape.gradient(b, a)
1490
1491      with self.cached_session() as sess:
1492        self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
1493
1494  @test_util.run_in_graph_and_eager_modes
1495  def testCustomGradientInEagerAndGraph(self):
1496
1497    @custom_gradient.custom_gradient
1498    def f(x):
1499      y = x * x
1500
1501      def grad(dy):
1502        return [4 * dy]
1503
1504      return y, grad
1505
1506    with backprop.GradientTape() as t:
1507      c = constant_op.constant(1.0)
1508      t.watch(c)
1509      g = f(c)
1510    self.assertAllEqual(self.evaluate(t.gradient(g, c)), 4.0)
1511
1512  def testOverrideSecondOrderWithCustomGradient(self):
1513
1514    @custom_gradient.custom_gradient
1515    def f(x):
1516
1517      def first_order_grad(dz):
1518
1519        @custom_gradient.custom_gradient
1520        def first_order_custom(unused_x):
1521
1522          def h(ddz):
1523            return -2.1 * ddz
1524
1525          return -1.1, h
1526
1527        return dz * first_order_custom(x)
1528
1529      return x + 10., first_order_grad
1530
1531    c = constant_op.constant(1.)
1532    with backprop.GradientTape() as outer:
1533      outer.watch(c)
1534      with backprop.GradientTape() as inner:
1535        inner.watch(c)
1536        d = f(c)**4.
1537      dd = inner.gradient(d, c)
1538      self.assertAllClose(4. * f(c)**3. * -1.1, dd)
1539    self.assertAllClose(3. * 4. * f(c)**2. * -1.1 * -1.1 + 4. * f(c)**3. * -2.1,
1540                        outer.gradient(dd, c))
1541
1542  @test_util.run_in_graph_and_eager_modes
1543  def testCustomGradientForwardprop(self):
1544
1545    @custom_gradient.custom_gradient
1546    def f(x):
1547      z = 2. * tensor_util.constant_value(x)
1548
1549      def g(dz):
1550
1551        @custom_gradient.custom_gradient
1552        def first_order(unused_x, unused_dz):
1553
1554          def second_order_and_transpose(unused_ddz):
1555            return 2.2, 3.1
1556
1557          return 2.1, second_order_and_transpose
1558
1559        return first_order(x, dz)
1560
1561      return z, g
1562
1563    with backprop.GradientTape(persistent=True) as t:
1564      with backprop.GradientTape() as tt:
1565        c = constant_op.constant(1.)
1566        t.watch(c)
1567        tt.watch(c)
1568        output_grad = array_ops.ones([])
1569        t.watch(output_grad)
1570        output = f(c)
1571        self.assertAllClose(2., output)
1572      gc = tt.gradient(output, c, output_gradients=output_grad)
1573      self.assertAllClose(2.1, gc)
1574    ggc = t.gradient(gc, c)
1575    self.assertAllClose(2.2, ggc)
1576    # Note that executed eagerly this kind of transpose is not efficient. But
1577    # from a tf.function we could prune out the first-order gradient
1578    # computation.
1579    transpose = t.gradient(gc, output_grad)
1580    self.assertAllClose(3.1, transpose)
1581
1582  @test_util.run_in_graph_and_eager_modes
1583  def testMaxPooling3DGradient(self):
1584
1585    def forward(a):
1586      r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
1587      return r
1588
1589    input_sizes = [1, 3, 2, 4, 1]
1590    pool_size = (2, 2, 1)
1591    strides = (1, 1, 1)
1592
1593    total_size = np.prod(input_sizes)
1594    x = np.arange(1, total_size + 1, dtype=np.float32)
1595    aa = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
1596    da = backprop.gradients_function(forward)(aa)
1597
1598    if not context.executing_eagerly():
1599      tf_aa = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
1600      tf_max = max_pooling3d(
1601          tf_aa, pool_size=pool_size, strides=strides, padding='SAME')
1602      tf_da = gradients.gradients(tf_max, [tf_aa])
1603      self.assertAllEqual(da[0], tf_da[0])
1604
1605  @test_util.run_in_graph_and_eager_modes
1606  def testWatchBadThing(self):
1607    g = backprop.GradientTape()
1608    with self.assertRaisesRegex(ValueError, 'ndarray'):
1609      g.watch(np.array(1.))
1610
1611  def testWatchComposite(self):
1612    """Test that tape.watch expands composites and watches component Tensors."""
1613    with backprop.GradientTape() as t:
1614      values = constant_op.constant([1.0, 2.0], dtypes.float32)
1615      s = sparse_tensor.SparseTensor(
1616          indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4])
1617      t.watch(s)
1618      z = sparse_ops.sparse_reduce_sum_v2(s)
1619    result = t.gradient(z, values)
1620    self.assertAllEqual(result, [1.0, 1.0])
1621
1622  def testWatchedVariablesAfterNonPersistentGradientCall(self):
1623    with backprop.GradientTape(persistent=False) as tape:
1624      x = resource_variable_ops.ResourceVariable(1.0)
1625      tape.watch(x)
1626    tape.gradient(x, x)
1627    self.assertEqual((x,), tape.watched_variables())
1628
1629  def testWatchedVariablesOnlyHasVariablesFromLastTape(self):
1630    with backprop.GradientTape(persistent=False) as tape:
1631      x = resource_variable_ops.ResourceVariable(1.0)
1632      tape.watch(x)
1633    with backprop.GradientTape(persistent=False) as tape:
1634      z = resource_variable_ops.ResourceVariable(2.0)
1635      tape.watch(z)
1636    tape.gradient(z, z)
1637    self.assertEqual((z,), tape.watched_variables())
1638
1639  def testWatchedVariablesRespectReset(self):
1640    with backprop.GradientTape(persistent=False) as tape:
1641      x = resource_variable_ops.ResourceVariable(1.0)
1642      tape.watch(x)
1643      self.assertEqual((x,), tape.watched_variables())
1644      tape.reset()
1645      z = resource_variable_ops.ResourceVariable(2.0)
1646      tape.watch(z)
1647      self.assertEqual((z,), tape.watched_variables())
1648    tape.gradient(z, z)
1649    self.assertEqual((z,), tape.watched_variables())
1650
1651  def testNameScope(self):
1652
1653    def fn(x):
1654      with ops.name_scope('my_scope'):
1655        a = math_ops.cos(x)
1656        b = math_ops.cos(x)
1657        return math_ops.add(a, b)
1658
1659    @function.defun
1660    def grad_fn(x):
1661      return backprop.gradients_function(fn)(x)
1662
1663    grad_ops = grad_fn.get_concrete_function(
1664        constant_op.constant(1.0)).graph.get_operations()
1665    num_sin_ops_found = 0
1666    for op in grad_ops:
1667      if op.type == 'Sin':
1668        num_sin_ops_found += 1
1669        self.assertIn('gradient_tape/my_scope/', op.name)
1670    self.assertEqual(num_sin_ops_found, 2)
1671
1672  @test_util.assert_no_new_pyobjects_executing_eagerly
1673  def testRecomputeGradWithDifferentShape(self):
1674
1675    @custom_gradient.recompute_grad
1676    def outer(x):
1677      return [x[0] + 1, x[1] + 1]
1678
1679    x = [
1680        variables.Variable([1.0, 2.0], name='a'),
1681        variables.Variable(1.0, name='b')
1682    ]
1683    with backprop.GradientTape():
1684      y = outer(x)
1685      self.assertAllEqual(y[0], [2.0, 3.0])
1686      self.assertAllEqual(y[1], 2.0)
1687
1688    @custom_gradient.recompute_grad
1689    def outer_dict(x):
1690      for key in x.keys():
1691        x[key] = x[key] + 1
1692      return x
1693
1694    x = {x[0].ref(): x[0], x[1].ref(): x[1]}
1695    with backprop.GradientTape():
1696      y = outer_dict(x)
1697      y = list(y.values())
1698      self.assertAllEqual(y[0], [2.0, 3.0])
1699      self.assertAllEqual(y[1], 2.0)
1700
1701  @test_util.assert_no_new_pyobjects_executing_eagerly
1702  def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
1703
1704    @custom_gradient.recompute_grad
1705    @def_function.function
1706    def outer(x):
1707
1708      @def_function.function
1709      def middle(y):
1710
1711        @def_function.function
1712        def inner(z):
1713          return z + 1
1714
1715        i = constant_op.constant(0.0)
1716        c = lambda y, i: i < 10.
1717        b = lambda y, i: (inner(y), i + 1.0)
1718        y, i = control_flow_ops.while_loop(c, b, [y, i])
1719
1720        return y
1721
1722      return middle(x)
1723
1724    with MemoryChecker() as memory_checker:
1725      for _ in range(5):
1726        x = variables.Variable(1.0, name='x')
1727        with backprop.GradientTape():
1728          y = outer(x)
1729          self.assertAllEqual(y, 11.0)
1730
1731    memory_checker.report()
1732    memory_checker.assert_no_leak_if_all_possibly_except_one()
1733
1734
1735class JacobianTest(test.TestCase):
1736
1737  def _jacobian(self, experimental_use_pfor):
1738    persistent = context.executing_eagerly and not experimental_use_pfor
1739    with backprop.GradientTape(persistent=persistent) as g:
1740      x = constant_op.constant([1., 2.])
1741      y = constant_op.constant([3., 4.])
1742      g.watch(x)
1743      g.watch(y)
1744      z = x * x * y
1745    jacobian = g.jacobian(
1746        z, [x, y], experimental_use_pfor=experimental_use_pfor)
1747    answer = [array_ops.diag(2 * x * y), array_ops.diag(x * x)]
1748    return jacobian, answer
1749
1750  @test_util.run_v1_only('b/120545219')
1751  def testPfor(self):
1752    jacobian, answer = self._jacobian(experimental_use_pfor=True)
1753    for j, a in zip(jacobian, answer):
1754      self.assertAllEqual(a, j)
1755
1756  @test_util.run_v1_only('b/120545219')
1757  def testWhileLoop(self):
1758    jacobian, answer = self._jacobian(experimental_use_pfor=False)
1759    for j, a in zip(jacobian, answer):
1760      self.assertAllEqual(a, j)
1761
1762  @test_util.run_v1_only('b/120545219')
1763  def testPforDefun(self):
1764
1765    @function.defun
1766    def _f():
1767      return self._jacobian(experimental_use_pfor=True)
1768
1769    jacobian, answer = _f()
1770    for j, a in zip(jacobian, answer):
1771      self.assertAllEqual(a, j)
1772
1773  @test_util.run_v1_only('b/120545219')
1774  def testWhileLoopDefun(self):
1775
1776    @function.defun
1777    def _f():
1778      return self._jacobian(experimental_use_pfor=False)
1779
1780    jacobian, answer = _f()
1781    for j, a in zip(jacobian, answer):
1782      self.assertAllEqual(a, j)
1783
1784  @test_util.run_v1_only('b/120545219')
1785  def testPersistentTape(self):
1786    if not context.executing_eagerly():
1787      return
1788    with backprop.GradientTape() as g:
1789      x = constant_op.constant([1.0, 2.0])
1790      g.watch(x)
1791      y = x * x
1792    with self.assertRaisesRegex(RuntimeError, 'persistent'):
1793      g.jacobian(y, x, experimental_use_pfor=False)
1794
1795  @test_util.run_v1_only('b/120545219')
1796  def test_parallel_iterations(self):
1797    with backprop.GradientTape(persistent=True) as g:
1798      x = constant_op.constant([[1., 2], [3, 4]])
1799      g.watch(x)
1800      y = math_ops.matmul(x, x)
1801    self.assertAllClose(
1802        g.jacobian(y, x, parallel_iterations=2),
1803        g.jacobian(y, x, parallel_iterations=3))
1804
1805  @test_util.run_in_graph_and_eager_modes
1806  def test_nested_jacobian(self):
1807    if context.executing_eagerly():
1808      # TODO(agarwal): b/128842926
1809      self.skipTest('Conversion of function calls not implemented yet.')
1810    x = array_ops.ones((10, 2))
1811    with backprop.GradientTape(persistent=False) as g:
1812      g.watch(x)
1813      with backprop.GradientTape(persistent=False) as gg:
1814        gg.watch(x)
1815        y = math_ops.reduce_sum(math_ops.square(x))
1816      dy_x = gg.jacobian(y, x)
1817    dy_xx = g.batch_jacobian(dy_x, x)
1818    dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
1819    self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
1820
1821  def test_nested_batch_jacobian_foldl(self):
1822    def _grad(f):
1823      def _grad_function(primal):
1824        with backprop.GradientTape() as tape:
1825          tape.watch(primal)
1826          primal_out = f(primal)
1827        return tape.batch_jacobian(primal_out, primal)
1828      return _grad_function
1829
1830    def _func(x):
1831      return array_ops.reshape(
1832          functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
1833                                  array_ops.transpose(x)),
1834          [1, 1])
1835
1836    f = _func
1837    x = constant_op.constant([[1., 2.]])
1838    for _ in range(2):
1839      theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
1840      self.assertAllClose(theoretical, numerical, rtol=1e-3)
1841      f = _grad(f)
1842      expected_flat = array_ops.reshape(numerical, [-1])
1843      self.assertAllClose(expected_flat,
1844                          array_ops.reshape(f(x), [-1]),
1845                          rtol=1e-3)
1846      self.assertAllClose(expected_flat,
1847                          array_ops.reshape(def_function.function(f)(x), [-1]),
1848                          rtol=1e-3)
1849
1850  def test_grad_jacobian_conv(self):
1851    def _inner(x):
1852      kernel = array_ops.ones([3, 3, 1, 9])
1853      with backprop.GradientTape() as tape:
1854        tape.watch(x)
1855        y = nn_ops.conv2d(x, kernel, strides=(1, 1), padding='SAME',
1856                          data_format='NHWC')
1857        reduced = math_ops.reduce_sum(y ** 2., axis=[2, 3])
1858      return math_ops.reduce_sum(tape.batch_jacobian(reduced, x))
1859
1860    theoretical, numerical = gradient_checker_v2.compute_gradient(
1861        def_function.function(_inner), [array_ops.ones([10, 4, 4, 1])])
1862    self.assertAllClose(numerical, theoretical, rtol=1e-1)
1863
1864    @def_function.function
1865    def _outer():
1866      with backprop.GradientTape() as tape:
1867        x = array_ops.ones([10, 4, 4, 1])
1868        tape.watch(x)
1869        y = _inner(x)
1870      return tape.gradient(y, x)
1871
1872    self.assertAllClose(array_ops.reshape(numerical, [-1]),
1873                        array_ops.reshape(_outer(), [-1]), rtol=1e-1)
1874
1875  @test_util.run_in_graph_and_eager_modes
1876  def test_indexed_slices(self):
1877    with backprop.GradientTape(persistent=True) as g:
1878      inp = random_ops.random_uniform([3, 2])
1879      g.watch(inp)
1880      output = nn.embedding_lookup(inp, [0, 2])
1881    self.assertAllClose(
1882        g.jacobian(output, inp, experimental_use_pfor=True),
1883        g.jacobian(output, inp, experimental_use_pfor=False))
1884
1885  def test_foldl_partial_function(self):
1886    x = array_ops.zeros([3])
1887    with backprop.GradientTape(persistent=True) as tape:
1888      tape.watch(x)
1889      result = def_function.function(
1890          functools.partial(functional_ops.foldl_v2, lambda a, b: a + b))(
1891              x)
1892    self.assertAllClose([1., 1., 1.],
1893                        tape.jacobian(result, x, experimental_use_pfor=True))
1894    self.assertAllClose([1., 1., 1.],
1895                        tape.jacobian(result, x, experimental_use_pfor=False))
1896
1897    # Non-persistent tapes take a different function gradient path, but also
1898    # work with pfor=True.
1899    x = array_ops.zeros([3])
1900    with backprop.GradientTape() as tape:
1901      tape.watch(x)
1902      result = def_function.function(
1903          functools.partial(functional_ops.foldl_v2, lambda a, b: a + b))(
1904              x)
1905    self.assertAllClose([1., 1., 1.],
1906                        tape.jacobian(result, x, experimental_use_pfor=True))
1907
1908  def test_foldl_pure_function(self):
1909
1910    @def_function.function
1911    def compute_jacobian(use_pfor):
1912      x = array_ops.zeros([3])
1913      with backprop.GradientTape(persistent=True) as tape:
1914        tape.watch(x)
1915        result = functools.partial(functional_ops.foldl_v2, lambda a, b: a + b)(
1916            x)
1917      return tape.jacobian(result, x, experimental_use_pfor=use_pfor)
1918
1919    self.assertAllClose(compute_jacobian(use_pfor=True),
1920                        compute_jacobian(use_pfor=False))
1921
1922  def test_cond_func_grad_jacobian(self):
1923
1924    @def_function.function
1925    def f(x):
1926      y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
1927      return y
1928
1929    with backprop.GradientTape(persistent=True) as tape:
1930      x = constant_op.constant(1.)
1931      tape.watch(x)
1932      y = f(x)
1933      grad = tape.gradient(y, x)
1934    self.assertAllClose(3., grad)
1935    jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
1936    self.assertAllClose(6., jacobian)
1937    jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
1938    self.assertAllClose(6., jacobian_pfor)
1939
1940
1941@test_util.run_all_in_graph_and_eager_modes
1942class BatchJacobianTest(test.TestCase, parameterized.TestCase):
1943
1944  def _batch_jacobian(self, experimental_use_pfor):
1945    persistent = context.executing_eagerly and not experimental_use_pfor
1946    with backprop.GradientTape(persistent=persistent) as g:
1947      x = constant_op.constant([[1., 2.], [3., 4.]])
1948      y = constant_op.constant([[3., 4.], [5., 6.]])
1949      g.watch(x)
1950      z = x * x * y
1951    batch_jacobian = g.batch_jacobian(
1952        z, x, experimental_use_pfor=experimental_use_pfor)
1953    answer = array_ops.stack(
1954        [array_ops.diag(2 * x[0] * y[0]),
1955         array_ops.diag(2 * x[1] * y[1])])
1956    return batch_jacobian, answer
1957
1958  def testPfor(self):
1959    batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=True)
1960    self.assertAllEqual(answer, batch_jacobian)
1961
1962  def testWhileLoop(self):
1963    batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=False)
1964    self.assertAllEqual(answer, batch_jacobian)
1965
1966  def testPforDefun(self):
1967
1968    @function.defun
1969    def _f():
1970      return self._batch_jacobian(experimental_use_pfor=True)
1971
1972    batch_jacobian, answer = _f()
1973    self.assertAllEqual(answer, batch_jacobian)
1974
1975  def testWhileLoopDefun(self):
1976
1977    @function.defun
1978    def _f():
1979      return self._batch_jacobian(experimental_use_pfor=False)
1980
1981    batch_jacobian, answer = _f()
1982    self.assertAllEqual(answer, batch_jacobian)
1983
1984  def testPersistentTape(self):
1985    if not context.executing_eagerly():
1986      return
1987    with backprop.GradientTape() as g:
1988      x = constant_op.constant([[1.0, 2.0]])
1989      g.watch(x)
1990      y = x * x
1991    with self.assertRaisesRegex(RuntimeError, 'persistent'):
1992      g.batch_jacobian(y, x, experimental_use_pfor=False)
1993
1994  def testBadShape(self):
1995    x = random_ops.random_uniform([2, 3])
1996    with backprop.GradientTape() as g:
1997      y = array_ops.concat([x, x], axis=0)
1998    with self.assertRaisesRegex(ValueError, 'Need first dimension'):
1999      g.batch_jacobian(y, x)
2000
2001  def testBadInputRank(self):
2002    x = random_ops.random_uniform([2])
2003    with backprop.GradientTape() as g:
2004      y = random_ops.random_uniform([2, 2])
2005    with self.assertRaisesRegex(ValueError, 'must have rank at least 2'):
2006      g.batch_jacobian(y, x)
2007
2008  def testBadOutputRank(self):
2009    x = random_ops.random_uniform([2, 2])
2010    with backprop.GradientTape() as g:
2011      y = random_ops.random_uniform([2])
2012    with self.assertRaisesRegex(ValueError, 'must have rank at least 2'):
2013      g.batch_jacobian(y, x)
2014
2015  def test_parallel_iterations(self):
2016    with backprop.GradientTape(persistent=True) as g:
2017      x = constant_op.constant([[1., 2], [3, 4]])
2018      g.watch(x)
2019      w = constant_op.constant([[1., 2, 3, 4], [5, 6, 7, 8]])
2020      y = math_ops.matmul(x, w)
2021    self.assertAllClose(
2022        g.batch_jacobian(y, x, parallel_iterations=2),
2023        g.batch_jacobian(y, x, parallel_iterations=3))
2024
2025  @parameterized.parameters((True, True), (True, False), (False, True),
2026                            (False, False))
2027  def test_degenerate_shape(self, use_function, use_pfor):
2028
2029    def f(x):
2030      with backprop.GradientTape(persistent=True) as tape:
2031        tape.watch(x)
2032        y = x**2
2033      return tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
2034
2035    if use_function:
2036      f = def_function.function(f)
2037    self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
2038
2039  @parameterized.parameters((True,), (False))
2040  def test_zeros_type_correct(self, use_pfor):
2041    for dtype in [dtypes.float32, dtypes.float64]:
2042      @def_function.function
2043      def f(x):
2044        del x
2045        return constant_op.constant([[1.]], dtype=dtype)  # pylint: disable=cell-var-from-loop
2046
2047      with backprop.GradientTape(persistent=True) as tape:
2048        x = constant_op.constant([[2.]], dtype=dtype)
2049        tape.watch(x)
2050        y = f(x)
2051      jac = tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
2052      self.assertEqual(dtype, jac.dtype)
2053      self.assertAllClose([[[0.]]], jac)
2054
2055      with backprop.GradientTape(persistent=True) as tape:
2056        x = constant_op.constant([[2.]], dtype=dtype)
2057        tape.watch(x)
2058        y = f(x)
2059      jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
2060                                experimental_use_pfor=use_pfor)
2061      self.assertEqual(dtype, jac.dtype)
2062      self.assertAllClose([[[0.]]], jac)
2063
2064  def test_strided_slice(self):
2065    x = array_ops.ones([2, 4, 2])
2066    length = constant_op.constant([2, 3, 4, 4], dtype=dtypes.int64)
2067    with backprop.GradientTape() as tape:
2068      tape.watch(x)
2069      y = array_ops.repeat(x, [2], axis=1)
2070      y = y[:, :math_ops.reduce_max(length), :]
2071    tape.batch_jacobian(y, x)
2072
2073
2074class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
2075
2076  def _assert_indexed_slices_equal(self, left, right):
2077    self.assertAllEqual(
2078        self.evaluate(ops.convert_to_tensor(left)),
2079        self.evaluate(ops.convert_to_tensor(right)))
2080
2081  def testNoGradients(self):
2082    self.assertIsNone(backprop.aggregate_indexed_slices_gradients([]))
2083
2084  def testOneGradient(self):
2085    t = math_ops._as_indexed_slices(
2086        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
2087    result = backprop.aggregate_indexed_slices_gradients([t])
2088    self._assert_indexed_slices_equal(t, result)
2089
2090  def testMultipleGradients(self):
2091    t0 = math_ops._as_indexed_slices(
2092        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
2093    t1 = math_ops._as_indexed_slices(
2094        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
2095    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
2096    result = backprop.aggregate_indexed_slices_gradients([t0, t1])
2097    self._assert_indexed_slices_equal(total, result)
2098
2099  def testMultipleGradientsWithNones(self):
2100    t0 = math_ops._as_indexed_slices(
2101        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
2102    t1 = math_ops._as_indexed_slices(
2103        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
2104    t3 = None
2105    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
2106    result = backprop.aggregate_indexed_slices_gradients([t0, t1, t3])
2107    self._assert_indexed_slices_equal(total, result)
2108
2109  def testMixedTensorAndIndexedSlices(self):
2110    t0 = math_ops._as_indexed_slices(
2111        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
2112    t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
2113    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
2114    result = backprop.aggregate_indexed_slices_gradients([t0, t1])
2115    self._assert_indexed_slices_equal(total, result)
2116
2117
2118if __name__ == '__main__':
2119  test.main()
2120