xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/control_flow/while_v2_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for while_v2."""
16
17from absl.testing import parameterized
18
19from google.protobuf import text_format
20from tensorflow.core.framework import graph_pb2
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.core.protobuf import rewriter_config_pb2
23from tensorflow.python.eager import backprop
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import function
29from tensorflow.python.framework import importer
30from tensorflow.python.framework import meta_graph
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import test_util
35from tensorflow.python.grappler import tf_optimizer
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import control_flow_util
39from tensorflow.python.ops import control_flow_util_v2
40from tensorflow.python.ops import control_flow_v2_toggles
41from tensorflow.python.ops import custom_gradient
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_list_ops
44from tensorflow.python.ops import gradient_checker_v2
45from tensorflow.python.ops import gradients_impl
46from tensorflow.python.ops import list_ops
47from tensorflow.python.ops import map_fn
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import random_ops
50from tensorflow.python.ops import variables
51from tensorflow.python.ops import while_v2
52from tensorflow.python.ops.ragged import ragged_factory_ops
53from tensorflow.python.ops.ragged import ragged_tensor
54from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
55from tensorflow.python.platform import test
56
57
58def random_gamma(shape):  # pylint: disable=invalid-name
59  return random_ops.random_gamma(shape, 1.0)
60
61
62def random_gamma_with_alpha_beta(shape):  # pylint: disable=invalid-name
63  return random_ops.random_gamma(
64      shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
65
66
67def random_poisson_v2(shape):  # pylint: disable=invalid-name
68  return random_ops.random_poisson_v2(shape, 1.0)
69
70
71def random_poisson_v2_with_lam(shape):  # pylint: disable=invalid-name
72  return random_ops.random_poisson_v2(shape, [12.2, 3.3])
73
74
75def fill(shape):  # pylint: disable=invalid-name
76  return array_ops.fill(shape, 1.0)
77
78
79class WhileV2Test(test.TestCase, parameterized.TestCase):
80
81  @test_util.run_deprecated_v1
82  def testSingleLoopVar(self):
83    x = constant_op.constant(2.)
84    ret = while_loop_v2(
85        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
86    grad = gradients_impl.gradients(ret, [x])
87    with self.cached_session():
88      self.assertEqual(self.evaluate(ret), 16.)
89      self.assertSequenceEqual(self.evaluate(grad), [32.])
90
91  @test_util.run_deprecated_v1
92  def testSingleLoopVarBackPropFalse(self):
93    x = constant_op.constant(2.)
94    ret = while_loop_v2(
95        lambda v: v < 8.,
96        lambda v: v * v, [x],
97        return_same_structure=False,
98        back_prop=False)
99    grad = gradients_impl.gradients(ret, [x])
100    self.assertEqual(grad, [None])
101    with self.cached_session():
102      self.assertEqual(self.evaluate(ret), 16.)
103
104  @test_util.run_deprecated_v1
105  def testCustomGradient(self):
106    x = constant_op.constant(2.)
107    n = constant_op.constant(1., name="const-n")
108    m = variables.Variable(1.0)
109    self.evaluate(variables.global_variables_initializer())
110
111    def body_fn(v):  # pylint: disable=invalid-name
112
113      @custom_gradient.custom_gradient
114      def inner_fn(v):  # pylint: disable=invalid-name
115
116        def grad_fn(dy, variables=None):  # pylint: disable=invalid-name, unused-argument, redefined-outer-name
117          return dy * 2 * v * n * m, [v * v]
118
119        return v * v * m, grad_fn
120
121      return inner_fn(v)
122
123    ret = while_loop_v2(
124        lambda v: v < 8., body_fn, [x], return_same_structure=False)
125    grad = gradients_impl.gradients(ret, [x])
126    with self.cached_session():
127      self.assertEqual(self.evaluate(ret), 16.)
128      self.assertSequenceEqual(self.evaluate(grad), [32.])
129
130  @test_util.run_v1_only("b/120545219")
131  def testReturnSameStructureTrue(self):
132    x = constant_op.constant(2.)
133    ret = while_loop_v2(
134        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True)
135    grad = gradients_impl.gradients(ret, [x])
136    with self.cached_session() as sess:
137      eval_result = sess.run(ret)
138      self.assertIsInstance(eval_result, list)
139      self.assertLen(eval_result, 1)
140      self.assertEqual(16., eval_result[0])
141      self.assertSequenceEqual(sess.run(grad), [32.])
142
143  def testVerifyInputOutputTypesMatch(self):
144
145    @def_function.function
146    def BuildWhile():
147      x = constant_op.constant(1., dtypes.float32)
148
149      def Body(x):
150        return math_ops.cast(x, dtypes.float16) + 1
151
152      while_loop_v2(lambda x: x < 10, Body, [x])
153
154    with self.assertRaisesRegex(
155        TypeError,
156        r"Loop var Const:0 enters the loop with type <dtype: 'float32'> "
157        r"but has type <dtype: 'float16'> after 1 iteration."):
158      BuildWhile()
159
160  @parameterized.parameters(dtypes.float32, dtypes.float64)
161  def testGradientTapeResourceVariable(self, dtype):
162    with context.eager_mode():
163      v = variables.Variable(1., dtype=dtype)
164
165      @def_function.function
166      def fnWithLoop():  # pylint: disable=invalid-name
167        with backprop.GradientTape() as tape:
168          _, x = while_loop_v2(
169              lambda i, _: i < 2,
170              lambda i, x: (i + 1, x * v),
171              [0, constant_op.constant(2., dtype=dtype)])
172        return tape.gradient(x, v)
173
174      self.assertAllEqual(fnWithLoop(), 4.0)
175
176  def testDeferredCaptures(self):
177    with context.eager_mode():
178      c = constant_op.constant(10)
179
180      @def_function.function
181      def F():
182
183        def Body(_):
184          return ops.get_default_graph().capture_call_time_value(
185              lambda: c, tensor_spec.TensorSpec([], dtypes.int32))
186
187        x, = while_loop_v2(lambda i: True, Body, [0], maximum_iterations=1)
188        return x
189
190      self.assertAllEqual(F(), 10)
191
192  def checkIteratedGradients(self, func):
193    with context.eager_mode():
194
195      def _Grad(f):
196        def _GradFunction(primal):
197          with backprop.GradientTape() as tape:
198            tape.watch(primal)
199            primal_out = f(primal)
200          return tape.gradient(primal_out, primal)
201        return _GradFunction
202
203      f = func
204      one = constant_op.constant(1.)
205
206      for _ in range(3):
207        theoretical, numerical = gradient_checker_v2.compute_gradient(
208            def_function.function(f), [one])
209        self.assertAllClose(theoretical, numerical, rtol=1e-3)
210        f = _Grad(f)
211        self.assertAllClose(array_ops.reshape(numerical, []),
212                            def_function.function(f)(one),
213                            rtol=1e-3)
214
215  def testIteratedGradients(self):
216
217    def _Func(x):
218      _, z = while_loop_v2(
219          lambda i, _: i < 2,
220          lambda i, y: (i + 1, math_ops.cos(y)),
221          [0, x])
222      return z
223
224    self.checkIteratedGradients(_Func)
225
226  def testIteratedGradientsWithList(self):
227
228    def _Func(x):
229      results = list_ops.empty_tensor_list(
230          element_shape=[], element_dtype=dtypes.float32)
231
232      def _LoopBody(i, y, handle):
233        return (i + 1, math_ops.cos(y),
234                list_ops.tensor_list_push_back(handle, y))
235
236      _, z, results = while_loop_v2(
237          lambda i, _, h: i < 2, _LoopBody, [0, x, results])
238      return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
239          results, dtypes.float32))
240
241    self.checkIteratedGradients(_Func)
242
243  def testGradWhileGradWhileWithVariable(self):
244    with context.eager_mode():
245      v = variables.Variable(1.)
246
247      @def_function.function
248      def _Func(x):
249
250        def _Inner(a):
251          with backprop.GradientTape() as tape:
252            tape.watch(a)
253            _, b = while_loop_v2(
254                lambda i, _: i < 2,
255                lambda i, y: (i + 1, math_ops.cos(v + y)),
256                [0, a])
257          return tape.gradient(b, a)
258
259        _, z = while_loop_v2(
260            lambda i, _: i < 2,
261            lambda i, y: (i + 1, _Inner(y)),
262            [0, x])
263        return z
264
265      with backprop.GradientTape(persistent=True) as tape:
266        x = constant_op.constant(1.)
267        tape.watch(x)
268        y = _Func(x)
269      dx, _ = tape.gradient(y, [x, v])
270      theoretical, numerical = gradient_checker_v2.compute_gradient(
271          _Func, [x])
272      self.assertAllClose(numerical, theoretical, rtol=1e-3)
273      self.assertAllClose(array_ops.reshape(numerical, []),
274                          dx, rtol=1e-3)
275
276  def testThreeNestWithLists(self):
277    with context.eager_mode():
278      def _WrapInWhile(f):
279        def _Wrapped(x):
280          results = list_ops.empty_tensor_list(
281              element_shape=[], element_dtype=dtypes.float32)
282
283          def _LoopBody(i, y, handle):
284            return (i + 1, f(math_ops.cos(y)),
285                    list_ops.tensor_list_push_back(handle, y))
286
287          _, z, results = control_flow_ops.while_loop(
288              lambda i, _, h: i < 2, _LoopBody, [0, x, results])
289          return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
290              results, dtypes.float32))
291        return _Wrapped
292
293      f = math_ops.sin
294
295      target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f)))
296
297      @def_function.function
298      def _TapeFromGraphMode(x):
299        with backprop.GradientTape(persistent=True) as tape:
300          tape.watch(x)
301          y = target_function(x)
302        return tape.gradient(y, x)
303
304      x = constant_op.constant(1.)
305      dx = _TapeFromGraphMode(x)
306      theoretical, numerical = gradient_checker_v2.compute_gradient(
307          target_function, [x])
308      self.assertAllClose(numerical, theoretical, rtol=3e-3)
309      self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3)
310
311  def testDeviceLabelsInherited(self):
312    def _LoopBody(i, y):
313      result = math_ops.cos(y)
314      self.assertIn("CPU:10", result.device)
315      with ops.device("CPU:11"):
316        result = array_ops.identity(result)
317      self.assertIn("CPU:11", result.device)
318      return i + 1, result
319
320    @def_function.function
321    def _FunctionWithWhileLoop():
322      x = constant_op.constant(1.)
323      with ops.device("CPU:10"):
324        _, z = while_loop_v2(
325            lambda i, _: i < 2,
326            _LoopBody,
327            [0, x])
328      return z
329    # The test assertion runs at trace time.
330    _FunctionWithWhileLoop.get_concrete_function()
331
332  def testExternalControlDependencies(self):
333    with ops.Graph().as_default(), self.test_session():
334      v = variables.Variable(1.)
335      self.evaluate(v.initializer)
336      op = v.assign_add(1.)
337
338      def body_fn(i):  # pylint: disable=invalid-name
339        with ops.control_dependencies([op]):
340          return i + 1
341
342      loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
343      loop[0].op.run()
344      self.assertAllEqual(self.evaluate(v), 2.0)
345
346  @test_util.run_deprecated_v1
347  def testMultipleLoopVarsBasic(self):
348    x = constant_op.constant(5.)
349    y = constant_op.constant(3.)
350
351    # x = 5.
352    # y = 3.
353    # while x < 45.:
354    #   x = x * y
355    ret = while_loop_v2(
356        lambda v, _: v < 45.,
357        lambda v, w: (v * w, w), [x, y],
358        return_same_structure=False)
359    # ret = [x*y^2, y]
360
361    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
362    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
363    with self.cached_session():
364      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
365      self.assertSequenceEqual(self.evaluate(grad), [9.])
366
367  @test_util.run_deprecated_v1
368  def testMultipleLoopNonscalarCond(self):
369    x = constant_op.constant([[5.]])
370    y = constant_op.constant(3.)
371
372    # x = 5.
373    # y = 3.
374    # while x < 45.:
375    #   x = x * y
376    ret = while_loop_v2(
377        lambda v, _: v < 45.,
378        lambda v, w: (v * w, w), [x, y],
379        return_same_structure=False)
380    # ret == [x*y^2, y]
381
382    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
383    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
384    with self.cached_session():
385      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
386      self.assertSequenceEqual(self.evaluate(grad), [9.])
387
388  @test_util.run_deprecated_v1
389  def testMultipleLoopVars(self):
390    x = constant_op.constant(5.)
391    y = constant_op.constant(3.)
392
393    # x = 5.
394    # y = 3.
395    # while x < 45.:
396    #   x = x * y
397    #   y = x + y
398    ret = while_loop_v2(
399        lambda v, _: v < 45.,
400        lambda v, w: (v * w, v + w), [x, y],
401        return_same_structure=False)
402    # ret = [y*x**2 + x*y**2, x*y + x + y]
403
404    gradx_0 = gradients_impl.gradients(ret[0], [x])  # [2*x*y + y**2]
405    gradx_1 = gradients_impl.gradients(ret[1], [x])  # [y + 1]
406    gradx_2 = gradients_impl.gradients(ret, [x])  # [2*x*y + y**2 + 2*y + 1]
407    grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
408    grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
409    grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
410    with self.cached_session():
411      self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
412      self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
413      self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
414      self.assertSequenceEqual(self.evaluate(gradx_2), [43.])
415      self.assertSequenceEqual(self.evaluate(grady_0), [55.])
416      self.assertSequenceEqual(self.evaluate(grady_1), [6.])
417      self.assertSequenceEqual(self.evaluate(grady_2), [61.])
418
419  @test_util.run_deprecated_v1
420  def testGradientTape(self):
421    with backprop.GradientTape() as t:
422      x = constant_op.constant(2.)
423      t.watch(x)
424      ret = while_loop_v2(
425          lambda v: v < 4., lambda v: v * v, [x],
426          return_same_structure=False)  # x**2
427    grad = t.gradient(ret, x)
428    with self.cached_session() as sess:
429      self.assertAllEqual(sess.run(grad), 4.0)
430
431  @test_util.run_deprecated_v1
432  def testMultipleWhileLoops(self):
433    x = constant_op.constant(2.)
434    ret1 = while_loop_v2(
435        lambda v: v < 4., lambda v: v * v, [x],
436        return_same_structure=False)  # x**2
437    ret2 = while_loop_v2(
438        lambda v: v < 16., lambda v: v * v, [ret1],
439        return_same_structure=False)  # x**4
440    grad = gradients_impl.gradients(ret2, [x])  # 4x**3
441    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
442    with self.cached_session():
443      self.assertSequenceEqual(self.evaluate(grad), [32.])
444      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
445
446  def testMultipleWhileLoopsWithFunc(self):
447    x = constant_op.constant(2.)
448
449    @def_function.function
450    def Fn():
451      ret1 = while_loop_v2(
452          lambda v: v < 4.,
453          lambda v: v * v, [x],
454          return_same_structure=False,
455          name="while_1")  # x**2
456      ret2 = while_loop_v2(
457          lambda v: v < 16.,
458          lambda v: v * v, [x],
459          return_same_structure=False,
460          name="while_2")  # x**4
461      return ret1, ret2
462
463    concrete_fn = Fn.get_concrete_function()
464    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
465    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
466    self.assertEqual(while_1.type, "StatelessWhile")
467    self.assertEqual(while_2.type, "StatelessWhile")
468    self.assertEmpty(while_1.control_inputs)
469    self.assertEmpty(while_2.control_inputs)
470
471  def testMultipleWhileLoopsGradStateless(self):
472
473    @def_function.function
474    def Fn():
475      x = constant_op.constant(2.)
476      with backprop.GradientTape() as tape:
477        tape.watch(x)
478        ret1 = while_loop_v2(
479            lambda v: v < 4.,
480            lambda v: v * v, [x],
481            return_same_structure=False,
482            name="while_1")  # x**2
483        ret2 = while_loop_v2(
484            lambda v: v < 16.,
485            lambda v: v * v, [x],
486            return_same_structure=False,
487            name="while_2")  # x**4
488        loss = ret1 + ret2
489      return tape.gradient(loss, x)
490
491    graph = Fn.get_concrete_function().graph
492    while_ops = [op for op in graph.get_operations() if "While" in op.type]
493    self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
494                        "Must have exactly 4 StatelessWhile ops.")
495    for op in while_ops:
496      self.assertEmpty(op.control_inputs,
497                       "{} should not have any control inputs".format(op.name))
498
499  def testMultipleWhileLoopsWithDeps(self):
500    x = variables.Variable(2.)
501    c = constant_op.constant(2.)
502
503    @def_function.function
504    def Fn():
505
506      def Body1(v):
507        x.assign(x)
508        return v * x
509
510      ret1 = while_loop_v2(
511          lambda v: v < 4.,
512          Body1, [c],
513          return_same_structure=False,
514          name="while_1")  # 2x
515
516      def Body2(v):
517        x.assign(x)
518        return v * x * x
519
520      ret2 = while_loop_v2(
521          lambda v: v < 16.,
522          Body2, [c],
523          return_same_structure=False,
524          name="while_2")  # 4x
525      return ret1, ret2
526
527    concrete_fn = Fn.get_concrete_function()
528    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
529    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
530    self.assertEqual(while_1.type, "While")
531    self.assertEqual(while_2.type, "While")
532    self.assertEmpty(while_1.control_inputs)
533    self.assertLen(while_2.control_inputs, 1)
534    self.assertIs(while_2.control_inputs[0], while_1)
535
536  def testMultipleWhileLoopsWithVarsDeps(self):
537    x1 = variables.Variable(2.)
538    x2 = variables.Variable(3.)
539    c = constant_op.constant(2.)
540
541    @def_function.function
542    def Fn():
543
544      def Body1(v):
545        x1.assign(x1)
546        return v * x1
547
548      ret1 = while_loop_v2(
549          lambda v: v < 4.,
550          Body1, [c],
551          return_same_structure=False,
552          name="while_1")  # 2x
553
554      def Body2(v):
555        x1.assign(x1)
556        return v * x1 * x1
557
558      ret2 = while_loop_v2(
559          lambda v: v < 16.,
560          Body2, [c],
561          return_same_structure=False,
562          name="while_2")  # 4x
563
564      def Body3(v):
565        x2.assign(x2)
566        return v * x2
567
568      ret3 = while_loop_v2(
569          lambda v: v < 4.,
570          Body3, [c],
571          return_same_structure=False,
572          name="while_3")  # 3x
573
574      def Body4(v):
575        x2.assign(x2)
576        return v * x2 * x2
577
578      ret4 = while_loop_v2(
579          lambda v: v < 16.,
580          Body4, [c],
581          return_same_structure=False,
582          name="while_4")  # 9x
583      ret5 = while_loop_v2(
584          lambda v: v < 16.,
585          lambda v: v * v, [c],
586          return_same_structure=False,
587          name="while_stateless")  # x**2
588      return ret1, ret2, ret3, ret4, ret5
589
590    concrete_fn = Fn.get_concrete_function()
591    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
592    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
593    while_3 = concrete_fn.graph.get_operation_by_name("while_3")
594    while_4 = concrete_fn.graph.get_operation_by_name("while_4")
595    while_stateless = concrete_fn.graph.get_operation_by_name(
596        "while_stateless")
597    self.assertEqual(while_1.type, "While")
598    self.assertEqual(while_2.type, "While")
599    self.assertEqual(while_3.type, "While")
600    self.assertEqual(while_4.type, "While")
601    self.assertEqual(while_stateless.type, "StatelessWhile")
602    self.assertEmpty(while_1.control_inputs)
603    self.assertLen(while_2.control_inputs, 1)
604    self.assertIs(while_2.control_inputs[0], while_1)
605    self.assertEmpty(while_3.control_inputs)
606    self.assertLen(while_4.control_inputs, 1)
607    self.assertIs(while_4.control_inputs[0], while_3)
608    self.assertEmpty(while_stateless.control_inputs)
609
610  @test_util.run_deprecated_v1
611  def testDoubleDerivative(self):
612    x = constant_op.constant(2.)
613    ret = while_loop_v2(
614        lambda v: v < 8., lambda v: v**2, [x],
615        return_same_structure=False)  # x**4
616    grad = gradients_impl.gradients(ret, [x])  # 4x**3
617    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
618    with self.cached_session():
619      self.assertEqual(self.evaluate(ret), 16.)
620      self.assertSequenceEqual(self.evaluate(grad), [32.])
621      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
622
623  @test_util.run_v2_only
624  def testMultipleWhileLoopsEager(self):
625
626    @def_function.function
627    def Func():
628      x = constant_op.constant(2.)
629      ret1 = while_loop_v2(
630          lambda v: v < 4., lambda v: v * v, [x],
631          return_same_structure=False)  # x**2
632      ret2 = while_loop_v2(
633          lambda v: v < 16.,
634          lambda v: v * v, [ret1],
635          return_same_structure=False)  # x**4
636      grad = gradients_impl.gradients(ret2, [x])[0]  # 4x**3
637      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
638      return grad, grad_grad
639
640    grad, grad_grad = Func()
641    self.assertEqual(grad.numpy(), 32.)
642    self.assertEqual(grad_grad.numpy(), 48.)
643
644  @test_util.run_v2_only
645  def testDoubleDerivativeEager(self):
646
647    @def_function.function
648    def Func():
649      x = constant_op.constant(2.)
650      ret = while_loop_v2(
651          lambda v: v < 8., lambda v: v**2, [x],
652          return_same_structure=False)  # x**4
653      grad = gradients_impl.gradients(ret, [x])[0]  # 4x**3
654      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
655      return ret, grad, grad_grad
656
657    ret, grad, grad_grad = Func()
658    self.assertEqual(ret.numpy(), 16.)
659    self.assertEqual(grad.numpy(), 32.)
660    self.assertEqual(grad_grad.numpy(), 48.)
661
662  def _testPruning(self):
663    x = constant_op.constant(1)
664
665    tensor_list = list_ops.empty_tensor_list(
666        element_dtype=x.dtype, element_shape=x.shape)
667
668    def Cond(x, tl):
669      del tl  # Unused for Cond.
670      return x < 5
671
672    def Body(x, tl):
673      return x + 1, list_ops.tensor_list_push_back(tl, x)
674
675    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
676
677    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
678    train_op.append(outputs[0])
679
680    g = GetOptimizedGraph()
681    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
682    # away, causing an extra Enter node.
683    enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
684    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
685    # Test that the TensorList is pruned out.
686    self.assertEmpty([
687        n for n in g.node if n.op == "Enter" and
688        n.attr["T"].type == dtypes.variant.as_datatype_enum
689    ])
690    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
691
692    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
693    train_op.append(stack)
694    g = GetOptimizedGraph()
695    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
696    # away, causing an extra Enter node.
697    enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
698    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
699    # Test that the TensorList is not pruned out.
700    self.assertNotEmpty([
701        n for n in g.node if n.op == "Enter" and
702        n.attr["T"].type == dtypes.variant.as_datatype_enum
703    ])
704    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
705
706  @test_util.run_deprecated_v1
707  def testPruningV1(self):
708    self._testPruning()
709
710  @test_util.enable_control_flow_v2
711  @test_util.run_deprecated_v1
712  def testPruningV2(self):
713    self._testPruning()
714
715  def _testDoNotAccumulateInvariants(self):
716    push_op = ("TensorListPushBack"
717               if control_flow_v2_toggles.control_flow_v2_enabled() else
718               "StackPushV2")
719
720    # Tests that loop invariants, i.e., tensors that are "captured" by the
721    # while loop and not passed as loop variables are not accumulated in
722    # gradient computation.
723    v = constant_op.constant(5.0, name="v")
724
725    r = control_flow_ops.while_loop(
726        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
727
728    output = gradients_impl.gradients(r, v)[0]
729    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
730    train_op.append(output)
731
732    g = GetOptimizedGraph()
733    # The gradient for v * x requires the value of both v and x. Since v is a
734    # loop invariant it is not accumulated so we have just one accumulator for
735    # x.
736    self.assertLen([n for n in g.node if n.op == push_op], 1)
737
738  @test_util.run_deprecated_v1
739  def testDoNotAccumulateInvariantsV1(self):
740    self._testDoNotAccumulateInvariants()
741
742  @test_util.run_deprecated_v1
743  @test_util.enable_control_flow_v2
744  def testDoNotAccumulateInvariantsV2(self):
745    self._testDoNotAccumulateInvariants()
746
747  @test_util.enable_control_flow_v2
748  @test_util.run_deprecated_v1
749  @test_util.enable_output_all_intermediates
750  def testPruningNested(self):
751    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
752    x = constant_op.constant(0)
753
754    tensor_list = list_ops.empty_tensor_list(
755        element_dtype=x.dtype, element_shape=x.shape)
756
757    def Cond(x, tl):
758      del tl  # Unused for Cond.
759      return x < 25
760
761    def Body(x, tl):
762
763      def InnerCond(inner_x, unused_outer_x, unused_tl):
764        return inner_x < 5
765
766      def InnerBody(inner_x, outer_x, tl):
767        return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
768
769      inner_x = constant_op.constant(0)
770      return control_flow_ops.while_loop(InnerCond, InnerBody,
771                                         [inner_x, x, tl])[1:]
772
773    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
774
775    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
776    train_op.append(outputs[0])
777
778    g = GetOptimizedGraph()
779    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
780    # away, causing an extra Enter node.
781    # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
782    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
783    # Test that the TensorList is pruned out.
784    self.assertEmpty([
785        n for n in g.node if n.op == "Enter" and
786        n.attr["T"].type == dtypes.variant.as_datatype_enum
787    ])
788    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
789    self.assertEmpty([n for n in g.node if n.op == "_While"])
790
791    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
792    train_op.append(stack)
793    g = GetOptimizedGraph()
794    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
795    # away, causing an extra Enter node.
796    # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
797    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
798    # Test that the TensorList is not pruned out.
799    self.assertNotEmpty([
800        n for n in g.node if n.op == "Enter" and
801        n.attr["T"].type == dtypes.variant.as_datatype_enum
802    ])
803    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
804
805  @test_util.enable_control_flow_v2
806  @test_util.run_deprecated_v1
807  @test_util.enable_output_all_intermediates
808  def testPruningNested2(self):
809    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
810    v = constant_op.constant(5.0, name="v")
811
812    p = array_ops.placeholder(dtype=dtypes.int32)
813
814    def MidBodyBuilder(iterations):
815
816      def MidBody(i, x):
817        r = control_flow_ops.while_loop(
818            lambda *_: True,
819            lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
820            (0, x),
821            maximum_iterations=iterations,
822            name="inner")
823        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
824
825      return MidBody
826
827    def OuterBody(i, x):
828      iterations = array_ops.size(p, name="iterations")
829      return (i + 1, x + control_flow_ops.while_loop(
830          lambda *_: True,
831          MidBodyBuilder(iterations), (0, x),
832          maximum_iterations=iterations,
833          name="mid")[1])
834
835    def CreateWhileLoop():
836      with ops.device("/cpu:0"):
837        r = control_flow_ops.while_loop(
838            lambda *_: True,
839            OuterBody, (0, 1.0),
840            maximum_iterations=5,
841            name="outer")
842        return array_ops.identity(r[1])
843
844    output = CreateWhileLoop()
845    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
846    train_op.append(output)
847
848    g = GetOptimizedGraph()
849    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
850
851  @test_util.enable_control_flow_v2
852  @test_util.run_deprecated_v1
853  @test_util.enable_output_all_intermediates
854  def testPruningNested3(self):
855    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
856    v = constant_op.constant(5.0, name="v")
857
858    def CreateWhileLoop():
859      r = control_flow_ops.while_loop(
860          lambda _: True,
861          lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
862          maximum_iterations=5,
863          name="outer")
864      return array_ops.identity(r)
865
866    r = CreateWhileLoop()
867    output = gradients_impl.gradients(r, v)[0]
868    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
869    train_op.append(output)
870
871    g = GetOptimizedGraph()
872    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
873
874  def _assertNotAccumulated(self, while_op, index):
875    """Asserts that `while_op` input at `index` is not accumulated."""
876    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
877    placeholder = body_graph.inputs[index]
878    self.assertNotIn("TensorListPushBack",
879                     [op.type for op in placeholder.consumers()])
880
881  @test_util.enable_control_flow_v2
882  @test_util.run_deprecated_v1
883  @test_util.enable_output_all_intermediates
884  def testDoNotOutputLoopCounterAsIntermediate(self):
885    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
886    v = constant_op.constant(5.0, name="v")
887    r = control_flow_ops.while_loop(
888        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
889    # Skip over Identity.
890    while_op = r.op.inputs[0].op
891    self._assertNotAccumulated(while_op, 0)
892
893  @test_util.enable_control_flow_v2
894  @test_util.run_deprecated_v1
895  @test_util.enable_output_all_intermediates
896  def testDoNotOutputLoopInvariantAsIntermediate(self):
897    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
898
899    def GetInputIndex(op, tensor):
900      for index, inp in enumerate(op.inputs):
901        if inp is tensor:
902          return index
903
904    v = constant_op.constant(5.0, name="v")
905    r = control_flow_ops.while_loop(
906        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
907    # Skip over Identity.
908    while_op = r.op.inputs[0].op
909    # We can't directly use while_op.inputs.index() because Tensors are not
910    # hashable.
911    index = GetInputIndex(while_op, v)
912    self._assertNotAccumulated(while_op, index)
913
914  @test_util.run_deprecated_v1
915  def testCaptureExternalTensorInCond(self):
916    x = constant_op.constant(2.)
917    y = constant_op.constant(1.)
918    ret = while_loop_v2(
919        lambda v: v + y < 9.,
920        lambda v: v * 3., [x],
921        return_same_structure=False)
922    grad = gradients_impl.gradients(ret, [x])
923    with self.cached_session():
924      self.assertEqual(self.evaluate(ret), 18.)
925      self.assertSequenceEqual(self.evaluate(grad), [9.])
926
927  @test_util.run_deprecated_v1
928  def testCaptureExternalTensorInBody(self):
929    x = constant_op.constant(2.)
930    y = constant_op.constant(3.)
931    ret = while_loop_v2(
932        lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
933    grad = gradients_impl.gradients(ret, [x])
934    with self.cached_session():
935      self.assertEqual(self.evaluate(ret), 18.)
936      self.assertSequenceEqual(self.evaluate(grad), [9.])
937
938  @test_util.run_deprecated_v1
939  def testLoopWithTensorListPushBack(self):
940    x = constant_op.constant(2.)
941
942    tensor_list = list_ops.empty_tensor_list(
943        element_dtype=dtypes.float32, element_shape=ScalarShape())
944
945    def Cond(x, tl):
946      del tl  # Unused for Cond.
947      return x < 5.
948
949    def Body(x, tl):
950      tl = list_ops.tensor_list_push_back(tl, x)
951      tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
952      return x**2., tl
953
954    ret = while_loop_v2(
955        Cond, Body, [x, tensor_list], return_same_structure=False)
956    grad = gradients_impl.gradients(ret[0], x)
957    with self.cached_session() as sess:
958      self.assertEqual(sess.run(ret[0]), 16.)
959      self.assertSequenceEqual(self.evaluate(grad), [32.])
960
961  @test_util.run_deprecated_v1
962  def testDuplicateAccumulator(self):
963    x = constant_op.constant(2.)
964
965    tensor_list = list_ops.empty_tensor_list(
966        element_dtype=dtypes.float32, element_shape=ScalarShape())
967
968    def Cond(x, tl):
969      del tl  # Unused for Cond.
970      return x < 5.
971
972    def Body(x, tl):
973      # There is an accumulator in the loop already so we should not add
974      # another.
975      tl = list_ops.tensor_list_push_back(tl, x)
976      return x**2., tl
977
978    ret = while_loop_v2(
979        Cond, Body, [x, tensor_list], return_same_structure=False)
980
981    for op in ops.get_default_graph().get_operations():
982      if op.type == "While" or op.type == "StatelessWhile":
983        while_op = op
984
985    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
986    x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
987    x_input_t = body_graph.inputs[x_input_index]
988    accumulator_count = len(
989        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
990    self.assertEqual(accumulator_count, 1)
991
992    grad = gradients_impl.gradients(ret[0], x)
993    with self.cached_session() as sess:
994      self.assertEqual(sess.run(ret[0]), 16.)
995      self.assertSequenceEqual(self.evaluate(grad), [32.])
996
997  @parameterized.named_parameters(
998      ("UnknownShape", None),
999      ("PartiallyDefinedShape", [None, 2]),
1000      ("FullyDefinedShape", [1, 2]),
1001  )
1002  @test_util.run_deprecated_v1
1003  def testAccumulatorElementShape(self, shape):
1004
1005    def MatchShape(actual_tensor_shape):
1006      # Compare the shapes, treating None dimensions as equal. We do not
1007      # directly check actual_tensor_shape and tf.TensorShape(shape) for
1008      # equality because tf.Dimension.__eq__ returns None if either dimension is
1009      # None.
1010      if shape is None:
1011        self.assertIsNone(actual_tensor_shape.dims)
1012      else:
1013        self.assertListEqual(actual_tensor_shape.as_list(), shape)
1014
1015    def GetAccumulatorForInputAtIndex(while_op, idx):
1016      body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
1017      y_input_t = body_graph.inputs[idx]
1018      push_back_node = [c for c in y_input_t.consumers()
1019                        if c.type == "TensorListPushBack"][0]
1020      output_idx = body_graph.outputs.index(push_back_node.outputs[0])
1021      return while_op.outputs[output_idx]
1022
1023    x = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
1024    y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
1025
1026    # Forward pass.
1027    ret = while_loop_v2(lambda v, u: v < 8.,
1028                        lambda v, u: (math_ops.pow(v, u), u),
1029                        [x, y],
1030                        return_same_structure=True)
1031    while_op = ret[0].op.inputs[0].op
1032    # Gradient pass.
1033    grad = gradients_impl.gradients(ret[0], x)
1034    # Note: There is an Identity b/w grad[0] and the While op.
1035    grad_while_op = grad[0].op.inputs[0].op
1036
1037    # Get the TensorList output of While op containing the accumulated values
1038    # of y.
1039    x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0]
1040    output = GetAccumulatorForInputAtIndex(while_op, x_input_index)
1041    _, val = list_ops.tensor_list_pop_back(output,
1042                                           element_dtype=dtypes.float32)
1043    MatchShape(val.shape)
1044
1045    # Take second derivative to generate intermediate grad_while_op outputs
1046    gradients_impl.gradients(grad, x)
1047
1048    # Get the TensorList output of gradient While op containing the accumulated
1049    # values of grad_x (note that grad_x is needed by the second derivative).
1050    # grad_while_op.inputs:
1051    grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0])
1052    grad_output = GetAccumulatorForInputAtIndex(grad_while_op,
1053                                                grad_output_index)
1054    _, val = list_ops.tensor_list_pop_back(grad_output,
1055                                           element_dtype=dtypes.float32)
1056    MatchShape(val.shape)
1057
1058  def _createWhile(self, name):
1059    """Helper function testDefaultName."""
1060    output = while_v2.while_loop(
1061        lambda i: i < 3,
1062        lambda i: i + 1, [constant_op.constant(0)],
1063        return_same_structure=False)
1064    while_op = output.op.inputs[0].op
1065    self.assertEqual(while_op.type, "StatelessWhile")
1066    return while_op
1067
1068  def testDefaultName(self):
1069    with ops.Graph().as_default():
1070      while_op = self._createWhile(None)
1071      self.assertEqual(while_op.name, "while")
1072      self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*")
1073      self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*")
1074
1075    with ops.Graph().as_default():
1076      with ops.name_scope("foo"):
1077        while1_op = self._createWhile("")
1078        self.assertEqual(while1_op.name, "foo/while")
1079        self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*")
1080        self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*")
1081
1082        while2_op = self._createWhile(None)
1083        self.assertEqual(while2_op.name, "foo/while_1")
1084        self.assertRegex(
1085            while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*")
1086        self.assertRegex(
1087            while2_op.get_attr("body").name, r"foo_while_1_body_\d*")
1088
1089  @test_util.enable_control_flow_v2
1090  @test_util.run_deprecated_v1
1091  def testWhileAndTensorArray(self):
1092    param = constant_op.constant(2.0)
1093    y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
1094    # map_fn uses TensorArray internally.
1095    r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0)
1096    grad = gradients_impl.gradients(r, param)[0]
1097    self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r))
1098    self.assertAllClose(21.0, self.evaluate(grad))
1099
1100  @test_util.run_deprecated_v1
1101  def testNestedWhile(self):
1102    # Compute sum of geometric progression: n^0 + n^1 + ... + n^m
1103    # We compute the pow using a while loop.
1104    n = constant_op.constant(3.)
1105    m = constant_op.constant(5.)
1106    sum_of_powers = constant_op.constant(0.)
1107
1108    def Body(i, previous_sum):
1109      prod = constant_op.constant(1.)
1110      return i - 1., previous_sum + while_loop_v2(
1111          lambda c, _: c > 0,
1112          lambda c, v: (c - 1., v * n), [i, prod],
1113          return_same_structure=False)[1]
1114
1115    result = while_loop_v2(
1116        lambda i, _: i >= 0,
1117        Body, [m, sum_of_powers],
1118        return_same_structure=False)[1]
1119    grad = gradients_impl.gradients(result, [n])
1120    self.assertEqual(self.evaluate(result), 364.)
1121    self.assertSequenceEqual(self.evaluate(grad), [547.])
1122
1123  @test_util.run_deprecated_v1
1124  def testNestedWhileWithLegacyDefun(self):
1125    n = constant_op.constant(3.)
1126    m = constant_op.constant(5.)
1127    sum_of_powers = constant_op.constant(0.)
1128
1129    def Body(i, previous_sum):
1130      prod = constant_op.constant(1.)
1131
1132      def InnerBodyWrapper(c, v):
1133
1134        @function.Defun(dtypes.float32, dtypes.float32)
1135        def InnerBody(c, v):
1136          return c - 1., v * n
1137
1138        results = InnerBody(c, v)
1139        results[0].set_shape([])
1140        results[1].set_shape([])
1141        return results
1142
1143      return i - 1., previous_sum + while_loop_v2(
1144          lambda c, _: c > 0,
1145          InnerBodyWrapper, [i, prod],
1146          return_same_structure=False)[1]
1147
1148    result = while_loop_v2(
1149        lambda i, _: i >= 0,
1150        Body, [m, sum_of_powers],
1151        return_same_structure=False)[1]
1152    grad = gradients_impl.gradients(result, [n])
1153    self.assertEqual(self.evaluate(result), 364.)
1154    self.assertSequenceEqual(self.evaluate(grad), [547.])
1155
1156  @test_util.run_deprecated_v1
1157  def testIdentityNodeInBody(self):
1158
1159    def Body(v):
1160      v = array_ops.identity(v)
1161      v = array_ops.identity(v)
1162      return v * v
1163
1164    x = constant_op.constant(2.)
1165    ret = while_loop_v2(
1166        lambda v: v < 8., Body, [x], return_same_structure=False)
1167    grad = gradients_impl.gradients(ret, [x])
1168    self.assertEqual(self.evaluate(ret), 16.)
1169    self.assertSequenceEqual(self.evaluate(grad), [32.])
1170
1171  @test_util.run_deprecated_v1
1172  def testForwardPassRewrite(self):
1173    x = constant_op.constant(1.0, name="x")
1174    output = while_v2.while_loop(lambda x: x < 10.0,
1175                                 lambda x: x * 2.0,
1176                                 [x])[0]
1177    while_op = output.op.inputs[0].op
1178    self.assertEqual(while_op.type, "StatelessWhile")
1179    # outputs = [loop_counter, max_iters, x]
1180    self.assertLen(while_op.outputs, 3)
1181
1182    gradients_impl.gradients(output, x)
1183    # while_op should have been rewritten to output intermediates.
1184    # outputs = [loop_counter, max_iters, x, x_accumulator]
1185    self.assertLen(while_op.outputs, 4)
1186
1187    gradients_impl.gradients(output, x)
1188    # Computing the gradient again shouldn't rewrite while_op again.
1189    self.assertLen(while_op.outputs, 4)
1190
1191  @parameterized.named_parameters(
1192      ("RandomUniform", random_ops.random_uniform, [5, 3]),
1193      ("RandomNormal", random_ops.random_normal, [5, 3]),
1194      ("ParameterizedTruncatedNormal",
1195       random_ops.parameterized_truncated_normal, [5, 3]),
1196      ("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
1197      ("RandomGamma", random_gamma, [5, 3]),
1198      ("RandomPoissonV2", random_poisson_v2, [5, 3]),
1199      ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
1200      ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
1201  )
1202  @test_util.run_deprecated_v1
1203  def testRandomOpsShape(self, random_fn, expected_shape):
1204    shape = constant_op.constant([3])
1205
1206    def Body(i, u):
1207      shape_extended = array_ops.concat([[5], shape], axis=0)
1208      u = random_fn(shape_extended)
1209      assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
1210      return i + 1, u
1211
1212    _, _ = while_loop_v2(
1213        cond=lambda i, _: i < 3,
1214        body=Body,
1215        loop_vars=[
1216            0,
1217            array_ops.zeros(expected_shape, dtype=dtypes.float32),
1218        ])
1219
1220  @test_util.run_deprecated_v1
1221  def testReshapeShape(self):
1222    shape = constant_op.constant([3, 4])
1223
1224    def Body(i, u):
1225      shape_extended = array_ops.concat([[5], shape], axis=0)
1226      u = array_ops.reshape(u, [-1])
1227      assert u.shape.as_list() == [60], str(u.shape.as_list())
1228      u = array_ops.reshape(u, shape_extended)
1229      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1230      return i + 1, u
1231
1232    _, _ = while_loop_v2(
1233        cond=lambda i, _: i < 3,
1234        body=Body,
1235        loop_vars=[
1236            0,
1237            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1238        ])
1239
1240  @parameterized.named_parameters(
1241      ("Zeros", array_ops.zeros),
1242      ("Ones", array_ops.ones),
1243      ("Fill", fill),
1244  )
1245  @test_util.run_deprecated_v1
1246  def testFillOpsShape(self, fill_fn):
1247    shape = constant_op.constant([3, 4])
1248
1249    def Body(i, u):
1250      shape_extended = array_ops.concat([[5], shape], axis=0)
1251      u = fill_fn(shape_extended)
1252      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1253      return i + 1, u
1254
1255    _, _ = while_loop_v2(
1256        cond=lambda i, _: i < 3,
1257        body=Body,
1258        loop_vars=[
1259            0,
1260            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1261        ])
1262
1263  @test_util.run_deprecated_v1
1264  def testExternalColocationGrad(self):
1265    external_t = constant_op.constant(2.)
1266    v0 = constant_op.constant(2.)
1267
1268    def Body(v):
1269      with ops.colocate_with(external_t):
1270        return v * v
1271
1272    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1273    grad = gradients_impl.gradients(ret, [v0])[0]
1274    self.assertAllEqual(ret, 16.)
1275    self.assertAllEqual(grad, 32.)
1276
1277  @test_util.run_deprecated_v1
1278  def testDoNotAccumulateConstNodes(self):
1279
1280    def Body(v):
1281      return v * 2.0
1282
1283    v0 = constant_op.constant(2.)
1284    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1285    # Gradients computation has the side-effect of updating the forward op
1286    # which is what we want to test.
1287    unused_grad = gradients_impl.gradients(ret, [v0])[0]
1288    # ret is separated from the `While` op by an `Identity` so we skip over
1289    # that.
1290    forward_while_op = ret.op.inputs[0].op
1291    body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph")
1292    push_back_nodes = [
1293        o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
1294    ]
1295    # Gradient of `Mul` requires accumulating both its inputs. But since one
1296    # of those is a Const (2.0), we should have just one accumulator.
1297    self.assertLen(push_back_nodes, 1)
1298
1299  def testDoNotAccumulateForwardTensorsForReductionOps(self):
1300
1301    @def_function.function
1302    def Fn():
1303      with backprop.GradientTape() as tape:
1304        x = constant_op.constant(2.)
1305        tape.watch(x)
1306
1307        def Body(i, x):
1308          forward_graph = ops.get_default_graph()
1309
1310          @custom_gradient.custom_gradient
1311          def SquaredWithZeroGrad(x):
1312
1313            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
1314              del variables
1315              gradient_graph = ops.get_default_graph()
1316              shape = gen_array_ops.shape(x)
1317              assert shape.graph is forward_graph
1318              rank = gen_array_ops.rank(x)
1319              assert rank.graph is forward_graph
1320              size = gen_array_ops.size(x)
1321              assert size.graph is forward_graph
1322              zeros = array_ops.zeros(shape)
1323              assert zeros.graph is gradient_graph
1324              return zeros
1325
1326            return x * 2, Grad
1327
1328          return i + 1, SquaredWithZeroGrad(x)
1329
1330        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
1331      grad = tape.gradient(result, x)
1332      return grad
1333
1334    Fn()
1335
1336  def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
1337
1338    @def_function.function
1339    def Fn():
1340      with backprop.GradientTape() as tape:
1341        e = constant_op.constant(2.)
1342        x = list_ops.empty_tensor_list(
1343            element_dtype=dtypes.float32, element_shape=e.shape)
1344        x = list_ops.tensor_list_push_back(x, e)
1345        tape.watch(x)
1346
1347        def Body(i, x):
1348          forward_graph = ops.get_default_graph()
1349
1350          @custom_gradient.custom_gradient
1351          def IdentityWithZeroGrad(x):
1352
1353            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
1354              del variables
1355              gradient_graph = ops.get_default_graph()
1356              shape = gen_list_ops.tensor_list_element_shape(
1357                  x, shape_type=dtypes.int32)
1358              assert shape.graph is forward_graph
1359              size = gen_list_ops.tensor_list_length(x)
1360              assert size.graph is forward_graph
1361              zeros = gen_list_ops.tensor_list_reserve(shape, size,
1362                                                       dtypes.float32)
1363              assert zeros.graph is gradient_graph
1364              return zeros
1365
1366            return x, Grad
1367
1368          return i + 1, IdentityWithZeroGrad(x)
1369
1370        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
1371      ones_like = list_ops.tensor_list_from_tensor(
1372          array_ops.ones_like(
1373              list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
1374          element_shape=tensor_shape.TensorShape([]))
1375      grad = tape.gradient(result, x, output_gradients=[ones_like])
1376      return grad
1377
1378    Fn()
1379
1380  @test_util.run_v2_only
1381  def testInheritParentNameScope(self):
1382
1383    @def_function.function
1384    def F():
1385      with ops.name_scope("foo"):
1386
1387        def Cond(unused_i):
1388          with ops.name_scope("cond"):
1389            actual_name_scope = ops.get_name_scope()
1390            expected_name_scope = "foo/while/cond"
1391            assert actual_name_scope == expected_name_scope, (
1392                "%s does not match %s" %
1393                (actual_name_scope, expected_name_scope))
1394          return False
1395
1396        def Body(i):
1397          with ops.name_scope("body"):
1398            actual_name_scope = ops.get_name_scope()
1399            expected_name_scope = "foo/while/body"
1400            assert actual_name_scope == expected_name_scope, (
1401                "%s does not match %s" %
1402                (actual_name_scope, expected_name_scope))
1403          return i
1404
1405        return while_v2.while_loop(Cond, Body, [0.])
1406
1407    F()
1408
1409  @test_util.run_deprecated_v1  # Need to pass RunMetadata.
1410  def testDisableLowering(self):
1411    old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
1412    control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True
1413    with self.session() as sess:
1414      x = constant_op.constant(2.)
1415      ret = while_loop_v2(
1416          lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
1417
1418      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1419      run_metadata = config_pb2.RunMetadata()
1420      self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata),
1421                       16)
1422      for dev_stat in run_metadata.step_stats.dev_stats:
1423        for ns in dev_stat.node_stats:
1424          self.assertNotIn("switch", ns.node_name)
1425    control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
1426
1427  def _runBasicWithConfig(self, config):
1428    with ops.device("/cpu:0"):
1429      x = constant_op.constant(0)
1430      ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
1431    with self.cached_session(config=config):
1432      self.assertEqual(1000, self.evaluate(ret))
1433
1434  @test_util.run_deprecated_v1
1435  def testRunKernelsInline(self):
1436    config = config_pb2.ConfigProto()
1437    config.inter_op_parallelism_threads = -1
1438    self._runBasicWithConfig(config)
1439
1440  @test_util.run_deprecated_v1
1441  def testSingleThreadedExecution(self):
1442    config = config_pb2.ConfigProto()
1443    config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
1444    self._runBasicWithConfig(config)
1445
1446  def testIsControlFlowGraph(self):
1447    x = constant_op.constant(0)
1448
1449    @def_function.function
1450    def F(c):
1451
1452      def Cond(i):
1453        self.assertTrue(i.graph.is_control_flow_graph)
1454        return i < 2
1455
1456      def Body(i):
1457        i = i + 1
1458        self.assertTrue(i.graph.is_control_flow_graph)
1459        return i
1460
1461      return while_loop_v2(Cond, Body, [c])
1462
1463    ret, = F(x)
1464    self.assertEqual(2, self.evaluate(ret))
1465
1466  def testImportFromSerializedWithFunctionInBody(self):
1467    serialized = """node {
1468      name: "Const"
1469      op: "Const"
1470      attr {
1471        key: "dtype"
1472        value {
1473          type: DT_FLOAT
1474        }
1475      }
1476      attr {
1477        key: "value"
1478        value {
1479          tensor {
1480            dtype: DT_FLOAT
1481            tensor_shape {
1482            }
1483            float_val: 1.0
1484          }
1485        }
1486      }
1487    }
1488    node {
1489      name: "while/maximum_iterations"
1490      op: "Const"
1491      attr {
1492        key: "dtype"
1493        value {
1494          type: DT_INT32
1495        }
1496      }
1497      attr {
1498        key: "value"
1499        value {
1500          tensor {
1501            dtype: DT_INT32
1502            tensor_shape {
1503            }
1504            int_val: -1
1505          }
1506        }
1507      }
1508    }
1509    node {
1510      name: "while/loop_counter"
1511      op: "Const"
1512      attr {
1513        key: "dtype"
1514        value {
1515          type: DT_INT32
1516        }
1517      }
1518      attr {
1519        key: "value"
1520        value {
1521          tensor {
1522            dtype: DT_INT32
1523            tensor_shape {
1524            }
1525            int_val: 0
1526          }
1527        }
1528      }
1529    }
1530    node {
1531      name: "while"
1532      op: "StatelessWhile"
1533      input: "while/loop_counter"
1534      input: "while/maximum_iterations"
1535      input: "Const"
1536      attr {
1537        key: "T"
1538        value {
1539          list {
1540            type: DT_INT32
1541            type: DT_INT32
1542            type: DT_FLOAT
1543          }
1544        }
1545      }
1546      attr {
1547        key: "_lower_using_switch_merge"
1548        value {
1549          b: true
1550        }
1551      }
1552      attr {
1553        key: "_num_original_outputs"
1554        value {
1555          i: 3
1556        }
1557      }
1558      attr {
1559        key: "_read_only_resource_inputs"
1560        value {
1561          list {
1562          }
1563        }
1564      }
1565      attr {
1566        key: "body"
1567        value {
1568          func {
1569            name: "while_body_822"
1570          }
1571        }
1572      }
1573      attr {
1574        key: "cond"
1575        value {
1576          func {
1577            name: "while_cond_821"
1578          }
1579        }
1580      }
1581      attr {
1582        key: "output_shapes"
1583        value {
1584          list {
1585            shape {
1586            }
1587            shape {
1588            }
1589            shape {
1590            }
1591          }
1592        }
1593      }
1594      attr {
1595        key: "parallel_iterations"
1596        value {
1597          i: 10
1598        }
1599      }
1600    }
1601    node {
1602      name: "while/Identity"
1603      op: "Identity"
1604      input: "while"
1605      attr {
1606        key: "T"
1607        value {
1608          type: DT_INT32
1609        }
1610      }
1611    }
1612    node {
1613      name: "while/Identity_1"
1614      op: "Identity"
1615      input: "while:1"
1616      attr {
1617        key: "T"
1618        value {
1619          type: DT_INT32
1620        }
1621      }
1622    }
1623    node {
1624      name: "while/Identity_2"
1625      op: "Identity"
1626      input: "while:2"
1627      attr {
1628        key: "T"
1629        value {
1630          type: DT_FLOAT
1631        }
1632      }
1633    }
1634    library {
1635      function {
1636        signature {
1637          name: "while_body_822"
1638          input_arg {
1639            name: "while_loop_counter"
1640            type: DT_INT32
1641          }
1642          input_arg {
1643            name: "while_maximum_iterations_0"
1644            type: DT_INT32
1645          }
1646          input_arg {
1647            name: "placeholder"
1648            type: DT_FLOAT
1649          }
1650          output_arg {
1651            name: "add"
1652            type: DT_INT32
1653          }
1654          output_arg {
1655            name: "while_maximum_iterations"
1656            type: DT_INT32
1657          }
1658          output_arg {
1659            name: "partitionedcall"
1660            type: DT_FLOAT
1661          }
1662        }
1663        node_def {
1664          name: "PartitionedCall"
1665          op: "PartitionedCall"
1666          input: "placeholder"
1667          attr {
1668            key: "Tin"
1669            value {
1670              list {
1671                type: DT_FLOAT
1672              }
1673            }
1674          }
1675          attr {
1676            key: "Tout"
1677            value {
1678              list {
1679                type: DT_FLOAT
1680              }
1681            }
1682          }
1683          attr {
1684            key: "_collective_manager_ids"
1685            value {
1686              list {
1687              }
1688            }
1689          }
1690          attr {
1691            key: "_read_only_resource_inputs"
1692            value {
1693              list {
1694              }
1695            }
1696          }
1697          attr {
1698            key: "config"
1699            value {
1700              s: ""
1701            }
1702          }
1703          attr {
1704            key: "config_proto"
1705            value {
1706              s: ""
1707            }
1708          }
1709          attr {
1710            key: "executor_type"
1711            value {
1712              s: ""
1713            }
1714          }
1715          attr {
1716            key: "f"
1717            value {
1718              func {
1719                name: "__inference_f_841"
1720              }
1721            }
1722          }
1723          experimental_debug_info {
1724            original_node_names: "PartitionedCall"
1725          }
1726        }
1727        node_def {
1728          name: "add/y"
1729          op: "Const"
1730          attr {
1731            key: "dtype"
1732            value {
1733              type: DT_INT32
1734            }
1735          }
1736          attr {
1737            key: "value"
1738            value {
1739              tensor {
1740                dtype: DT_INT32
1741                tensor_shape {
1742                }
1743                int_val: 1
1744              }
1745            }
1746          }
1747          experimental_debug_info {
1748            original_node_names: "add/y"
1749          }
1750        }
1751        node_def {
1752          name: "add_0"
1753          op: "AddV2"
1754          input: "while_loop_counter"
1755          input: "add/y:output:0"
1756          attr {
1757            key: "T"
1758            value {
1759              type: DT_INT32
1760            }
1761          }
1762          experimental_debug_info {
1763            original_node_names: "add"
1764          }
1765        }
1766        ret {
1767          key: "add"
1768          value: "add_0:z:0"
1769        }
1770        ret {
1771          key: "partitionedcall"
1772          value: "PartitionedCall:output:0"
1773        }
1774        ret {
1775          key: "while_maximum_iterations"
1776          value: "while_maximum_iterations_0"
1777        }
1778        arg_attr {
1779          key: 0
1780          value {
1781            attr {
1782              key: "_output_shapes"
1783              value {
1784                list {
1785                  shape {
1786                  }
1787                }
1788              }
1789            }
1790          }
1791        }
1792        arg_attr {
1793          key: 1
1794          value {
1795            attr {
1796              key: "_output_shapes"
1797              value {
1798                list {
1799                  shape {
1800                  }
1801                }
1802              }
1803            }
1804          }
1805        }
1806        arg_attr {
1807          key: 2
1808          value {
1809            attr {
1810              key: "_output_shapes"
1811              value {
1812                list {
1813                  shape {
1814                  }
1815                }
1816              }
1817            }
1818          }
1819        }
1820      }
1821      function {
1822        signature {
1823          name: "while_cond_821"
1824          input_arg {
1825            name: "while_loop_counter"
1826            type: DT_INT32
1827          }
1828          input_arg {
1829            name: "while_maximum_iterations"
1830            type: DT_INT32
1831          }
1832          input_arg {
1833            name: "placeholder"
1834            type: DT_FLOAT
1835          }
1836          output_arg {
1837            name: "less"
1838            type: DT_BOOL
1839          }
1840        }
1841        node_def {
1842          name: "Less/y"
1843          op: "Const"
1844          attr {
1845            key: "dtype"
1846            value {
1847              type: DT_FLOAT
1848            }
1849          }
1850          attr {
1851            key: "value"
1852            value {
1853              tensor {
1854                dtype: DT_FLOAT
1855                tensor_shape {
1856                }
1857                float_val: 5.0
1858              }
1859            }
1860          }
1861          experimental_debug_info {
1862            original_node_names: "Less/y"
1863          }
1864        }
1865        node_def {
1866          name: "Less"
1867          op: "Less"
1868          input: "placeholder"
1869          input: "Less/y:output:0"
1870          attr {
1871            key: "T"
1872            value {
1873              type: DT_FLOAT
1874            }
1875          }
1876          experimental_debug_info {
1877            original_node_names: "Less"
1878          }
1879        }
1880        ret {
1881          key: "less"
1882          value: "Less:z:0"
1883        }
1884        arg_attr {
1885          key: 0
1886          value {
1887            attr {
1888              key: "_output_shapes"
1889              value {
1890                list {
1891                  shape {
1892                  }
1893                }
1894              }
1895            }
1896          }
1897        }
1898        arg_attr {
1899          key: 1
1900          value {
1901            attr {
1902              key: "_output_shapes"
1903              value {
1904                list {
1905                  shape {
1906                  }
1907                }
1908              }
1909            }
1910          }
1911        }
1912        arg_attr {
1913          key: 2
1914          value {
1915            attr {
1916              key: "_output_shapes"
1917              value {
1918                list {
1919                  shape {
1920                  }
1921                }
1922              }
1923            }
1924          }
1925        }
1926      }
1927      function {
1928        signature {
1929          name: "__inference_f_841"
1930          input_arg {
1931            name: "mul_placeholder"
1932            type: DT_FLOAT
1933          }
1934          output_arg {
1935            name: "identity"
1936            type: DT_FLOAT
1937          }
1938        }
1939        node_def {
1940          name: "mul/y"
1941          op: "Const"
1942          attr {
1943            key: "dtype"
1944            value {
1945              type: DT_FLOAT
1946            }
1947          }
1948          attr {
1949            key: "value"
1950            value {
1951              tensor {
1952                dtype: DT_FLOAT
1953                tensor_shape {
1954                }
1955                float_val: 2.0
1956              }
1957            }
1958          }
1959          experimental_debug_info {
1960            original_node_names: "mul/y"
1961          }
1962        }
1963        node_def {
1964          name: "mul"
1965          op: "Mul"
1966          input: "mul_placeholder"
1967          input: "mul/y:output:0"
1968          attr {
1969            key: "T"
1970            value {
1971              type: DT_FLOAT
1972            }
1973          }
1974          experimental_debug_info {
1975            original_node_names: "mul"
1976          }
1977        }
1978        node_def {
1979          name: "Identity"
1980          op: "Identity"
1981          input: "mul:z:0"
1982          attr {
1983            key: "T"
1984            value {
1985              type: DT_FLOAT
1986            }
1987          }
1988          experimental_debug_info {
1989            original_node_names: "Identity"
1990          }
1991        }
1992        ret {
1993          key: "identity"
1994          value: "Identity:output:0"
1995        }
1996        arg_attr {
1997          key: 0
1998          value {
1999            attr {
2000              key: "_output_shapes"
2001              value {
2002                list {
2003                  shape {
2004                  }
2005                }
2006              }
2007            }
2008          }
2009        }
2010      }
2011    }
2012    versions {
2013      producer: 399
2014      min_consumer: 12
2015    }
2016    """
2017    # Code for generating above graph:
2018    #
2019    # def Body(i):
2020    #   @tf.function
2021    #   def f():
2022    #     return i * 2
2023    #   return f()
2024    # tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)])
2025    graph_def = graph_pb2.GraphDef()
2026    text_format.Parse(serialized, graph_def)
2027    @def_function.function
2028    def F():
2029      x, y = importer.import_graph_def(
2030          graph_def, return_elements=["Const:0", "while:2"])
2031      grad_out, = gradients_impl.gradients(y, x)
2032      return grad_out
2033    self.assertAllEqual(F(), 8.0)
2034
2035  def testIndexedSlicesInIncomingGrads(self):
2036    @def_function.function
2037    def F():
2038      x = constant_op.constant([2.])
2039      # Computes x^4
2040      ret = while_loop_v2(
2041          lambda _: True, lambda v: v * v, [x], return_same_structure=False,
2042          maximum_iterations=2)
2043      v = array_ops.gather(ret, [0])
2044      return gradients_impl.gradients(v, [x])[0]  # 4*x^3
2045    self.assertAllEqual(self.evaluate(F()), [32.])
2046
2047  def testShapeInvariantsRaggedTensor(self):
2048
2049    @def_function.function
2050    def TestFn(x):
2051      _, ret = while_loop_v2(
2052          lambda i, _: i < 1,
2053          lambda i, y: (i + 1, array_ops.concat([y, y], axis=0)),
2054          [0, x],
2055          shape_invariants=[
2056              tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
2057              ragged_tensor.RaggedTensorSpec(shape=[None, None])],
2058      )
2059      return ret
2060
2061    x = ragged_factory_ops.constant([[1., 2.], [3.]])
2062    result = TestFn(x)
2063    expected_result = [[1., 2.], [3.], [1., 2.], [3.]]
2064    self.assertAllEqual(result, expected_result)
2065
2066
2067def ScalarShape():
2068  return ops.convert_to_tensor([], dtype=dtypes.int32)
2069
2070
2071def GetOptimizedGraph():
2072  mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
2073  config = config_pb2.ConfigProto()
2074  config.graph_options.rewrite_options.CopyFrom(
2075      rewriter_config_pb2.RewriterConfig(
2076          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
2077          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
2078  return tf_optimizer.OptimizeGraph(config, mg)
2079
2080
2081if __name__ == "__main__":
2082  test.main()
2083