1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=g-long-lambda
17"""Tests for tensorflow.ops.control_flow_ops."""
18
19import collections
20import math
21import re
22import sys
23import time
24
25from absl.testing import parameterized
26import numpy as np
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import rewriter_config_pb2
30from tensorflow.python import tf2
31from tensorflow.python.client import device_lib
32from tensorflow.python.client import session
33from tensorflow.python.data.experimental.ops import cardinality
34from tensorflow.python.data.ops import dataset_ops
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.eager import function as eager_function
38from tensorflow.python.eager import wrap_function
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors_impl
42from tensorflow.python.framework import function
43from tensorflow.python.framework import indexed_slices
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_spec
48from tensorflow.python.framework import test_util
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import control_flow_ops
51from tensorflow.python.ops import control_flow_util
52from tensorflow.python.ops import data_flow_ops
53from tensorflow.python.ops import functional_ops
54from tensorflow.python.ops import gen_array_ops
55from tensorflow.python.ops import gen_control_flow_ops
56from tensorflow.python.ops import gen_data_flow_ops
57from tensorflow.python.ops import gen_logging_ops
58from tensorflow.python.ops import gen_state_ops
59from tensorflow.python.ops import gradient_checker_v2
60from tensorflow.python.ops import gradients_impl
61from tensorflow.python.ops import init_ops
62from tensorflow.python.ops import linalg_ops
63from tensorflow.python.ops import logging_ops
64from tensorflow.python.ops import map_fn
65from tensorflow.python.ops import math_ops
66from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
67from tensorflow.python.ops import nn_ops
68from tensorflow.python.ops import random_ops
69from tensorflow.python.ops import resource_variable_ops
70from tensorflow.python.ops import script_ops
71from tensorflow.python.ops import sparse_ops
72from tensorflow.python.ops import state_ops
73from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
74from tensorflow.python.ops import tensor_array_ops
75from tensorflow.python.ops import variable_scope
76from tensorflow.python.ops import variables
77from tensorflow.python.ops import while_v2  # pylint: disable=unused-import
78# pylint: disable=unused-import
79from tensorflow.python.ops.ragged import ragged_factory_ops
80from tensorflow.python.ops.ragged import ragged_tensor
81import tensorflow.python.ops.tensor_array_grad
82# pylint: enable=unused-import
83from tensorflow.python.platform import test
84from tensorflow.python.training import adam
85from tensorflow.python.training import gradient_descent
86from tensorflow.python.util import nest
87
88
89def check_consumers(graph):
90  """Sanity check on the consumer list of the tensors."""
91
92  consumer_count = {}
93  for op in graph.get_operations():
94    for v in op.inputs:
95      cnt = consumer_count.get(v, 0)
96      consumer_count[v] = cnt + 1
97  for k, v in consumer_count.items():
98    if len(k.consumers()) != v:
99      return False
100  return True
101
102
103def all_fetchables():
104  tensor_names = []
105  graph = ops.get_default_graph()
106  for op in graph.get_operations():
107    for t in op.outputs:
108      if graph.is_fetchable(t):
109        tensor_names.append(t.name)
110  return tensor_names
111
112
113def all_feedables():
114  feedable_tensors = []
115  graph = ops.get_default_graph()
116  for op in graph.get_operations():
117    for t in op.inputs:
118      if graph.is_feedable(t):
119        feedable_tensors.append(t)
120  return feedable_tensors
121
122
123def opt_cfg(do_constant_folding=True):
124  return config_pb2.ConfigProto(
125      allow_soft_placement=True,
126      graph_options=config_pb2.GraphOptions(
127          optimizer_options=config_pb2.OptimizerOptions(
128              opt_level=config_pb2.OptimizerOptions.L1,
129              do_function_inlining=True,
130              do_constant_folding=do_constant_folding)))
131
132
133def isum(s, maximum_iterations=None):
134  i = constant_op.constant(0, name="i")
135  c = lambda i, s: math_ops.less(i, 10)
136  b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
137  _, r_s = control_flow_ops.while_loop(
138      c, b, [i, s], maximum_iterations=maximum_iterations)
139  return r_s
140
141
142def enqueue_print_op(s):
143  """Enqueues an op that prints a message to be captured in the test."""
144  return logging_ops.print_v2("ControlFlowOpsTest: " + s)
145
146
147def filter_test_messages(s):
148  """Returns a list of messages printed by enqueue_print_op."""
149  prefix = "ControlFlowOpsTest: "
150  return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
151
152
153def tf_function_in_tf2(f):
154  if tf2.enabled():
155    # In TF1 do not wrap with tf.function so that we can test the v1 control
156    # flow code path.
157    return def_function.function(f)
158  return f
159
160
161@test_util.with_eager_op_as_function
162@test_util.with_control_flow_v2
163class ControlFlowTest(test.TestCase, parameterized.TestCase):
164
165  @test_util.run_v1_only("b/120545219")
166  def testRefIdentity(self):
167    with self.cached_session():
168      v = variables.VariableV1(7)
169
170      v = control_flow_ops._Identity(v)
171      op = state_ops.assign(v, 9)
172      v2 = control_flow_ops.with_dependencies([op], v)
173
174      self.assertTrue(isinstance(v2, ops.Tensor))
175      self.evaluate(variables.global_variables_initializer())
176      self.assertEqual(9, self.evaluate(v2))
177
178  @test_util.run_v1_only("b/120545219")
179  def testRefEnter(self):
180    with self.cached_session():
181      v = variables.VariableV1(7)
182
183      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
184      nine = constant_op.constant(9)
185      enter_nine = gen_control_flow_ops.enter(nine, "foo_1")
186      op = state_ops.assign(enter_v, enter_nine)
187      v2 = control_flow_ops.with_dependencies([op], enter_v)
188      v3 = control_flow_ops.exit(v2)
189      self.evaluate(variables.global_variables_initializer())
190      self.assertEqual(9, self.evaluate(v3))
191
192  @test_util.run_v1_only("b/120545219")
193  def testRefSwitch(self):
194    with self.cached_session():
195      v = variables.VariableV1(7)
196
197      p = constant_op.constant(True)
198      v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
199      v2 = state_ops.assign(v1[1], 9)
200      self.evaluate(variables.global_variables_initializer())
201      self.assertEqual(9, self.evaluate(v2))
202
203  def testEnterMulExit(self):
204    with self.cached_session():
205      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
206      enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
207      five = constant_op.constant(5)
208      enter_five = gen_control_flow_ops.enter(five, "foo_1", False)
209      mul_op = math_ops.multiply(enter_data, enter_five)
210      exit_op = control_flow_ops.exit(mul_op)
211
212      result = self.evaluate(exit_op)
213    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
214
215  @test_util.run_deprecated_v1
216  def testEnterShapePropagation(self):
217    with self.cached_session():
218      v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
219
220      # If is_constant=True, the shape information should be propagated.
221      enter_v_constant = gen_control_flow_ops.enter(
222          v, "frame1", is_constant=True)
223      self.assertEqual(enter_v_constant.shape, [2])
224
225      # Otherwise, the shape should be unknown.
226      enter_v_non_constant = gen_control_flow_ops.enter(
227          v, "frame2", is_constant=False)
228      self.assertEqual(enter_v_non_constant.shape, None)
229
230  @test_util.run_v1_only("b/120545219")
231  def testSwitchMergeIndexedSlices(self):
232    with self.cached_session():
233      values = constant_op.constant([1, 2, 3, 4, 5, 6])
234      indices = constant_op.constant([0, 2, 4, 6, 8, 10])
235      data = indexed_slices.IndexedSlices(values, indices)
236      pred = ops.convert_to_tensor(True)
237      switch_op = control_flow_ops.switch(data, pred)
238      merge_op = control_flow_ops.merge(switch_op)[0]
239
240      val = merge_op.values
241      ind = merge_op.indices
242    self.assertAllEqual(np.arange(1, 7), val)
243    self.assertAllEqual(np.arange(0, 12, 2), ind)
244
245  @test_util.run_v1_only("b/120545219")
246  def testSwitchDeadBranch(self):
247    with self.cached_session():
248      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
249      ports = ops.convert_to_tensor(True, name="ports")
250      switch_op = control_flow_ops.switch(data, ports)
251      dead_branch = array_ops.identity(switch_op[0])
252
253      with self.assertRaisesWithPredicateMatch(
254          errors_impl.InvalidArgumentError,
255          lambda e: "Retval[0] does not have value" in str(e)):
256        self.evaluate(dead_branch)
257
258  @test_util.run_v1_only("b/120545219")
259  def testSwitchMergeLess(self):
260    with self.cached_session():
261      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
262      zero = ops.convert_to_tensor(0)
263      one = ops.convert_to_tensor(1)
264      less_op = math_ops.less(zero, one)
265      switch_op = control_flow_ops.switch(data, less_op)
266      merge_op = control_flow_ops.merge(switch_op)[0]
267
268      result = self.evaluate(merge_op)
269    self.assertAllEqual(np.arange(1, 7), result)
270
271  @test_util.run_v1_only("b/120545219")
272  def testSwitchMergeAddIdentity(self):
273    with self.cached_session():
274      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
275      ports = ops.convert_to_tensor(False, name="ports")
276      switch_op = control_flow_ops.switch(data, ports)
277      one = constant_op.constant(1)
278      add_op = math_ops.add(switch_op[0], one)
279      id_op = array_ops.identity(switch_op[1])
280      merge_op = control_flow_ops.merge([add_op, id_op])[0]
281
282      result = self.evaluate(merge_op)
283    self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
284
285  @test_util.run_v1_only("b/120545219")
286  def testSwitchMergeAddMul(self):
287    with self.cached_session():
288      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
289      ports = ops.convert_to_tensor(True, name="ports")
290      switch_op = control_flow_ops.switch(data, ports)
291      one = constant_op.constant(1)
292      add_op = math_ops.add(switch_op[0], one)
293      five = constant_op.constant(5)
294      mul_op = math_ops.multiply(switch_op[1], five)
295      merge_op = control_flow_ops.merge([add_op, mul_op])[0]
296
297      result = self.evaluate(merge_op)
298    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
299
300  @test_util.run_v1_only("b/120545219")
301  def testLoop_false(self):
302    with self.cached_session():
303      false = ops.convert_to_tensor(False)
304      n = constant_op.constant(10)
305
306      enter_false = gen_control_flow_ops.enter(false, "foo_1", False)
307      enter_n = gen_control_flow_ops.enter(n, "foo_1", False)
308
309      merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
310      switch_n = control_flow_ops.switch(merge_n, enter_false)
311      exit_n = control_flow_ops.exit(switch_n[0])
312      next_n = control_flow_ops.next_iteration(switch_n[0])
313      merge_n.op._update_input(1, next_n)
314
315      result = self.evaluate(exit_n)
316    self.assertAllEqual(10, result)
317
318  @test_util.run_deprecated_v1
319  def testLoop_1(self):
320    with self.cached_session():
321      zero = constant_op.constant(0)
322      one = constant_op.constant(1)
323      n = constant_op.constant(10)
324
325      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
326      enter_one = gen_control_flow_ops.enter(one, "foo", True)
327      enter_n = gen_control_flow_ops.enter(n, "foo", True)
328
329      with ops.device(test.gpu_device_name()):
330        merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
331
332      less_op = math_ops.less(merge_i, enter_n)
333      cond_op = control_flow_ops.loop_cond(less_op)
334      switch_i = control_flow_ops.switch(merge_i, cond_op)
335
336      add_i = math_ops.add(switch_i[1], enter_one)
337
338      next_i = control_flow_ops.next_iteration(add_i)
339      merge_i.op._update_input(1, next_i)
340
341      exit_i = control_flow_ops.exit(switch_i[0])
342      result = self.evaluate(exit_i)
343    self.assertAllEqual(10, result)
344
345  @test_util.run_v1_only("b/120545219")
346  def testLoop_2(self):
347    with self.cached_session():
348      zero = constant_op.constant(0)
349      one = constant_op.constant(1)
350      n = constant_op.constant(10)
351
352      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
353      enter_one = gen_control_flow_ops.enter(one, "foo", True)
354      enter_n = gen_control_flow_ops.enter(n, "foo", True)
355
356      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
357
358      less_op = math_ops.less(merge_i, enter_n)
359      cond_op = control_flow_ops.loop_cond(less_op)
360      switch_i = control_flow_ops.switch(merge_i, cond_op)
361
362      add_i = math_ops.add(switch_i[1], enter_one)
363
364      with ops.device(test.gpu_device_name()):
365        next_i = control_flow_ops.next_iteration(add_i)
366      merge_i.op._update_input(1, next_i)
367
368      exit_i = control_flow_ops.exit(switch_i[0])
369      result = self.evaluate(exit_i)
370    self.assertAllEqual(10, result)
371
372  @test_util.run_v1_only("b/120545219")
373  def testDifferentFrame(self):
374    with self.cached_session():
375      data = array_ops.placeholder(dtypes.float32, shape=[])
376      enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
377      enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
378      res = math_ops.add(enter_1, enter_2)
379      with self.assertRaisesOpError("has inputs from different frames"):
380        res.eval(feed_dict={data: 1.0})
381
382  @test_util.run_deprecated_v1
383  def testCondBool(self):
384    values = constant_op.constant(10)
385    fn1 = lambda: math_ops.add(values, 1)
386    fn2 = lambda: math_ops.subtract(values, 1)
387    with self.assertRaisesRegex(TypeError, "must not be a Python bool"):
388      _ = control_flow_ops.cond(False, fn1, fn2)
389
390  @test_util.run_deprecated_v1
391  def testCondInt(self):
392    p = array_ops.placeholder(dtypes.bool, shape=[])
393    v = constant_op.constant(10)
394    fn1 = lambda: math_ops.add(v, 1)
395    fn2 = lambda: math_ops.subtract(v, 1)
396    y = control_flow_ops.cond(p, fn1, fn2)
397    grad = gradients_impl.gradients(y, [v])
398    self.assertAllEqual([None], grad)
399
400  def testCondOutputShape(self):
401    x = constant_op.constant(1.0)
402    b = control_flow_ops.cond(
403        constant_op.constant(True), lambda: math_ops.square(x),
404        lambda: math_ops.subtract(x, 1.))
405    self.assertEqual(b.shape, tensor_shape.TensorShape([]))
406
407  @test_util.run_v1_only("b/120545219")
408  def testFetchable(self):
409    with self.cached_session() as sess:
410      x = array_ops.placeholder(dtypes.float32)
411      control_flow_ops.cond(
412          constant_op.constant(True), lambda: x + 2, lambda: x + 0)
413      graph = ops.get_default_graph()
414      for op in graph.get_operations():
415        for t in op.inputs:
416          if graph.is_fetchable(t.op):
417            sess.run(t, feed_dict={x: 3})
418          else:
419            with self.assertRaisesRegex(ValueError,
420                                        "has been marked as not fetchable"):
421              sess.run(t, feed_dict={x: 3})
422
423  @test_util.disable_control_flow_v2("Not relevant")
424  @test_util.run_v1_only("b/120545219")
425  def testFeedable(self):
426    with self.cached_session() as sess:
427      c = constant_op.constant(2)
428      i0 = constant_op.constant(0)
429      r = control_flow_ops.while_loop(lambda i: i < 1000,
430                                      lambda i: math_ops.square(c) + i, [i0])
431      self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
432      feedable_tensors = all_feedables()
433      for t in feedable_tensors:
434        sess.run(r, feed_dict={t: 3})
435      graph = ops.get_default_graph()
436      for op in graph.get_operations():
437        for t in op.inputs:
438          if t not in feedable_tensors and t.dtype is dtypes.int32:
439            with self.assertRaisesRegex(ValueError, "may not be fed"):
440              sess.run(r, feed_dict={t: 3})
441
442  @test_util.run_v1_only("b/120545219")
443  def testCondIndexedSlices(self):
444    with self.cached_session():
445      values = constant_op.constant([10])
446      indices = constant_op.constant([0])
447      x = indexed_slices.IndexedSlices(values, indices)
448      pred = math_ops.less(1, 2)
449      fn1 = lambda: indexed_slices.IndexedSlices(
450          math_ops.add(x.values, 1), indices)
451      fn2 = lambda: indexed_slices.IndexedSlices(
452          math_ops.subtract(x.values, 1), indices)
453      r = control_flow_ops.cond(pred, fn1, fn2)
454
455      val = r.values
456      ind = r.indices
457    self.assertAllEqual([11], val)
458    self.assertAllEqual([0], ind)
459
460  def testCondMismatchedIndexedSlices(self):
461    @def_function.function
462    def foo():
463      values = constant_op.constant([10])
464      indices = constant_op.constant([0])
465      x = indexed_slices.IndexedSlices(values, indices)
466      with self.assertRaisesRegex(TypeError,
467                                  "Cannot reconcile tf.cond 0-th outputs"):
468        control_flow_ops.cond(
469            constant_op.constant(True), lambda: indexed_slices.IndexedSlices(
470                math_ops.add(x.values, 1), indices),
471            lambda: math_ops.add(x.values, 1), indices)
472    foo()
473
474  def testCondSparseTensor(self):
475    values = constant_op.constant([2.0, 4.0], name="values")
476    indices = constant_op.constant([[0], [3]],
477                                   dtype=dtypes.int64,
478                                   name="indices")
479    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
480    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
481    pred = math_ops.less(1, 2)
482    fn1 = lambda: sparse_tensor.SparseTensor(
483        indices + 1, x.values + 1, dense_shape=shape)
484    fn2 = lambda: sparse_tensor.SparseTensor(
485        indices, x.values - 1, dense_shape=shape)
486    r = control_flow_ops.cond(pred, fn1, fn2)
487    self.assertAllEqual([3.0, 5.0], r.values)
488    self.assertAllEqual([[1], [4]], r.indices)
489    self.assertAllEqual(r.values.get_shape(), (2,))
490
491  def testCondRaggedTensor(self):
492    rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
493    pred = math_ops.less(1, 2)
494    fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0)
495    fn2 = lambda: rt[:2] - 2
496    result = control_flow_ops.cond(pred, fn1, fn2)
497    self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values)
498    self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits)
499
500  @test_util.run_v1_only("b/120545219")
501  def testCondResource(self):
502
503    with self.cached_session():
504      rv = resource_variable_ops.ResourceVariable(True)
505      self.evaluate(variables.global_variables_initializer())
506      t = ops.convert_to_tensor(1.0)
507
508      def case():
509        assign = resource_variable_ops.assign_variable_op(rv.handle, False)
510        with ops.control_dependencies([assign]):
511          return array_ops.identity(t)
512
513      self.assertEqual(
514          1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t)))
515
516  @test_util.run_deprecated_v1
517  def testCondResourceGradShape(self):
518    rv1 = resource_variable_ops.ResourceVariable([1.0, 2.0])
519    rv2 = resource_variable_ops.ResourceVariable([3.0, 4.0])
520    pred = constant_op.constant(True)
521    result = control_flow_ops.cond(pred, lambda: rv1, lambda: rv2)
522    grads = gradients_impl.gradients(result, [rv1, rv2])
523    self.assertAllEqual(grads[0].shape.as_list(), [2])
524    self.assertAllEqual(grads[1].shape.as_list(), [2])
525
526  @test_util.run_v1_only("b/120545219")
527  def testCondWithTensorArrayGrad(self):
528    with self.cached_session() as sess:
529      with ops.device(test.gpu_device_name()):
530        pred = array_ops.placeholder(dtypes.bool, [])
531        x = constant_op.constant([1.0, 2.0, 3.0])
532        y = control_flow_ops.cond(
533            pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x),
534            lambda: constant_op.constant([1.0, 1.0, 1.0]))
535        g = gradients_impl.gradients(y, x)[0]
536
537      self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
538      self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
539
540  @test_util.run_v1_only("b/120545219")
541  def testCondIndexedSlicesDifferentTypes(self):
542    with self.cached_session():
543      values = constant_op.constant([10])
544      i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32)
545      i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64)
546      x = indexed_slices.IndexedSlices(values, i_32)
547      pred = math_ops.less(1, 2)
548      fn1 = lambda: indexed_slices.IndexedSlices(
549          math_ops.add(x.values, 1), i_32)
550      fn2 = lambda: indexed_slices.IndexedSlices(
551          math_ops.subtract(x.values, 1), i_64)
552      r = control_flow_ops.cond(pred, fn1, fn2)
553
554      val = r.values
555      ind = r.indices
556    self.assertAllEqual([11], val)
557    self.assertAllEqual([0], ind)
558    self.assertTrue(ind.dtype == np.int64)
559
560  @test_util.run_v1_only("b/120545219")
561  def testCondColocation(self):
562    with self.session():
563      with ops.device("/cpu:0"):
564        v = variables.Variable(7.0)
565
566      x = constant_op.constant(10.0)
567      pred = math_ops.less(1.0, 2.0)
568      fn1 = lambda: math_ops.add(v, 1.0)
569      fn2 = lambda: math_ops.subtract(x, 1.0)
570      r = control_flow_ops.cond(pred, fn1, fn2)
571
572      for op in x.graph.get_operations():
573        if op.name == "cond/Add/Switch":
574          self.assertDeviceEqual(op.device, "/cpu:0")
575
576  def _testCond_1(self, use_gpu):
577    with self.cached_session(use_gpu=use_gpu):
578      x = constant_op.constant(10)
579      pred = math_ops.less(1, 2)
580      fn1 = lambda: math_ops.add(x, 1)
581      fn2 = lambda: math_ops.subtract(x, 1)
582      r = control_flow_ops.cond(pred, fn1, fn2)
583
584      result = self.evaluate(r)
585    self.assertAllEqual(11, result)
586
587  def testCond_1(self):
588
589    self._testCond_1(use_gpu=False)
590    # TODO(b/116526896): Enable GPU tests.
591    # self._testCond_1(use_gpu=True)
592
593  def testCond_2(self):
594
595    with self.cached_session():
596      x = constant_op.constant(10)
597      r = control_flow_ops.cond(
598          math_ops.less(1, 0), lambda: math_ops.add(x, 1),
599          lambda: math_ops.subtract(x, 1))
600      result = self.evaluate(r)
601    self.assertAllEqual(9, result)
602
603  def testCond_3(self):
604
605    with self.cached_session():
606      x = constant_op.constant(10)
607      pred = math_ops.less(1, 2)
608      fn1 = lambda: math_ops.add(x, 1)
609      fn2 = lambda: math_ops.subtract(x, 1)
610      fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1)
611      r = control_flow_ops.cond(pred, fn3, fn2)
612
613      result = self.evaluate(r)
614    self.assertAllEqual(12, result)
615
616  @test_util.run_in_graph_and_eager_modes
617  def testCondPruning(self):
618    v1 = variables.Variable(7)
619    v2 = variables.Variable(7)
620    v3 = variables.Variable(7)
621
622    def f():
623      age = constant_op.constant(3)
624      max_age = constant_op.constant(2)
625      pred = math_ops.greater(age, max_age)
626      fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
627      fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
628      r = control_flow_ops.cond(pred, fn1, fn2)
629      self.assertEqual(len(r), 2)
630      return r[1]
631
632    f_defun = eager_function.defun(f)
633
634    if not context.executing_eagerly():
635      with self.cached_session():
636        self.evaluate(variables.global_variables_initializer())
637        result = self.evaluate(f())
638        self.assertEqual(True, result)
639        # Only second cond result was fetched, so v1 assign shouldn't run.
640        self.assertEqual(7, self.evaluate(v1))
641        self.assertEqual(2, self.evaluate(v2))
642        self.assertEqual(7, self.evaluate(v3))
643
644    result = f_defun()
645    self.assertEqual(True, self.evaluate(result))
646    # Both v1 and v2 branch assignments should be run in defun.
647    self.assertEqual(1, self.evaluate(v1))
648    self.assertEqual(2, self.evaluate(v2))
649    self.assertEqual(7, self.evaluate(v3))
650
651  def testCond_5(self):
652    with self.cached_session():
653      alive = constant_op.constant(True, name="alive")
654      count = constant_op.constant(0, name="count")
655
656      def body(i):
657        return control_flow_ops.cond(
658            alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
659            lambda: [alive, count])
660
661      for i in range(10):
662        alive, count = body(i)
663      self.assertAllEqual(4, self.evaluate(count))
664
665  @test_util.run_v1_only("b/120545219")
666  def testCond_6(self):
667    with self.cached_session():
668      v1 = variables.Variable([7])
669
670      age = constant_op.constant(3)
671      pred = math_ops.greater(age, 4)
672      fn1 = lambda: age
673      fn2 = lambda: v1
674      r = control_flow_ops.cond(pred, fn1, fn2)
675
676      self.evaluate(variables.global_variables_initializer())
677      result = self.evaluate(r)
678      self.assertAllEqual(np.array([7]), result)
679
680  def testCond_7(self):
681    with self.cached_session() as sess:
682      x = constant_op.constant(10)
683      y = constant_op.constant(200)
684      pred = math_ops.less(1, 2)
685      fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
686      fn2 = lambda: [y, y]
687      r = control_flow_ops.cond(pred, fn1, fn2)
688      self.assertAllEqual([11, 12], self.evaluate(r))
689
690  @parameterized.parameters(dtypes.float32, dtypes.float64)
691  @test_util.run_v1_only("Uses tf.gradients")
692  def testCondResourceGrad(self, dtype):
693    init = constant_op.constant([7.], dtype=dtype)
694    v1 = variables.Variable(init)
695
696    age = constant_op.constant(3., dtype=dtype)
697    pred = math_ops.greater(age, 4.)
698    fn1 = lambda: age
699    fn2 = lambda: v1
700    r = control_flow_ops.cond(pred, fn1, fn2)
701
702    grad = gradients_impl.gradients(r, v1)[0]
703    self.evaluate(variables.global_variables_initializer())
704    self.assertAllEqual(grad, [1.])
705
706  @test_util.run_gpu_only
707  @test_util.run_deprecated_v1
708  def testCond_Device(self):
709    x = constant_op.constant(-10.)
710
711    # True branch function defined outside of device scope
712    def true_fn():
713      return math_ops.exp(x)
714
715    with ops.device("CPU:0"):
716      r = control_flow_ops.cond(
717          constant_op.constant(True), true_fn, lambda: 0.)
718      self.assertIn("cpu", r.device.lower())
719
720    with session.Session() as sess:
721      options = config_pb2.RunOptions(output_partition_graphs=True)
722      run_metadata = config_pb2.RunMetadata()
723      sess.run(r, options=options, run_metadata=run_metadata)
724      # We expect that everything runs on CPU, even if GPU is available.
725      self.assertEqual(len(run_metadata.partition_graphs), 1)
726
727  def _count_matching_switch_nodes_on_device(self, run_metadata, device_str,
728                                             dtype):
729    # Returns the number of Switch nodes with type dtype placed on
730    # `device_str`.
731    device_graphs = [
732        g for g in run_metadata.partition_graphs
733        if device_str in g.node[0].device
734    ]
735    self.assertLen(device_graphs, 1)
736    switch_nodes = [
737        n for n in device_graphs[0].node
738        if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum
739    ]
740    return len(switch_nodes)
741
742  @test_util.run_gpu_only
743  @test_util.run_deprecated_v1
744  def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self):
745    x = array_ops.placeholder(dtypes.float32)
746
747    # `arg` is used in the cond then branch so a Switch node is created for it.
748    # We test that the Switch node gets placed on the same device as `arg`.
749    # We force `arg` to be on CPU here.
750    with ops.device("CPU:0"):
751      arg = x + 10.
752
753    def true_fn():
754      with ops.device("CPU:0"):
755        return arg + 1
756
757    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
758
759    # Disable Loop_optimizer grappler pass for this test because it replaces
760    # Switch with Identity when it's part of a dead branch.
761    config = config_pb2.ConfigProto()
762    config.graph_options.rewrite_options.loop_optimization = (
763        rewriter_config_pb2.RewriterConfig.OFF)
764
765    with self.session(config=config) as sess:
766      run_metadata = config_pb2.RunMetadata()
767      options = config_pb2.RunOptions(output_partition_graphs=True)
768      sess.run(
769          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
770      self.assertLen(run_metadata.partition_graphs, 2)
771      # Check that the Switch for `arg` gets placed on CPU.
772      self.assertEqual(
773          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
774                                                      dtypes.float32), 1)
775      self.assertEqual(
776          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
777                                                      dtypes.float32), 0)
778
779  @test_util.run_gpu_only
780  @test_util.run_deprecated_v1
781  def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self):
782    x = array_ops.placeholder(dtypes.float32)
783
784    # `arg` is used in the cond then branch so a Switch node is created for it.
785    # We test that the Switch node gets placed on the same device as `arg`.
786    # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU
787    # by placer.
788    arg = dataset_ops.Dataset.range(8)
789
790    def true_fn():
791      return cardinality.cardinality(arg)
792
793    r = control_flow_ops.cond(
794        constant_op.constant(True), true_fn,
795        lambda: constant_op.constant(0, dtypes.int64))
796
797    # Disable Loop_optimizer grappler pass for this test because it replaces
798    # Switch with Identity when it's part of a dead branch.
799    config = config_pb2.ConfigProto()
800    config.graph_options.rewrite_options.loop_optimization = (
801        rewriter_config_pb2.RewriterConfig.OFF)
802
803    with session.Session(config=config) as sess:
804      run_metadata = config_pb2.RunMetadata()
805      options = config_pb2.RunOptions(output_partition_graphs=True)
806      sess.run(
807          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
808      self.assertLen(run_metadata.partition_graphs, 2)
809      # Check that the Switch for `arg` gets placed on CPU.
810      self.assertEqual(
811          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
812                                                      dtypes.variant), 1)
813      self.assertEqual(
814          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
815                                                      dtypes.variant), 0)
816
817  @test_util.run_gpu_only
818  @test_util.run_deprecated_v1
819  def testCondSwitchColocatedWithInputWhenInputOnGPU(self):
820    x = array_ops.placeholder(dtypes.float32)
821
822    # `arg` is used in the cond then branch so a Switch node is created for it.
823    # We test that the Switch node gets placed on the same device as `arg`.
824    # Note: `arg` gets placed on GPU by default by the placer.
825    arg = x + 10.
826
827    def true_fn():
828      with ops.device("CPU:0"):
829        return arg + 1
830
831    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
832
833    # Disable Loop_optimizer grappler pass for this test because it replaces
834    # Switch with Identity when it's part of a dead branch.
835    config = config_pb2.ConfigProto()
836    config.graph_options.rewrite_options.loop_optimization = (
837        rewriter_config_pb2.RewriterConfig.OFF)
838
839    with session.Session(config=config) as sess:
840      run_metadata = config_pb2.RunMetadata()
841      options = config_pb2.RunOptions(output_partition_graphs=True)
842      sess.run(
843          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
844      self.assertEqual(len(run_metadata.partition_graphs), 2)
845      # Check that the Switch for `arg` gets placed on GPU.
846      self.assertEqual(
847          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
848                                                      dtypes.float32), 0)
849      self.assertEqual(
850          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
851                                                      dtypes.float32), 1)
852
853  def testCondAccessTrueBranchTensorInFalseBranchRaises(self):
854
855    @def_function.function
856    def f():
857      c = constant_op.constant(1.)
858      inputs = {"c": c}
859
860      def true_fn(inputs):
861        inputs["c"] = array_ops.identity(inputs["c"], name="true_branch")
862        return inputs["c"]
863
864      def false_fn(inputs):
865        return array_ops.identity(inputs["c"])
866
867      pred = constant_op.constant(True)
868      return control_flow_ops.cond(
869          pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
870
871    # This was needed for backwards compatibility with TF2 Estimators which
872    # rely on variable names.
873    prefix = "cond/" if context.executing_eagerly() else ""
874
875    with self.assertRaisesRegex(
876        ValueError,
877        "Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
878        prefix):
879      f()
880
881  def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
882
883    @def_function.function
884    def f():
885      c = constant_op.constant(1.)
886      inputs = {"c": c}
887
888      def br1_fn(inputs):
889        inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity")
890        return inputs["c"]
891
892      def br4_fn(inputs):
893        return array_ops.identity(inputs["c"])
894
895      def other_fn():
896        return array_ops.identity(c)
897
898      return control_flow_ops.switch_case(
899          constant_op.constant(2),
900          [other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
901           lambda: br4_fn(inputs)])
902
903    # This was needed for backwards compatibility with TF2 Estimators which
904    # rely on variable names.
905    prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
906    with self.assertRaisesRegex(
907        ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
908        "accessed from branch 4." % prefix):
909      f()
910
911  def testCondListOutput(self):
912    with self.cached_session() as sess:
913      x = constant_op.constant(10)
914      y = constant_op.constant(200)
915      pred = math_ops.less(1, 2)
916      fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)]
917      fn2 = lambda: [y, y]
918      r = control_flow_ops.cond(pred, fn1, fn2)
919      test_result = self.evaluate(r)
920      self.assertListEqual([210, 210], test_result)
921
922  def testTupleOutput(self):
923    with self.cached_session() as sess:
924      x = constant_op.constant(10)
925      y = constant_op.constant(200)
926      pred = math_ops.less(1, 2)
927      fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y))
928      fn2 = lambda: (y, y)
929      r = control_flow_ops.cond(pred, fn1, fn2)
930      test_result = self.evaluate(r)
931      self.assertTupleEqual((210, 210), test_result)
932
933  def testDictOutput(self):
934    with self.cached_session() as sess:
935      x = constant_op.constant(10)
936      y = constant_op.constant(200)
937      pred = math_ops.less(1, 2)
938      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
939      fn2 = lambda: {"a": y, "b": y}
940      r = control_flow_ops.cond(pred, fn1, fn2)
941      test_result = self.evaluate(r)
942      self.assertDictEqual({"a": 210, "b": 210}, test_result)
943
944  def testEmbeddedListOutput(self):
945    x = constant_op.constant(10)
946    y = constant_op.constant(200)
947    pred = math_ops.less(1, 2)
948    fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]]
949    fn2 = lambda: [[y, y]]
950    # Pass strict=True flag as cond_v2 allows for tensors to be
951    # in nested output structures as singletons
952    r = control_flow_ops.cond(pred, fn1, fn2, strict=True)
953    test_result = self.evaluate(r)
954    self.assertListEqual([[210, 210]], test_result)
955
956  def testEmbeddedTupleOutput(self):
957    with self.cached_session() as sess:
958      x = constant_op.constant(10)
959      y = constant_op.constant(200)
960      pred = math_ops.less(1, 2)
961      fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y)))
962      fn2 = lambda: ((y, y))
963      r = control_flow_ops.cond(pred, fn1, fn2)
964      test_result = self.evaluate(r)
965      self.assertTupleEqual(((210, 210)), test_result)
966
967  def testEmbeddedDictOutput(self):
968    with self.cached_session() as sess:
969      x = constant_op.constant(10)
970      y = constant_op.constant(200)
971      pred = math_ops.less(1, 2)
972      fn1 = lambda: {"a": {"c": math_ops.add(x, y)},
973                     "b": {"d": math_ops.add(x, y)}}
974      fn2 = lambda: {"a": {"c": y},
975                     "b": {"d": y}}
976      r = control_flow_ops.cond(pred, fn1, fn2)
977      test_result = self.evaluate(r)
978      self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result)
979
980  @test_util.run_v1_only("b/120545219")
981  def testCheckNestedOutputStruct(self):
982    with self.cached_session() as sess:
983      x = constant_op.constant(10)
984      y = constant_op.constant(200)
985      pred = math_ops.less(1, 2)
986      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
987      fn2 = lambda: {"c": y, "d": y}
988      v1_msg = "The two structures don't have the same nested structure"
989      v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
990                "number, type, and overall structure of return values.")
991      with self.assertRaisesRegex(
992          TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError,
993          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
994        control_flow_ops.cond(pred, fn1, fn2)
995
996  @test_util.run_v1_only("b/120545219")
997  def testCondWithControl(self):
998    with self.cached_session() as sess:
999      control_holder = array_ops.placeholder(dtypes.float32, shape=())
1000      a = constant_op.constant(3)
1001
1002      def true_branch():
1003        with ops.control_dependencies([control_holder]):
1004          _ = a + 1
1005        return a + 2
1006
1007      r = control_flow_ops.cond(
1008          constant_op.constant(True), true_branch,
1009          lambda: constant_op.constant(1))
1010      result = sess.run(r, feed_dict={control_holder: 5.})
1011      self.assertEqual(5, result)
1012
1013  @test_util.run_v1_only("b/120545219")
1014  def testUninitializedRefIdentity(self):
1015    with self.cached_session() as sess:
1016      v = gen_state_ops.variable(
1017          shape=[1],
1018          dtype=dtypes.float32,
1019          name="v",
1020          container="",
1021          shared_name="")
1022      inited = state_ops.is_variable_initialized(v)
1023      v_f, v_t = control_flow_ops.ref_switch(v, inited)
1024      # Both v_f and v_t are uninitialized references. However, an actual use
1025      # of the reference in the 'true' branch in the 'tf.identity' op will
1026      # not 'fire' when v is uninitialized, so this is a valid construction.
1027      # This test tests that ref_identity allows uninitialized ref as input
1028      # so that this construction is allowed.
1029      v_f_op = gen_array_ops.ref_identity(v_f)
1030      v_t_op = gen_array_ops.ref_identity(v_t)
1031      with ops.control_dependencies([v_f_op]):
1032        assign_v = state_ops.assign(v, [1.0])
1033      with ops.control_dependencies([v_t_op]):
1034        orig_v = array_ops.identity(v)
1035      merged_op = control_flow_ops.merge([assign_v, orig_v])
1036      self.assertAllEqual([1.0], self.evaluate(merged_op.output))
1037
1038  def testCondSwitchIdentity(self):
1039    # Make sure the recv identity is not removed by optimization.
1040    with session.Session(config=opt_cfg()) as sess:
1041      pred = constant_op.constant(True)
1042
1043      def fn1():
1044        return control_flow_ops.no_op()
1045
1046      def fn2():
1047        return control_flow_ops.Assert(False, ["Wrong branch!!!"])
1048
1049      r = control_flow_ops.cond(pred, fn1, fn2)
1050      self.evaluate(r)
1051
1052  def testCondRecvIdentity(self):
1053    # Make sure the switch identity is not removed by optimization.
1054    with session.Session(config=opt_cfg()) as sess:
1055      with ops.device(test.gpu_device_name()):
1056        pred = constant_op.constant(True)
1057
1058      def fn1():
1059        return control_flow_ops.no_op()
1060
1061      def fn2():
1062        with ops.device("/cpu:0"):
1063          return control_flow_ops.Assert(False, ["Wrong branch!!!"])
1064
1065      r = control_flow_ops.cond(pred, fn1, fn2)
1066      self.evaluate(r)
1067
1068  @test_util.run_deprecated_v1
1069  @test_util.enable_control_flow_v2
1070  def testDisableLoweringSwitchMerge(self):
1071    if test_util.is_gpu_available():
1072      self.skipTest(
1073          "Single threaded executor doesn't support partitioned graphs.  "
1074          "Skipping GPU test.")
1075    # Make pred feedable to ensure we don't constant-fold it out.
1076    run_opts = config_pb2.RunOptions(
1077        trace_level=config_pb2.RunOptions.FULL_TRACE)
1078    run_metadata_no_lowering = config_pb2.RunMetadata()
1079    run_metadata_with_lowering = config_pb2.RunMetadata()
1080
1081    config = opt_cfg(do_constant_folding=False)
1082
1083    pred = array_ops.placeholder_with_default(
1084        constant_op.constant(True), shape=())
1085    r = control_flow_ops.cond(pred, lambda: True, lambda: False)
1086
1087    with session.Session(config=config) as sess:
1088      r_value = sess.run(
1089          r, options=run_opts, run_metadata=run_metadata_with_lowering)
1090      self.assertEqual(r_value, True)
1091
1092    # Use the single threaded executor, which disables control flow lowering.
1093    config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
1094    with session.Session(config=config) as sess:
1095      r_value = sess.run(
1096          r, options=run_opts, run_metadata=run_metadata_no_lowering)
1097      self.assertEqual(r_value, True)
1098
1099    self.assertTrue(  # pylint: disable=g-complex-comprehension
1100        any("switch" in ns.node_name
1101            for dev_stat in run_metadata_with_lowering.step_stats.dev_stats
1102            for ns in dev_stat.node_stats))
1103
1104    self.assertTrue(  # pylint: disable=g-complex-comprehension
1105        all("switch" not in ns.node_name
1106            for dev_stat in run_metadata_no_lowering.step_stats.dev_stats
1107            for ns in dev_stat.node_stats))
1108
1109  @test_util.run_v1_only("b/120545219")
1110  def testCondGrad_1(self):
1111    with self.cached_session():
1112      x = constant_op.constant(10.0, name="x")
1113      pred = math_ops.less(1, 2)
1114      fn1 = lambda: array_ops.identity(x)
1115      fn2 = lambda: array_ops.identity(x)
1116      r = control_flow_ops.cond(pred, fn1, fn2)
1117
1118      grad = gradients_impl.gradients(r, [x])[0]
1119      self.assertAllEqual(1.0, self.evaluate(grad))
1120
1121  @test_util.run_deprecated_v1
1122  @test_util.enable_control_flow_v2
1123  def testCondComputeGradAfterSessRunFails(self):
1124    with self.cached_session():
1125      x = constant_op.constant(10.0, name="x")
1126      pred = math_ops.less(1, 2)
1127
1128      def true_fn():
1129        a = x * x
1130        return a * a
1131
1132      def false_fn():
1133        return x * x
1134
1135      r = control_flow_ops.cond(pred, true_fn, false_fn)
1136
1137      self.assertAllEqual(r, 10000.)
1138      grad = gradients_impl.gradients(r, [x])[0]
1139      with self.assertRaisesRegex(
1140          errors_impl.InvalidArgumentError,
1141          r"Connecting to invalid output 1 of source node cond which has 1 "
1142          r"outputs. Try using "
1143          "tf.compat.v1.experimental.output_all_intermediates\(True\)."):
1144        self.evaluate(grad)
1145
1146  @test_util.run_deprecated_v1
1147  @test_util.enable_output_all_intermediates
1148  def testCondComputeGradAfterSessRun(self):
1149    with self.cached_session():
1150      x = constant_op.constant(10.0, name="x")
1151      pred = math_ops.less(1, 2)
1152
1153      def true_fn():
1154        a = x * x
1155        return a * a
1156
1157      def false_fn():
1158        return x * x
1159
1160      r = control_flow_ops.cond(pred, true_fn, false_fn)
1161
1162      self.assertAllEqual(r, 10000.)
1163      grad = gradients_impl.gradients(r, [x])[0]
1164      self.assertAllEqual(grad, 4000.)
1165
1166  @test_util.run_deprecated_v1
1167  @test_util.enable_output_all_intermediates
1168  def testNestedCondComputeGradAfterSessRun(self):
1169    with self.cached_session():
1170      x = constant_op.constant(10.0, name="x")
1171      pred = math_ops.less(1, 2)
1172
1173      def true_fn():
1174
1175        def inner_true_fn():
1176          a = x * x
1177          return a * a
1178
1179        def inner_false_fn():
1180          return x * x
1181
1182        return control_flow_ops.cond(
1183            constant_op.constant(True), inner_true_fn, inner_false_fn)
1184
1185      def false_fn():
1186        return x * x
1187
1188      r = control_flow_ops.cond(pred, true_fn, false_fn)
1189
1190      self.assertAllEqual(r, 10000.)
1191      grad = gradients_impl.gradients(r, [x])[0]
1192      self.assertAllEqual(grad, 4000.)
1193
1194  @test_util.run_deprecated_v1
1195  def testCondGrad_2(self):
1196    with self.cached_session():
1197      c = array_ops.placeholder(dtypes.int32, shape=[])
1198      x = constant_op.constant(10.0)
1199      pred = math_ops.less(c, 2)
1200      fn1 = lambda: math_ops.multiply(x, 42.0)
1201      fn2 = lambda: math_ops.multiply(x, 3.0)
1202      r = control_flow_ops.cond(pred, fn1, fn2)
1203
1204      grad = gradients_impl.gradients(r, [x])[0]
1205      self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
1206      self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
1207
1208  @test_util.disable_control_flow_v2(
1209      "b/110550782 (gradient w.r.t external variable)")
1210  @test_util.run_deprecated_v1
1211  def testCondGrad_3(self):
1212    with self.cached_session():
1213      c = array_ops.placeholder(dtypes.int32, shape=[])
1214      ox = constant_op.constant(10.0)
1215      pred = math_ops.less(c, 2)
1216
1217      def fn1(x):
1218        m = x * x
1219        return gradients_impl.gradients(m, [ox])[0]
1220
1221      fn2 = lambda: math_ops.multiply(ox, 3.0)
1222      y = math_ops.multiply(7.0, ox)
1223      r = control_flow_ops.cond(pred, lambda: fn1(y), fn2)
1224
1225      self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
1226      self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
1227
1228  @test_util.run_deprecated_v1
1229  def testCondGradMultiDevice(self):
1230    config = config_pb2.ConfigProto(device_count={"CPU": 2},
1231                                    allow_soft_placement=True)
1232    with self.cached_session(config=config) as sess:
1233      pred = array_ops.placeholder(dtypes.bool, [])
1234      x = array_ops.placeholder(dtypes.float32)
1235      y = array_ops.placeholder(dtypes.float32)
1236
1237      with ops.device("/cpu:0"):
1238        z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
1239
1240      with ops.device("/cpu:1"):
1241        grad = gradients_impl.gradients(z, x)[0]
1242
1243      with ops.device("/cpu:0"):
1244        grad_grad = gradients_impl.gradients(grad, x)[0]
1245
1246      self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
1247      self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
1248
1249      # v1 control flow gets None second derivative for some reason.
1250      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
1251        self.assertIsNone(grad_grad)
1252        return
1253
1254      self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
1255      self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
1256
1257  @test_util.run_v1_only("b/120545219")
1258  def testNestedCond_Simple(self):
1259    with self.cached_session():
1260      x = constant_op.constant(0., name="X")
1261      y = control_flow_ops.cond(
1262          constant_op.constant(True), lambda: x,
1263          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1264      result = gradients_impl.gradients(y, x)[0]
1265      self.assertEqual(1.0, self.evaluate(result))
1266
1267      z = control_flow_ops.cond(
1268          constant_op.constant(False), lambda: x,
1269          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1270      result = gradients_impl.gradients(z, x)[0]
1271      self.assertEqual(1.0, self.evaluate(result))
1272
1273  @test_util.run_v1_only("b/120545219")
1274  def testCondGrad_Gather(self):
1275    with self.cached_session() as sess:
1276      v1 = variables.Variable([1.0, 42.0])
1277      c = array_ops.placeholder(dtypes.int32, shape=[])
1278      pred = math_ops.less(c, 2)
1279      fn1 = lambda: array_ops.identity(v1)
1280      fn2 = lambda: array_ops.gather(v1, [1, 1])
1281      r = control_flow_ops.cond(pred, fn1, fn2)
1282      # The following `grad` is a Tensor since it is the aggregation of an
1283      # IndexedSlice and a Tensor. It is an `IndexedSlices` with control flow
1284      # v2.
1285      grad = gradients_impl.gradients(r, [v1])[0]
1286      self.evaluate(variables.global_variables_initializer())
1287
1288      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1289        self.assertIsInstance(grad, indexed_slices.IndexedSlices)
1290
1291      grad_value = sess.run(grad, feed_dict={c: 1})
1292      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [1.0, 1.0])
1293
1294      grad_value = sess.run(grad, feed_dict={c: 3})
1295      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [0.0, 2.0])
1296
1297  @test_util.run_deprecated_v1
1298  def testCondGrad_ResourceVarSparseRead(self):
1299    # NOTE(skyewm): this test is interesting because the
1300    # ResourceVariable.sparse_read gradient function returns IndexedSlices.
1301    var = resource_variable_ops.ResourceVariable(
1302        np.ones((4, 2), dtype=np.float32))
1303    x = constant_op.constant(1.0)
1304    r = control_flow_ops.cond(
1305        constant_op.constant(True),
1306        lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])),
1307        lambda: constant_op.constant(np.zeros((2, 3)),
1308                                     dtype=dtypes.float32))
1309    grad = gradients_impl.gradients(r, var)[0]
1310
1311    self.evaluate(variables.global_variables_initializer())
1312    grad_val = self.evaluate(grad)
1313    self.assertIsInstance(grad_val, indexed_slices.IndexedSlicesValue)
1314    self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.],
1315                                                                  [1., 1.],
1316                                                                  [1., 1.],
1317                                                                  [0., 0.]])
1318
1319  def testCondGrad_MultiGather(self):
1320    # NOTE(skyewm): this test is interesting because the array_ops.gather and
1321    # ResourceVariable.sparse_read gradient functions returns IndexedSlices.
1322    var = resource_variable_ops.ResourceVariable(
1323        np.ones((4, 2), dtype=np.float32))
1324    x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32))
1325    x2 = constant_op.constant(2.0)
1326
1327    def true_fn():
1328      y1 = var.sparse_read([1, 2])
1329      y2 = array_ops.gather(x1, [2]) * x2
1330      y3 = x2 * [1., 1., 1.]
1331      return y1, y2, y3
1332
1333    def false_fn():
1334      y1 = np.zeros((2, 2), dtype=np.float32)
1335      y2 = array_ops.gather(x1, [2]) * x2
1336      y3 = array_ops.gather(x1, [2])
1337      return y1, y2, y3
1338
1339    @def_function.function
1340    def foo():
1341      r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
1342      return gradients_impl.gradients(r, [var, x1, x2])
1343
1344    grad = foo()
1345    self.evaluate(variables.global_variables_initializer())
1346    var_grad, x1_grad, x2_grad = self.evaluate(grad)
1347    self.assertIsInstance(var_grad, indexed_slices.IndexedSlicesValue)
1348    self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.],
1349                                                                  [1., 1.],
1350                                                                  [1., 1.],
1351                                                                  [0., 0]])
1352    self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue)
1353    self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.],
1354                                                                 [0., 0., 0.],
1355                                                                 [2., 2., 2.]])
1356    self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue)
1357    self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.)
1358
1359  @test_util.run_v1_only("b/120545219")
1360  def testCondPredicateTensor(self):
1361    """Regression test for lowering predicate from non-first output of an op."""
1362
1363    @eager_function.defun
1364    def foo():
1365      return constant_op.constant("foo"), constant_op.constant(True)
1366
1367    r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0)
1368    self.assertEqual(self.evaluate(r), 1.0)
1369
1370  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1371  def testCondFeedConstantPredicate(self):
1372    with self.cached_session() as sess:
1373      value = constant_op.constant(37.0)
1374      predicate = constant_op.constant(True)
1375      cond_output = control_flow_ops.cond(
1376          predicate, lambda: constant_op.constant(0.0), lambda: value)
1377      result = array_ops.identity(cond_output)
1378      self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1379      self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1380      self.assertEqual(0.0, sess.run(result))
1381
1382  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1383  def testCondFeedPlaceholderWithDefaultPredicate(self):
1384    with self.cached_session() as sess:
1385      value = constant_op.constant(37.0)
1386      predicate = array_ops.placeholder_with_default(
1387          constant_op.constant(True), [])
1388      cond_output = control_flow_ops.cond(
1389          predicate, lambda: constant_op.constant(0.0), lambda: value)
1390      result = array_ops.identity(cond_output)
1391      self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1392      self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1393      self.assertAllEqual(0.0, sess.run(result))
1394
1395  def testCondTensorDeps(self):
1396    t = array_ops.identity(1.)
1397
1398    @def_function.function
1399    def f():
1400      with ops.control_dependencies([t]):
1401        return array_ops.identity(2.)
1402
1403    f.get_concrete_function()
1404
1405  @test_util.run_in_graph_and_eager_modes
1406  def testCondAutoControlDeps(self):
1407    if test_util.is_gpu_available():
1408      self.skipTest("b/128676188 causes OOM on opensource gpu tests")
1409
1410    print_prefix = "testCondAutoControlDeps: "
1411
1412    def branch_fn():
1413      enqueue_print_op("A")
1414      enqueue_print_op("B")
1415      with ops.control_dependencies([enqueue_print_op("C")]):
1416        return constant_op.constant(10)
1417
1418    def build_cond():
1419      return control_flow_ops.cond(
1420          constant_op.constant(True), branch_fn, lambda: 0)
1421
1422    def build_nested_cond():
1423      return control_flow_ops.cond(
1424          constant_op.constant(True), build_cond, lambda: 0)
1425
1426    # In v1 graph mode, pruning should make only "C" print.
1427    if not context.executing_eagerly():
1428      with self.cached_session():
1429        with self.captureWritesToStream(sys.stderr) as printed:
1430          self.assertEqual(self.evaluate(build_cond()), 10)
1431        self.assertEqual(["C"], filter_test_messages(printed.contents()))
1432
1433        with self.captureWritesToStream(sys.stderr) as printed:
1434          self.assertEqual(self.evaluate(build_nested_cond()), 10)
1435        self.assertEqual(["C"], filter_test_messages(printed.contents()))
1436
1437    # In defuns, all prints should execute in program order.
1438    # This doesn't work with legacy control flow.
1439    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1440
1441      @eager_function.defun
1442      def cond():
1443        return build_cond()
1444
1445      with self.captureWritesToStream(sys.stderr) as printed:
1446        self.assertEqual(self.evaluate(cond()), 10)
1447      self.assertEqual(["A", "B", "C"],
1448                       filter_test_messages(printed.contents()))
1449
1450      @eager_function.defun
1451      def nested_cond():
1452        return build_nested_cond()
1453
1454      with self.captureWritesToStream(sys.stderr) as printed:
1455        self.assertEqual(self.evaluate(nested_cond()), 10)
1456      self.assertEqual(["A", "B", "C"],
1457                       filter_test_messages(printed.contents()))
1458
1459    # wrap_function should prune.
1460    def pruned_cond():
1461      return build_cond()
1462    pruned_cond = wrap_function.wrap_function(pruned_cond, [])
1463
1464    with self.captureWritesToStream(sys.stderr) as printed:
1465      self.assertEqual(self.evaluate(pruned_cond()), 10)
1466    self.assertEqual(["C"], filter_test_messages(printed.contents()))
1467
1468    def pruned_nested_cond():
1469      return build_nested_cond()
1470    pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, [])
1471
1472    with self.captureWritesToStream(sys.stderr) as printed:
1473      self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
1474    self.assertEqual(["C"], filter_test_messages(printed.contents()))
1475
1476
1477  @test_util.run_in_graph_and_eager_modes
1478  @test_util.disable_tfrt("b/179459136")
1479  def testWhileAutoControlDeps(self):
1480    # Legacy while_loop fails this test because it produces deprecation notices
1481    # in stderr.
1482    if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return
1483
1484    def cond(i, unused_x):
1485      enqueue_print_op("A")
1486      return i < 2
1487
1488    def body(i, x):
1489      enqueue_print_op("B")
1490      with ops.control_dependencies([enqueue_print_op("C")]):
1491        x = array_ops.identity(x)
1492      with ops.control_dependencies([enqueue_print_op("D")]):
1493        return i + 1, x
1494
1495    def build_while():
1496      return control_flow_ops.while_loop(
1497          cond, body, [constant_op.constant(0), constant_op.constant(0)])
1498
1499    def build_nested_while():
1500      return control_flow_ops.cond(
1501          constant_op.constant(True), build_while, lambda: [0, 0])
1502
1503    # In v1 graph mode, pruning should make only "D" print.
1504    if not context.executing_eagerly():
1505      with self.cached_session():
1506        with self.captureWritesToStream(sys.stderr) as printed:
1507          self.assertEqual(self.evaluate(build_while()[0]), 2)
1508        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1509
1510        with self.captureWritesToStream(sys.stderr) as printed:
1511          self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
1512        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1513
1514    # In defuns, all prints should execute in program order.
1515    @eager_function.defun
1516    def while_loop():
1517      return build_while()[0]
1518
1519    with self.captureWritesToStream(sys.stderr) as printed:
1520      self.assertEqual(self.evaluate(while_loop()), 2)
1521    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
1522                     filter_test_messages(printed.contents()))
1523
1524    @eager_function.defun
1525    def nested_while_loop():
1526      return build_nested_while()[0]
1527
1528    with self.captureWritesToStream(sys.stderr) as printed:
1529      self.assertEqual(self.evaluate(nested_while_loop()), 2)
1530    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
1531                     filter_test_messages(printed.contents()))
1532
1533    # wrap_function should prune.
1534    def pruned_while():
1535      return build_while()[0]
1536    pruned_while = wrap_function.wrap_function(pruned_while, [])
1537
1538    with self.captureWritesToStream(sys.stderr) as printed:
1539      self.assertEqual(self.evaluate(pruned_while()), 2)
1540    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1541
1542    def pruned_nested_while():
1543      return build_nested_while()[0]
1544    pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, [])
1545
1546    with self.captureWritesToStream(sys.stderr) as printed:
1547      self.assertEqual(self.evaluate(pruned_nested_while()), 2)
1548    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1549
1550  # Microbenchmark: 256,000 iterations/s.
1551  def testWhile_1(self):
1552    with self.cached_session():
1553      n = constant_op.constant(0)
1554      c = lambda x: math_ops.less(x, 10000)
1555      b = lambda x: math_ops.add(x, 1)
1556      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1557      self.assertEqual(10000, self.evaluate(r))
1558
1559  @test_util.run_v1_only("b/120545219")
1560  def testWhileExternalControlDependencies(self):
1561    with self.cached_session():
1562      v = variables.Variable(0.0)
1563      self.evaluate(v.initializer)
1564      increment = v.assign_add(1.0).read_value()
1565
1566      def body_fn(i):
1567        with ops.control_dependencies([increment]):
1568          return i + 1
1569
1570      result = control_flow_ops.while_loop(cond=lambda i: i < 2,
1571                                           body=body_fn, loop_vars=[1])
1572      self.assertAllEqual(result, 2)
1573      self.assertAllEqual(v.read_value(), 1.0)
1574
1575  @test_util.run_v1_only("b/120545219")
1576  def testWhileExternalControlDependenciesNoInput(self):
1577    with self.cached_session():
1578      v = variables.Variable(0.0)
1579      self.evaluate(v.initializer)
1580      # TODO(apassos): figure out why the reading is necessary here.
1581      increment = v.assign_add(1.0).read_value()
1582
1583      def body_fn(unused_i):
1584        with ops.control_dependencies([increment]):
1585          return constant_op.constant(5, name="five")
1586
1587      result = control_flow_ops.while_loop(cond=lambda i: i < 5,
1588                                           body=body_fn, loop_vars=[0])
1589      self.evaluate(result)
1590      self.assertAllEqual(self.evaluate(v), 1.0)
1591
1592  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
1593  @test_util.run_v1_only("b/120545219")
1594  def testWhileWithRefs_1(self):
1595    with self.cached_session() as sess:
1596      x = variables.VariableV1(0)._ref()  # pylint: disable=protected-access
1597      i = constant_op.constant(0)
1598      c = lambda i, x: math_ops.less(i, 100)
1599
1600      self.assertEqual(x.dtype, dtypes.int32_ref)
1601
1602      def b(i, x):
1603        self.assertEqual(x.dtype, dtypes.int32_ref)
1604        return (i + 1, gen_array_ops.ref_identity(x))
1605
1606      r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5)
1607
1608      self.evaluate(variables.global_variables_initializer())
1609
1610      self.assertEqual(r[0].dtype, dtypes.int32)
1611      self.assertEqual(r[1].dtype, dtypes.int32_ref)
1612
1613      value_i, value_x = self.evaluate(r)
1614
1615    self.assertEqual(100, value_i)
1616    self.assertEqual(0, value_x)
1617
1618  def testWhile_2(self):
1619    with self.cached_session():
1620      s = constant_op.constant(0)
1621      r = isum(s)
1622      self.assertAllEqual(45, self.evaluate(r))
1623
1624  def testWhileWithMaximumIterations(self):
1625    with self.cached_session():
1626      s = constant_op.constant([1, 2, 3, 4, 5])
1627      r = isum(s, maximum_iterations=3)
1628      self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r))
1629
1630  @test_util.run_v1_only("b/120545219")
1631  def testWhileWithMaximumIterationsAndSingleArgument(self):
1632    with self.cached_session():
1633      r = control_flow_ops.while_loop(
1634          lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
1635      self.assertEqual(1, self.evaluate(r))
1636
1637  @test_util.run_v1_only("b/120545219")
1638  def testXLAGradInLoop(self):
1639    # We have an optimization that moves certain reduction ops, this test makes
1640    # sure we don't do that for XLA ops.
1641
1642    # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
1643    # and "Shape" op.
1644    input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
1645    input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
1646    def cond(i1, i2):
1647      return False
1648
1649    def body(i1, i2):
1650      return math_ops.add(i1, i2), math_ops.add(i1, i2)
1651
1652    xla_context = control_flow_ops.XLAControlFlowContext()
1653    xla_context.Enter()
1654
1655    out1, _ = control_flow_ops.while_loop(
1656        cond, body, (input1, input2), maximum_iterations=2)
1657    g = gradients_impl.gradients(out1, [input1])
1658
1659    for op in out1.graph.get_operations():
1660      # Test that the "Shape" is directly passed to BroadcastGradientArgs
1661      # instead of being pushed to the stack.
1662      if op.type == "BroadcastGradientArgs":
1663        self.assertEqual(op.inputs[0].op.type, "Shape")
1664        self.assertEqual(op.inputs[1].op.type, "Shape")
1665    xla_context.Exit()
1666
1667
1668  @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
1669  @test_util.run_v1_only("b/120545219")
1670  def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
1671    v = constant_op.constant(1.0)
1672
1673    def training_loop_with_gradient(i):
1674      out = control_flow_ops.while_loop(
1675          lambda i_, _: i_ < 3,
1676          lambda i_, j: [i_ + 1, j * v], [0, 1.0],
1677          maximum_iterations=i)
1678      g = gradients_impl.gradients(out, v)
1679      with ops.control_dependencies(g):
1680        return i + 1
1681
1682    xla_context = control_flow_ops.XLAControlFlowContext()
1683    xla_context.Enter()
1684    # Create training loop, ensure we can call gradient() of
1685    # while_loop inside the training loop.
1686    loop = control_flow_ops.while_loop(lambda i: i < 3,
1687                                       training_loop_with_gradient, [0])
1688    xla_context.Exit()
1689
1690    loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.
1691
1692    # Should execute without issue.
1693    self.assertEqual(3, self.evaluate(loop_execute))
1694
1695  @test_util.run_v1_only("b/120545219")
1696  def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
1697    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1698      self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
1699    v = constant_op.constant(1.0)
1700
1701    def inner_body(i, x):
1702      out = control_flow_ops.while_loop(
1703          lambda i, _: i < 3,
1704          lambda i, j: [i + 1, j * v], [0, x],
1705          maximum_iterations=i)
1706      return out
1707
1708    def create_while_loop(maximum_iterations=None):
1709      return control_flow_ops.while_loop(
1710          lambda i, _: i < 3,
1711          inner_body, [0, 1.0],
1712          maximum_iterations=maximum_iterations)
1713
1714    loop_no_xla = create_while_loop(maximum_iterations=5)
1715    # maximum_iterations is fine outside of an XLA scope
1716    gs = gradients_impl.gradients(loop_no_xla, v)
1717    self.evaluate(gs)  # This should execute without error.
1718
1719    xla_context = control_flow_ops.XLAControlFlowContext()
1720    xla_context.Enter()
1721    loop_no_maxiter = create_while_loop()
1722    loop_with_maxiter = create_while_loop(maximum_iterations=2)
1723    xla_context.Exit()
1724
1725    with self.assertRaisesRegex(
1726        ValueError,
1727        r"Cannot create a gradient accumulator for tensor '.+' inside "
1728        r"XLA while_loop because maximum_iterations was not passed to "
1729        r"the tf.while_loop call \('.+'\)."):
1730      _ = gradients_impl.gradients(loop_no_maxiter, v)
1731
1732    with self.assertRaisesRegex(
1733        ValueError,
1734        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1735        r"while_loop. maximum_iterations tensor '.+' for while_loop context "
1736        r"'.+' must be statically known \(e.g. a constant value or known "
1737        r"shape dimension\), or be defined at or outside the while loop "
1738        r"context '.*' \(currently defined in '.*'\)"):
1739      _ = gradients_impl.gradients(loop_with_maxiter, v)
1740
1741  @test_util.run_v1_only("b/120545219")
1742  def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
1743    v = constant_op.constant(1.0)
1744
1745    def create_while_loop():
1746      max_iter_holder = []
1747
1748      def create_mi():
1749        max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
1750        return 1.0
1751
1752      _ = control_flow_ops.cond(
1753          constant_op.constant(True), create_mi, create_mi)
1754
1755      return control_flow_ops.while_loop(
1756          lambda i, _: i < 3,
1757          lambda i, x: (i + 1, v * x), (0, 1.0),
1758          maximum_iterations=max_iter_holder[0])
1759
1760    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1761      xla_context = control_flow_ops.XLAControlFlowContext()
1762      xla_context.Enter()
1763      with self.assertRaisesRegex(ValueError, r"must be from the same graph.*"):
1764        loop = create_while_loop()
1765      xla_context.Exit()
1766    else:
1767      xla_context = control_flow_ops.XLAControlFlowContext()
1768      xla_context.Enter()
1769      loop = create_while_loop()
1770      xla_context.Exit()
1771      with self.assertRaisesRegex(
1772          ValueError,
1773          r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1774          r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
1775          r"while_loop context '.+' must be statically known \(e.g. a constant "
1776          r"value or known shape dimension\), or be defined at or outside the "
1777          r"while loop context '' \(currently defined in 'cond/.+'\)"):
1778        _ = gradients_impl.gradients(loop, v)
1779
1780  @test_util.run_v1_only("b/120545219")
1781  def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
1782    if test_util.is_gpu_available():
1783      self.skipTest("b/128646372, b/128645947 fails in opensource build")
1784
1785    v = constant_op.constant(1.0)
1786
1787    p = array_ops.placeholder(dtype=dtypes.int32)
1788
1789    def mid_body_builder(iterations):
1790
1791      def mid_body(i, x):
1792        r = control_flow_ops.while_loop(
1793            lambda *_: True,
1794            lambda i, x: (i + 1, v * x), (0, x),
1795            maximum_iterations=iterations,
1796            name="inner")
1797        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
1798
1799      return mid_body
1800
1801    def outer_body(i, x):
1802      iterations = array_ops.size(p, name="iterations")
1803      return (i + 1, x + control_flow_ops.while_loop(
1804          lambda *_: True,
1805          mid_body_builder(iterations), (0, x),
1806          maximum_iterations=iterations,
1807          name="mid")[1])
1808
1809    def create_while_loop():
1810      with ops.device("/cpu:0"):
1811        r = control_flow_ops.while_loop(
1812            lambda *_: True,
1813            outer_body, (0, 1.0),
1814            maximum_iterations=5,
1815            name="outer")
1816        return array_ops.identity(r[1])
1817
1818    xla_context = control_flow_ops.XLAControlFlowContext()
1819    xla_context.Enter()
1820    final_with_xla_context = create_while_loop()
1821    xla_context.Exit()
1822
1823    final_without_xla_context = create_while_loop()
1824
1825    with self.session(use_gpu=False) as sess:
1826      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1827      run_metadata_without_xla_context = config_pb2.RunMetadata()
1828      run_metadata = config_pb2.RunMetadata()
1829
1830      final_value_without_xla_context = sess.run(
1831          final_without_xla_context,
1832          feed_dict={p: [0, 0, 0]},
1833          options=opts,
1834          run_metadata=run_metadata_without_xla_context)
1835
1836      final_value_with_xla_context = sess.run(
1837          final_with_xla_context,
1838          feed_dict={p: [0, 0, 0]},
1839          options=opts,
1840          run_metadata=run_metadata)
1841
1842      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1843        # With while_v2 on xla, run_metadata only contains the unlowered While
1844        # op so node_stats does not have statistics for the pushes. So as a
1845        # loose check we check the pushes in the lowered version.
1846        for dev in run_metadata_without_xla_context.step_stats.dev_stats:
1847          if "/device:CPU" in dev.device:
1848            node_stats = dev.node_stats
1849        stack_push_count = len([
1850            x for x in node_stats
1851            if re.match(r".*TensorListPushBack_?\d*", x.node_name)
1852        ])
1853      else:
1854        for dev in run_metadata.step_stats.dev_stats:
1855          if "/device:CPU" in dev.device:
1856            node_stats = dev.node_stats
1857        stack_push_op = "StackPushV2"
1858        stack_push_count = len(
1859            [x for x in node_stats if x.node_name.endswith("StackPushV2")])
1860      # Pushes to the stack = product of maximum_iterations values;
1861      # the last two "3"s comes from size(p), when p == [0, 0, 0].
1862      self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))
1863
1864      self.assertAllClose(final_value_with_xla_context,
1865                          final_value_without_xla_context)
1866
1867  # Have more than 10 parallel iterations and hence exercise k-bound
1868  # most of the time.
1869  @test_util.run_deprecated_v1
1870  def testWhile_3(self):
1871    with self.cached_session():
1872
1873      def compute(i, m, c, o):
1874        m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
1875        o = math_ops.add(o, m)
1876        o = math_ops.add(o, c)
1877        i = math_ops.add(i, 1)
1878        return [i, m, c, o]
1879
1880      i = ops.convert_to_tensor(0)
1881      m = ops.convert_to_tensor(0)
1882      c = ops.convert_to_tensor(0)
1883      o = ops.convert_to_tensor(0)
1884      d = ops.convert_to_tensor(100)
1885      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d),
1886                                      compute, [i, m, c, o])
1887      result = r[3]
1888    self.assertAllEqual(10100, result)
1889
1890  @test_util.run_deprecated_v1
1891  def testWhile_4(self):
1892    with self.cached_session():
1893
1894      def compute(i, m, c, o):
1895        m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
1896        o = math_ops.add(o, m)
1897        o = math_ops.add(o, c)
1898        i = math_ops.add(i, 1)
1899        return [i, m, c, o]
1900
1901      i = ops.convert_to_tensor(0)
1902      m = ops.convert_to_tensor(0)
1903      c = ops.convert_to_tensor(0)
1904      o = ops.convert_to_tensor(0)
1905      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1906      s = array_ops.size(x)
1907      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s),
1908                                      compute, [i, m, c, o])
1909      result = r[3]
1910    self.assertAllEqual(42, result)
1911
1912  @test_util.run_v1_only("b/120545219")
1913  def testWhile_5(self):
1914    with self.cached_session():
1915
1916      def compute(i, c, o):
1917        c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
1918                                    [1] + array_ops.expand_dims(i, 0))
1919        o = array_ops.concat([o, c], 0)
1920        i = math_ops.add(i, 1)
1921        return [i, c, o]
1922
1923      i = ops.convert_to_tensor(0)
1924      c = ops.convert_to_tensor([0])
1925      o = ops.convert_to_tensor([0])
1926      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1927      s = array_ops.size(x)
1928      r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
1929                                      compute, [i, c, o], [
1930                                          i.get_shape(),
1931                                          tensor_shape.unknown_shape(),
1932                                          tensor_shape.unknown_shape()
1933                                      ])
1934      result = r[2]
1935    self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
1936
1937  @test_util.run_gpu_only
1938  @test_util.run_deprecated_v1
1939  def testWhile_Device(self):
1940
1941    # Body function defined outside of device scope
1942    def body(x):
1943      return math_ops.exp(x)
1944
1945    with ops.device("CPU:0"):
1946      r = control_flow_ops.while_loop(
1947          lambda x: x < 10, body, [constant_op.constant(-10.)])
1948      self.assertIn("cpu", r.device.lower())
1949
1950    with session.Session() as sess:
1951      options = config_pb2.RunOptions(output_partition_graphs=True)
1952      run_metadata = config_pb2.RunMetadata()
1953      sess.run(r, options=options, run_metadata=run_metadata)
1954      # We expect that everything runs on CPU, even if GPU is available.
1955      self.assertEqual(len(run_metadata.partition_graphs), 1)
1956
1957  @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
1958  @test_util.run_v1_only("b/120545219")
1959  def testBufferForwarding(self):
1960    run_options = config_pb2.RunOptions(
1961        trace_level=config_pb2.RunOptions.FULL_TRACE)
1962    run_metadata = config_pb2.RunMetadata()
1963
1964    with self.cached_session() as sess:
1965      with ops.device("/cpu:0"):
1966        c = constant_op.constant(2)
1967        i0 = constant_op.constant(0)
1968        r = control_flow_ops.while_loop(lambda i: i < 1000,
1969                                        lambda i: math_ops.square(c) + i, [i0])
1970      r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
1971      self.assertEqual(1000, r_val)
1972      self.assertTrue(run_metadata.HasField("step_stats"))
1973      unique_allocs = set()
1974      for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
1975        for output in node_stat.output:
1976          unique_allocs.add(
1977              output.tensor_description.allocation_description.ptr)
1978      # Prior to cl/147536680, the number of unique allocations was about 1005.
1979      self.assertLess(len(unique_allocs), 756)
1980
1981  def _testWhile_Gpu_1(self, use_gpu):
1982    with self.cached_session(use_gpu=use_gpu):
1983      n = constant_op.constant(1.0)
1984      c = lambda x: math_ops.less(x, 10.0)
1985      b = lambda x: math_ops.add(x, 1.0)
1986      r = control_flow_ops.while_loop(c, b, [n])
1987      self.assertAllClose(10.0, self.evaluate(r))
1988
1989  def testWhile_Gpu_1(self):
1990    self._testWhile_Gpu_1(use_gpu=False)
1991    self._testWhile_Gpu_1(use_gpu=True)
1992
1993  def _testWhile_Gpu_2(self, use_gpu):
1994    with self.cached_session(use_gpu=use_gpu):
1995      n = constant_op.constant(1.0)
1996      c = lambda x: math_ops.less(x, 10.0)
1997
1998      def b(x):
1999        with ops.device("/cpu:0"):
2000          return math_ops.add(x, 1.0)
2001
2002      r = control_flow_ops.while_loop(c, b, [n])
2003      self.assertAllClose(10.0, self.evaluate(r))
2004
2005  def testWhile_Gpu_2(self):
2006    self._testWhile_Gpu_2(use_gpu=False)
2007    self._testWhile_Gpu_2(use_gpu=True)
2008
2009  def testWhileShape(self):
2010    with self.cached_session():
2011      i = constant_op.constant(0)
2012      m = array_ops.ones([2, 2])
2013      c = lambda i, j: math_ops.less(i, 2)
2014
2015      def _b(i, j):
2016        new_i = math_ops.add(i, 1)
2017        new_j = array_ops.tile(j, [2, 2])
2018        return [new_i, new_j]
2019
2020      r = control_flow_ops.while_loop(
2021          c, _b, [i, m],
2022          [i.get_shape(), tensor_shape.unknown_shape()])
2023      r = r[1] * array_ops.ones([8, 8])
2024      self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
2025
2026  @test_util.disable_control_flow_v2("b/131265085")
2027  @test_util.run_v1_only("b/131265085")
2028  def testWhileBadShape(self):
2029    x = constant_op.constant([2.0, 4.0], name="values")
2030    i = constant_op.constant(0)
2031    c = lambda i, _: math_ops.less(i, 10)
2032    b = lambda i, x: [i + 1, x + 1]
2033    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2034      # Shape of x is [2], but we specify a shape of [5].
2035      control_flow_ops.while_loop(
2036          c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])])
2037
2038  @test_util.run_in_graph_and_eager_modes
2039  def testWhileBadBodyReturn(self):
2040    x = constant_op.constant([2.0, 4.0], name="values")
2041    i = constant_op.constant(0)
2042    c = lambda i, *x: math_ops.less(i, 10)
2043
2044    # body accepts N values and returns N+1 values.
2045    b = lambda i, *x: (i, i) + x
2046
2047    with self.assertRaisesRegex(
2048        ValueError, "The two structures don't have the same nested structure."):
2049      control_flow_ops.while_loop(c, b, [i, x])
2050
2051  @test_util.run_deprecated_v1
2052  def testWhileWithNonTensorInput_Scalar(self):
2053    with self.cached_session():
2054      n = 0
2055      c = lambda x: x < 10000
2056      b = lambda x: x + 1
2057      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
2058      self.assertEqual(10000, self.evaluate(r))
2059
2060  def testWhileWithNonTensorInput_Vector(self):
2061    with self.cached_session():
2062      n = np.array([0])  # Note, [0] would not work here; that is a list
2063      c = lambda x: x[0] < 10000
2064      b = lambda x: array_ops.stack([x[0] + 1])
2065      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
2066      self.assertEqual([10000], self.evaluate(r))
2067
2068  def testWhileShapeInference(self):
2069    with self.cached_session():
2070      i = constant_op.constant(0)
2071      m = array_ops.ones([2, 2])
2072      c = lambda i, j: math_ops.less(i, 2)
2073
2074      def b(i, j):
2075        new_i = math_ops.add(i, 1)
2076        new_j = array_ops.concat([j, j], 0)
2077        return [new_i, new_j]
2078
2079      r = control_flow_ops.while_loop(
2080          c, b, [i, m],
2081          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
2082      self.assertTrue(r[1].shape.is_compatible_with([8, 2]))
2083
2084  @test_util.run_v1_only("b/120545219")
2085  def testWhileShapeInferenceBadShape(self):
2086    with self.cached_session():
2087      i = constant_op.constant(0)
2088      m = array_ops.ones([2, 2])
2089      c = lambda i, j: math_ops.less(i, 2)
2090      b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)]
2091      with self.assertRaisesRegex(
2092          ValueError,
2093          r".*\(2, 2\).*\(4, 2\) after one iteration\. To allow the shape to "
2094          r"vary across iterations, use the `shape_invariants` argument of "
2095          r"tf.while_loop to specify a less-specific shape\."):
2096        control_flow_ops.while_loop(c, b, [i, m])
2097
2098  def testWhileShapeInferenceSparseTensor(self):
2099    values = constant_op.constant([2.0, 4.0], name="values")
2100    indices = constant_op.constant([[0], [3]],
2101                                   dtype=dtypes.int64,
2102                                   name="indices")
2103    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2104    i = constant_op.constant(0)
2105    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2106
2107    def c(i, _):
2108      return i < 10
2109
2110    def b1(i, x):  # modifies values.  (shape of components is not changed.)
2111      return [
2112          i + 1,
2113          sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
2114      ]
2115
2116    def b2(i, x):  # adds new values.  (shape of components is changed.)
2117      return [
2118          i + 1,
2119          sparse_ops.sparse_add(
2120              x,
2121              sparse_tensor.SparseTensor(
2122                  indices=math_ops.cast(
2123                      array_ops.fill([1, 1], i), dtypes.int64),
2124                  values=array_ops.fill([1], 1.0),
2125                  dense_shape=x.dense_shape))
2126      ]
2127
2128    def b3(i, x):  # modifies rank.  (shape of all components is changed.)
2129      return [
2130          i + 1,
2131          sparse_tensor.SparseTensor(
2132              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
2133              array_ops.concat([x.dense_shape, [10]], axis=0))
2134      ]
2135
2136    def check_shapes(r, indices, values, dense_shape):
2137      self.assertTrue(r.indices.shape.is_compatible_with(indices))
2138      self.assertTrue(r.values.shape.is_compatible_with(values))
2139      self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape))
2140
2141    # Default shape invariant; b1 only modifies values.
2142    _, r = control_flow_ops.while_loop(c, b1, [i, x])
2143    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
2144
2145    # Default shape invariant; b2 adds new values
2146    _, r = control_flow_ops.while_loop(c, b2, [i, x])
2147    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
2148
2149    # Explicit shape invariant, allowing any rank; b1 only modifies values.
2150    _, r = control_flow_ops.while_loop(
2151        c, b1, [i, x],
2152        [i.get_shape(), tensor_shape.TensorShape([None])])
2153    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2154
2155    # Explicit shape invariant, allowing any rank; b3 modifies rank.
2156    _, r = control_flow_ops.while_loop(
2157        c, b3, [i, x],
2158        [i.get_shape(), tensor_shape.TensorShape([None])])
2159    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2160
2161    # Shape invariant with ndims=None.  Technically, this isn't supported
2162    # according to the docs, but we support it for backwards compatibility.
2163    _, r = control_flow_ops.while_loop(
2164        c, b1, [i, x],
2165        [i.get_shape(), tensor_shape.TensorShape(None)])
2166    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2167    _, r = control_flow_ops.while_loop(
2168        c, b3, [i, x],
2169        [i.get_shape(), tensor_shape.TensorShape(None)])
2170    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2171
2172  @test_util.disable_control_flow_v2("b/131265085")
2173  @test_util.run_v1_only("b/131265085")
2174  def testWhileBadShapeSparseTensor(self):
2175    values = constant_op.constant([2.0, 4.0], name="values")
2176    indices = constant_op.constant([[0], [3]],
2177                                   dtype=dtypes.int64,
2178                                   name="indices")
2179    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2180    i = constant_op.constant(0)
2181    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2182    c = lambda i, _: i < 10
2183    b1 = lambda i, x: [i+1, x]
2184    def b2(i, x):  # modifies rank.  (shape of all components is changed.)
2185      return [
2186          i + 1,
2187          sparse_tensor.SparseTensor(
2188              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
2189              array_ops.concat([x.dense_shape, [10]], axis=0))
2190      ]
2191
2192    # Explicit shape invariant, with a specific (incompatible) rank.
2193    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2194      control_flow_ops.while_loop(
2195          c, b1, [i, x],
2196          [i.get_shape(), tensor_shape.TensorShape([5])])
2197
2198    # Default shape invariant, but b2 modifies rank (which is not allowed).
2199    with self.assertRaises(ValueError):
2200      control_flow_ops.while_loop(c, b2, [i, x])
2201
2202  def testWhileShapeInferenceIndexedSlices(self):
2203    with self.cached_session():
2204      values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
2205      indices = constant_op.constant([0, 3], name="indices")
2206      shape = constant_op.constant([10, 2], name="dense_shape")
2207      i = constant_op.constant(0)
2208      x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape)
2209
2210      def c(i, _):
2211        return i < 10
2212
2213      def b(i, x):
2214        return [
2215            i + 1,
2216            indexed_slices.IndexedSlices(x.values * 2.0, x.indices,
2217                                         x.dense_shape)
2218        ]
2219
2220      _, r = control_flow_ops.while_loop(c, b, [i, x])
2221      self.assertEqual(r.dense_shape.get_shape()[0], 2)
2222      self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
2223
2224      _, r = control_flow_ops.while_loop(
2225          c, b, [i, x],
2226          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
2227      self.assertEqual(r.dense_shape.get_shape()[0], 2)
2228      self.assertTrue(r.values.get_shape().is_compatible_with([None, 2]))
2229
2230  @test_util.disable_control_flow_v2("b/131265085")
2231  @test_util.run_v1_only("b/131265085")
2232  def testWhileBadShapeIndexedSlices(self):
2233    values = constant_op.constant([2.0, 4.0], name="values")
2234    indices = constant_op.constant([[0], [3]],
2235                                   dtype=dtypes.int64,
2236                                   name="indices")
2237    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2238    i = constant_op.constant(0)
2239    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2240    c = lambda i, _: 10
2241    b = lambda i, x: [i+1, x]
2242
2243    # Explicit shape invariant, with a specific (incompatible) rank.
2244    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2245      control_flow_ops.while_loop(
2246          c, b, [i, x],
2247          [i.get_shape(), tensor_shape.TensorShape([5])])
2248
2249  def testWhileShapeInferenceRaggedTensor(self):
2250    i = constant_op.constant(0)
2251    x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
2252    c = lambda i, _: i < 10
2253
2254    def b1(i, x):  # Adds new values to rows (but doesn't create new rows)
2255      return [
2256          i + 1,
2257          array_ops.concat([x, x], axis=1)
2258      ]
2259
2260    def b2(i, x):  # Adds new rows.
2261      return [
2262          i + 1,
2263          array_ops.concat([x, x], axis=0)
2264      ]
2265
2266    def check_shapes(r, values, splits):
2267      self.assertTrue(r.values.shape.is_compatible_with(values))
2268      self.assertTrue(r.row_splits.shape.is_compatible_with(splits))
2269
2270    # Default shape invariant; b1 adds new values to rows.
2271    _, r = control_flow_ops.while_loop(c, b1, [i, x])
2272    check_shapes(r, values=[None], splits=[4])
2273
2274    # Default shape invariant; b2 adds new rows (not allowed).
2275    if not context.executing_eagerly():
2276      with self.assertRaises(ValueError):
2277        _, r = control_flow_ops.while_loop(c, b2, [i, x])
2278
2279    # Explicit shape invariant; b1 adds new values to rows.
2280    # (deprecated: use TensorShape instead of RaggedTensorSpec)
2281    _, r = control_flow_ops.while_loop(
2282        c, b1, [i, x],
2283        [i.get_shape(), tensor_shape.TensorShape([None, None])])
2284    check_shapes(r, values=[None], splits=[None])
2285
2286    # Explicit shape invariant; b1 adds new values to rows.
2287    _, r = control_flow_ops.while_loop(
2288        c, b1, [i, x],
2289        [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None],
2290                                                       dtypes.int32)])
2291    check_shapes(r, values=[None], splits=[None])
2292
2293    # Explicit shape invariant; b2 adds new rows.
2294    _, r = control_flow_ops.while_loop(
2295        c, b2, [i, x],
2296        [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None],
2297                                                       dtypes.int32)])
2298    check_shapes(r, values=[None], splits=[None])
2299
2300  def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
2301    i = constant_op.constant(0)
2302    x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
2303                                     [[], [8, 9, 10]]])
2304    c = lambda i, _: i < 10
2305    def b(i, x):
2306      return [
2307          i + 1,
2308          array_ops.concat([x, x[..., i:i+1]], axis=-1)
2309      ]
2310    _, r = control_flow_ops.while_loop(c, b, [i, x])
2311    self.assertEqual(r.row_splits.shape.as_list(), [3])
2312    self.assertIn(r.values.row_splits.shape.as_list(), ([6], [None]))
2313    self.assertIn(r.values.values.shape.as_list(), ([49], [None]))
2314
2315  def testWhileShapeInvariantTensorSpec(self):
2316    i = constant_op.constant(0)
2317    x = constant_op.constant([1])
2318    c = lambda i, _: i < 10
2319    b = lambda i, x: (i + 1, array_ops.stack([x, x]))
2320    shape_invariants = [
2321        tensor_spec.TensorSpec([], dtype=dtypes.int32),
2322        tensor_spec.TensorSpec(None, dtype=dtypes.int32)]
2323    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
2324
2325  # TODO(b/131265085) Remove this decorator when bug is fixed.
2326  @test_util.build_as_function_and_v1_graph
2327  def testWhileShapeInvariantWrongTypeSpecType(self):
2328    c = lambda i, _: i < 10
2329    b = lambda i, x: (i + 1, x)
2330    i = constant_op.constant(0)
2331    x = sparse_tensor.SparseTensor([[0]], [1.0], [10])
2332    shape_invariants = [
2333        tensor_spec.TensorSpec([], dtype=dtypes.int32),
2334        sparse_tensor.SparseTensorSpec([None])]
2335    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
2336
2337    x2 = constant_op.constant([1])
2338    with self.assertRaises(TypeError):
2339      control_flow_ops.while_loop(c, b, [i, x2], shape_invariants)
2340
2341    x3 = ragged_factory_ops.constant([[1, 2], [3]])
2342    with self.assertRaises(TypeError):
2343      control_flow_ops.while_loop(c, b, [i, x3], shape_invariants)
2344
2345    i2 = constant_op.constant(0.0)
2346    with self.assertRaises(TypeError):
2347      control_flow_ops.while_loop(c, b, [i2, x], shape_invariants)
2348
2349  # TODO(b/131265085) Remove this decorator when bug is fixed.
2350  @test_util.build_as_function_and_v1_graph
2351  def testWhileShapeInvariantBadType(self):
2352    i = constant_op.constant(0)
2353    x = constant_op.constant([1])
2354    c = lambda i, _: i < 10
2355    b = lambda i, x: (i + 1, x)
2356    with self.assertRaises((ValueError, TypeError)):
2357      control_flow_ops.while_loop(c, b, [i, x], ["foo", "bar"])
2358
2359  def _testNestedWhile_1(self, use_gpu):
2360    with self.cached_session(use_gpu=use_gpu):
2361      n = constant_op.constant(0)
2362
2363      def cpu_sum(s):
2364        c = lambda i, s: math_ops.less(i, 10)
2365
2366        def b(i, s):
2367          i1 = math_ops.add(i, 1)
2368          with ops.device("/cpu:0"):
2369            s1 = math_ops.add(i, s)
2370          return i1, s1
2371
2372        _, r_s = control_flow_ops.while_loop(c, b, [n, s])
2373        return r_s
2374
2375      c = lambda x: math_ops.less(x, 200)
2376      b = lambda x: math_ops.add(x, cpu_sum(n))
2377      r = control_flow_ops.while_loop(c, b, [n])
2378      self.assertEqual(225, self.evaluate(r))
2379
2380  def testNestedWhile_1(self):
2381    self._testNestedWhile_1(use_gpu=False)
2382    self._testNestedWhile_1(use_gpu=True)
2383
2384  def _testNestedWhile_2(self, use_gpu):
2385    # Test the cases that A -> Enter and Exit -> A are partitioned.
2386    with self.cached_session(use_gpu=use_gpu):
2387      s0 = constant_op.constant(2.0)
2388
2389      def inner_loop(s):
2390        c = lambda s: math_ops.less(s, 20.0)
2391
2392        def b(s):
2393          s1 = math_ops.add(s, s)
2394          return s1
2395
2396        r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1)
2397        return r_s
2398
2399      outer_c = lambda x: math_ops.less(x, 3000.0)
2400
2401      def outer_b(x):
2402        x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
2403        x = inner_loop(x)
2404        with ops.device("/cpu:0"):
2405          x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
2406        return x
2407
2408      r = control_flow_ops.while_loop(
2409          outer_c, outer_b, [s0], parallel_iterations=1)
2410      self.assertEqual(1048576.0, self.evaluate(r))
2411
2412  def testNestedWhile_2(self):
2413    self._testNestedWhile_2(use_gpu=False)
2414    self._testNestedWhile_2(use_gpu=True)
2415
2416  @test_util.run_v1_only("b/120545219")
2417  def testWhileWithControl_1(self):
2418    with self.cached_session():
2419      n = constant_op.constant(0)
2420      r = constant_op.constant(0)
2421      condition = lambda n_, r_: math_ops.less(n_, 10)
2422
2423      def body(n_, r_):
2424        n_ = math_ops.add(n_, 1)
2425        with r_.graph.control_dependencies([r_]):
2426          r_ = constant_op.constant(12)
2427        return [n_, r_]
2428
2429      res = control_flow_ops.while_loop(
2430          condition, body, [n, r], parallel_iterations=1)
2431      self.assertAllEqual(12, res[1])
2432
2433  @test_util.run_deprecated_v1
2434  def testWhileWithControl_2(self):
2435    with self.cached_session():
2436      r = constant_op.constant(0)
2437      condition = lambda r_: math_ops.less(r_, 10)
2438
2439      def body(r_):
2440        with r_.graph.control_dependencies([r_]):
2441          r_ = constant_op.constant(12)
2442        return [r_]
2443
2444      res = control_flow_ops.while_loop(
2445          condition, body, [r], parallel_iterations=1)
2446      self.assertAllEqual(12, self.evaluate(res))
2447
2448  @test_util.run_v1_only("b/120545219")
2449  def testWhileWithControl_3(self):
2450    with self.cached_session() as sess:
2451      b = array_ops.placeholder(dtypes.bool)
2452      c = constant_op.constant(1)
2453      x0 = constant_op.constant(0)
2454      with ops.control_dependencies([b]):
2455        r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
2456      self.assertEqual(10, sess.run(r, {b: True}))
2457
2458  @test_util.run_v1_only("b/120545219")
2459  def testWhileWithControl_4(self):
2460    with self.cached_session() as sess:
2461      b = array_ops.placeholder(dtypes.bool)
2462      c = constant_op.constant(1)
2463      x0 = constant_op.constant(0)
2464      with ops.control_dependencies([b]):
2465        r = control_flow_ops.while_loop(
2466            lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
2467      self.assertEqual(10, sess.run(r, {b: True}))
2468
2469  @test_util.run_v1_only("b/120545219")
2470  def testWhileWithControl_5(self):
2471    with self.cached_session() as sess:
2472      b = array_ops.placeholder(dtypes.bool)
2473      c = constant_op.constant(1)
2474      x0 = constant_op.constant(0)
2475
2476      def body(x):
2477        with ops.control_dependencies([b]):
2478          return x + c
2479
2480      r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
2481      self.assertEqual(10, sess.run(r, {b: True}))
2482
2483  def testWhileCondWithControl(self):
2484    # Ensure that no control edges by an outer control dependency context are
2485    # added to nodes inside cond/while contexts.
2486    with self.cached_session() as sess:
2487      const_true = lambda: constant_op.constant(True)
2488      const_false = lambda: constant_op.constant(False)
2489      cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
2490      body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i)
2491
2492      with ops.control_dependencies([control_flow_ops.no_op()]):
2493        loop = control_flow_ops.while_loop(cond, body,
2494                                           (constant_op.constant(5),))
2495      self.assertEqual(0, self.evaluate(loop))
2496
2497  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2498  @test_util.run_v1_only("b/120545219")
2499  def testWhileCondWithControl_1(self):
2500    with self.cached_session():
2501      v = variable_scope.get_variable(
2502          "v", [], initializer=init_ops.constant_initializer(2))
2503      i0 = constant_op.constant(0)
2504      with ops.control_dependencies([i0]):
2505
2506        def loop_condition(i):
2507          return i < 4
2508
2509        def loop_body(i):
2510          some_cond = control_flow_ops.cond(
2511              constant_op.constant(True),
2512              lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
2513          with ops.control_dependencies([some_cond]):
2514            return i + 1
2515
2516      r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
2517      self.evaluate(variables.global_variables_initializer())
2518      self.assertEqual(4, self.evaluate(r))
2519      self.assertAllClose(65536.0, self.evaluate(v))
2520
2521  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2522  @test_util.run_v1_only("b/120545219")
2523  def testWhileCondExitControl(self):
2524
2525    with self.cached_session():
2526      v = variables.Variable(1)
2527
2528      def false_branch():
2529        cond = lambda i: i < 100
2530
2531        def body(i):
2532          x = state_ops.assign(v, i)
2533          return x + 1
2534
2535        loop = control_flow_ops.while_loop(cond, body, [0])
2536        # Make sure to handle correctly control edge from Exit to a node.
2537        with ops.control_dependencies([loop]):
2538          return constant_op.constant(6.0)
2539
2540      r = control_flow_ops.cond(
2541          constant_op.constant(False), lambda: constant_op.constant(1.0),
2542          false_branch)
2543      self.evaluate(variables.global_variables_initializer())
2544      self.assertEqual(6.0, self.evaluate(r))
2545      self.assertEqual(99, self.evaluate(v))
2546
2547  def testCondWhile_1(self):
2548
2549    with self.cached_session():
2550      n = ops.convert_to_tensor(0, name="n")
2551      c = lambda x: math_ops.less(x, 10)
2552      b = lambda x: math_ops.add(x, 1)
2553      r = control_flow_ops.cond(
2554          math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]),
2555          lambda: n)
2556      self.assertAllEqual(10, self.evaluate(r))
2557
2558  def testCondWhile_2(self):
2559
2560    with self.cached_session():
2561      n = ops.convert_to_tensor(0)
2562      c = lambda x: math_ops.less(x, 10)
2563      b = lambda x: math_ops.add(x, 1)
2564      r = control_flow_ops.cond(
2565          math_ops.less(1, 0), lambda: math_ops.add(n, 1),
2566          lambda: control_flow_ops.while_loop(c, b, [n]))
2567      self.assertAllEqual(10, self.evaluate(r))
2568
2569  def _testCondWhile_3(self, use_gpu):
2570    with self.cached_session(use_gpu=use_gpu) as sess:
2571      p = array_ops.placeholder(dtypes.bool)
2572      n = constant_op.constant(0.0)
2573
2574      def c(x):
2575        return math_ops.less(x, 10.0)
2576
2577      def b(x):
2578        with ops.device("/cpu:0"):
2579          x1 = math_ops.add(x, 1.0)
2580        return x1
2581
2582      r = control_flow_ops.cond(p,
2583                                lambda: control_flow_ops.while_loop(c, b, [n]),
2584                                lambda: math_ops.multiply(n, 2.0))
2585      r1 = gradients_impl.gradients(r, [n])
2586      self.assertEqual(10., sess.run(r, {p: True}))
2587      self.assertEqual([1.0], sess.run(r1, {p: True}))
2588      self.assertEqual(0.0, sess.run(r, {p: False}))
2589      self.assertEqual([2.0], sess.run(r1, {p: False}))
2590
2591  @test_util.run_deprecated_v1
2592  def testCondWhile_3(self):
2593    self._testCondWhile_3(use_gpu=False)
2594    self._testCondWhile_3(use_gpu=True)
2595
2596  def testWhileCond_1(self):
2597
2598    with self.cached_session():
2599      i = ops.convert_to_tensor(0, name="i")
2600      n = ops.convert_to_tensor(10, name="n")
2601      one = ops.convert_to_tensor(1, name="one")
2602      c = lambda x: math_ops.less(x, n)
2603      # pylint: disable=undefined-variable
2604      # for OSS build
2605      b = lambda x: control_flow_ops.cond(
2606          constant_op.constant(True),
2607          lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
2608      # pylint: enable=undefined-variable
2609      r = control_flow_ops.while_loop(c, b, [i])
2610      self.assertAllEqual(10, self.evaluate(r))
2611
2612  def testWhileCond_2(self):
2613
2614    with self.cached_session():
2615      n = ops.convert_to_tensor(0, name="n")
2616      c = lambda x: math_ops.less(x, 10)
2617      b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
2618      r = control_flow_ops.while_loop(c, b, [n])
2619      self.assertAllEqual(10, self.evaluate(r))
2620
2621  def testWhileCond_3(self):
2622
2623    with self.cached_session():
2624      n = ops.convert_to_tensor(0)
2625      c = lambda x: math_ops.less(x, 10)
2626      # pylint: disable=undefined-variable
2627      # for OSS build
2628      b = lambda x: control_flow_ops.cond(math_ops.less(0, 1),
2629                                          lambda: math_ops.add(x, 1),
2630                                          lambda: math_ops.subtract(x, 1))
2631      # pylint: enable=undefined-variable
2632      r = control_flow_ops.while_loop(c, b, [n])
2633      self.assertAllEqual(10, self.evaluate(r))
2634
2635  @test_util.run_deprecated_v1
2636  def testWhileCondGradMultiDevice(self):
2637    config = config_pb2.ConfigProto(device_count={"CPU": 2},
2638                                    allow_soft_placement=True)
2639    with self.cached_session(config=config) as sess:
2640      pred = array_ops.placeholder(dtypes.bool, [])
2641      x_init = constant_op.constant(1.0)
2642
2643      with ops.device("/cpu:0"):
2644        z = control_flow_ops.while_loop(
2645            lambda i, _: i < 3,
2646            lambda i, x: (i + 1, control_flow_ops.cond(
2647                pred, lambda: x * 2.0, lambda: 10.0)),
2648            [0, x_init])
2649
2650      with ops.device("/cpu:1"):
2651        grad = gradients_impl.gradients(z, x_init)[0]
2652
2653      with ops.device("/cpu:0"):
2654        grad_grad = gradients_impl.gradients(grad, x_init)[0]
2655
2656      self.assertEqual(sess.run(grad, {pred: True}), 8.0)
2657      self.assertEqual(sess.run(grad, {pred: False}), 0.0)
2658
2659      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
2660        return
2661
2662      self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
2663      self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
2664
2665  # NOTE: It is ok to have parallel_iterations > 1
2666  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2667  @test_util.run_deprecated_v1
2668  def testWhileUpdateVariable_1(self):
2669    with self.cached_session():
2670      select = variables.Variable([3.0, 4.0, 5.0])
2671      n = constant_op.constant(0)
2672
2673      def loop_iterator(j):
2674        return math_ops.less(j, 3)
2675
2676      def loop_body(j):
2677        ns = state_ops.scatter_update(select, j, 10.0)
2678        nj = math_ops.add(j, 1)
2679        op = control_flow_ops.group(ns)
2680        nj = control_flow_ops.with_dependencies([op], nj)
2681        return [nj]
2682
2683      r = control_flow_ops.while_loop(
2684          loop_iterator, loop_body, [n], parallel_iterations=1)
2685      self.evaluate(variables.global_variables_initializer())
2686      self.assertEqual(3, self.evaluate(r))
2687      result = self.evaluate(select)
2688      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2689
2690  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2691  @test_util.run_v1_only("b/120545219")
2692  def testWhileUpdateVariable_2(self):
2693    with self.cached_session():
2694      select1 = variables.Variable([3.0, 4.0, 5.0])
2695      select2 = variables.Variable([3.0, 4.0, 5.0])
2696      n = constant_op.constant(0)
2697
2698      def loop_iterator(j):
2699        return math_ops.less(j, 3)
2700
2701      def loop_body(j):
2702        ns1 = state_ops.scatter_update(select1, j, 10.0)
2703        ns2 = state_ops.scatter_update(select2, j, 10.0)
2704        nj = math_ops.add(j, 1)
2705        op = control_flow_ops.group(ns1, ns2)
2706        nj = control_flow_ops.with_dependencies([op], nj)
2707        return [nj]
2708
2709      r = control_flow_ops.while_loop(
2710          loop_iterator, loop_body, [n], parallel_iterations=1)
2711      self.evaluate(variables.global_variables_initializer())
2712      self.assertEqual(3, self.evaluate(r))
2713      result1 = self.evaluate(select1)
2714      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
2715      result2 = self.evaluate(select2)
2716      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
2717
2718  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2719  @test_util.run_v1_only("b/120545219")
2720  def testWhileUpdateVariable_3(self):
2721    with self.cached_session():
2722      select = variables.Variable([3.0, 4.0, 5.0])
2723      n = constant_op.constant(0)
2724
2725      def loop_iterator(j, _):
2726        return math_ops.less(j, 3)
2727
2728      def loop_body(j, _):
2729        ns = state_ops.scatter_update(select, j, 10.0)
2730        nj = math_ops.add(j, 1)
2731        return [nj, ns]
2732
2733      r = control_flow_ops.while_loop(
2734          loop_iterator,
2735          loop_body, [n, array_ops.identity(select)],
2736          parallel_iterations=1)
2737      self.evaluate(variables.global_variables_initializer())
2738      result = r[1]
2739    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2740
2741  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2742  @test_util.run_v1_only("b/120545219")
2743  def testWhileUpdateVariable_4(self):
2744    with self.cached_session():
2745      var_a = variables.Variable(0, name="a")
2746      var_b = variables.Variable(0, name="b")
2747      self.evaluate(variables.global_variables_initializer())
2748
2749      c = constant_op.constant(0, name="c")
2750      asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2751
2752      # Loop condition
2753      def pred(i):
2754        return math_ops.less(i, 10)
2755
2756      # Loop body
2757      def loop_body(i):
2758        asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
2759        with ops.control_dependencies([asn2]):
2760          ni = math_ops.add(i, 1, name="i_add")
2761        return ni
2762
2763      lpa = control_flow_ops.while_loop(
2764          pred, loop_body, [c], parallel_iterations=1)
2765
2766      self.assertEqual(0, self.evaluate(var_b))
2767      self.evaluate(lpa)  # Run the loop
2768      self.assertEqual(10, self.evaluate(var_b))
2769
2770  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2771  @test_util.run_v1_only("b/120545219")
2772  def testWhileUpdateVariable_5(self):
2773    with self.cached_session():
2774      # Create some variables.
2775      var_a = variables.Variable(0, name="a")
2776      var_b = variables.Variable(0, name="b")
2777      self.evaluate(variables.global_variables_initializer())
2778
2779      # Change condition to check var_b
2780      def pred(_):
2781        return math_ops.less(var_b, 10)
2782
2783      # Change body to increment var_b
2784      def loop_body(i):
2785        asn1 = state_ops.assign_add(
2786            var_a, constant_op.constant(1), name="a_add")
2787        asn2 = state_ops.assign_add(
2788            var_b, constant_op.constant(1), name="b_add")
2789        with ops.control_dependencies([asn1, asn2]):
2790          inc_b = array_ops.identity(var_b)
2791        return inc_b
2792
2793      lpa = control_flow_ops.while_loop(
2794          pred, loop_body, [var_b], parallel_iterations=1, name="loop")
2795
2796      self.assertEqual(0, self.evaluate(var_b))
2797      self.evaluate(lpa)  # Run the loop
2798      self.assertEqual(10, self.evaluate(var_a))
2799      self.assertEqual(10, self.evaluate(var_b))
2800
2801  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2802  @test_util.run_v1_only("b/120545219")
2803  def testWhileUpdateVariable_6(self):
2804    with self.cached_session():
2805      # Create some variables.
2806      var_a = variables.Variable(0, name="a")
2807      var_b = variables.Variable(0, name="b")
2808      c = constant_op.constant(0)
2809      self.evaluate(variables.global_variables_initializer())
2810
2811      # Loop condition
2812      def pred(i):
2813        return math_ops.less(i, 10)
2814
2815      # Loop body
2816      def loop_body(i):
2817        asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2818        with ops.control_dependencies([asn1]):
2819          asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
2820        with ops.control_dependencies([asn2]):
2821          ni = math_ops.add(i, 1, name="i_add")
2822          return ni
2823
2824      lpa = control_flow_ops.while_loop(
2825          pred, loop_body, [c], parallel_iterations=1, name="loop")
2826
2827      self.assertEqual(0, self.evaluate(var_b))
2828      self.evaluate(lpa)  # Run the loop
2829      self.assertEqual(55, self.evaluate(var_b))
2830      self.assertEqual(10, self.evaluate(var_a))
2831
2832  @test_util.run_v1_only("b/120545219")
2833  def testWhileQueue_1(self):
2834    with self.cached_session():
2835      q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
2836      i = constant_op.constant(0)
2837
2838      def c(i):
2839        return math_ops.less(i, 10)
2840
2841      def b(i):
2842        ni = math_ops.add(i, 1)
2843        ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
2844        return ni
2845
2846      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2847      self.assertEqual([10], self.evaluate(r))
2848      for i in range(10):
2849        self.assertEqual([i], self.evaluate(q.dequeue()))
2850
2851  @test_util.run_v1_only("b/120545219")
2852  def testWhileTimeOut(self):
2853    run_options = config_pb2.RunOptions(timeout_in_ms=1)
2854    with self.cached_session() as sess:
2855      n = constant_op.constant(0)
2856      c = lambda x: True
2857      b = lambda x: math_ops.add(x, 1)
2858      r = control_flow_ops.while_loop(c, b, [n])
2859      with self.assertRaises(errors_impl.DeadlineExceededError):
2860        sess.run(r, options=run_options)
2861
2862  @test_util.disable_control_flow_v2("b/117119329 (stack)")
2863  @test_util.run_v1_only("b/120545219")
2864  def testWhileStack_1(self):
2865    with self.cached_session():
2866      s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
2867      i = constant_op.constant(0)
2868
2869      def c(i):
2870        return math_ops.less(i, 10)
2871
2872      def b(i):
2873        ni = math_ops.add(i, 1)
2874        ni = control_flow_ops.with_dependencies(
2875            [gen_data_flow_ops.stack_push_v2(s, i)], ni)
2876        return ni
2877
2878      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2879
2880      x = constant_op.constant(0)
2881
2882      def c1(i, _):
2883        return math_ops.greater(i, 0)
2884
2885      def b1(i, x):
2886        ni = math_ops.subtract(i, 1)
2887        nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32)
2888        return [ni, nx]
2889
2890      _, rx = control_flow_ops.while_loop(
2891          c1,
2892          b1, [r, x],
2893          [r.get_shape(), tensor_shape.unknown_shape()],
2894          parallel_iterations=1)
2895      self.assertEqual(45, self.evaluate(rx))
2896
2897  def _testWhileGrad_ColocateGradients(self, colocate):
2898    gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
2899    ) else "/device:CPU:0"
2900
2901    graph = ops.Graph()
2902    with graph.as_default():
2903      v = constant_op.constant(2.0, name="v")
2904      c = lambda v: math_ops.less(v, 100.0)
2905
2906      def b(x):
2907        with ops.device(gpu_dev_name):
2908          return math_ops.square(x)
2909
2910      loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2911      r = gradients_impl.gradients(
2912          loop, v, colocate_gradients_with_ops=colocate)[0]
2913
2914    r_ops = graph.get_operations()
2915    r_devices = [(op.name, op.device) for op in r_ops]
2916
2917    self.assertTrue(any("Square" in op.name for op in r_ops))
2918
2919    for (name, dev) in r_devices:
2920      if not colocate and name.endswith("Square"):
2921        # Only forward graph contain gpu in Square device
2922        self.assertTrue(gpu_dev_name in dev)
2923      elif colocate and "Square" in name:
2924        # Forward and backward graphs contain gpu in Square/Square_grad devices
2925        self.assertTrue(gpu_dev_name in dev)
2926      else:
2927        self.assertFalse(gpu_dev_name in dev)
2928
2929    with self.session(graph=graph) as sess:
2930      self.assertAllClose(1024.0, self.evaluate(r))
2931
2932  @test_util.disable_control_flow_v2("b/116351701 (colocation)")
2933  @test_util.run_v1_only("b/120545219")
2934  def testWhileGrad_ColocateGradients(self):
2935    self._testWhileGrad_ColocateGradients(colocate=False)
2936    self._testWhileGrad_ColocateGradients(colocate=True)
2937
2938  @test_util.run_v1_only("b/120545219")
2939  def testWhileGrad_Square(self):
2940    with self.cached_session():
2941      v = constant_op.constant(2.0, name="v")
2942      c = lambda v: math_ops.less(v, 100.0)
2943      b = math_ops.square
2944      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2945      r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v)
2946
2947      r = gradients_impl.gradients(r, v)[0]
2948      self.assertAllClose(1024.0, self.evaluate(r))
2949
2950  @test_util.run_v1_only("b/120545219")
2951  def testWhileGrad_Shape(self):
2952    with self.cached_session():
2953      x = array_ops.placeholder(dtypes.float32, shape=[None])
2954      v = constant_op.constant([2.0], name="v")
2955      n = constant_op.constant(0, name="n")
2956      c = lambda i, v: math_ops.less(i, 5)
2957      b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
2958      r = control_flow_ops.while_loop(
2959          c,
2960          b, [n, v],
2961          [n.get_shape(), tensor_shape.unknown_shape()],
2962          parallel_iterations=1)
2963
2964      r = gradients_impl.gradients(r[1], x)[0]
2965      self.assertEqual([None], r.get_shape().as_list())
2966      self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
2967
2968  @test_util.run_deprecated_v1
2969  def testWhileGrad_BaseShape(self):
2970    with self.cached_session() as sess:
2971      x = array_ops.placeholder(dtypes.float32, [None])
2972      v0 = constant_op.constant([2.0, 2.0], name="v")
2973      c = lambda v: constant_op.constant(False)
2974      b = lambda v: math_ops.multiply(v, x)
2975      r = control_flow_ops.while_loop(c, b, [v0])
2976      y = math_ops.square(x)
2977
2978      r = gradients_impl.gradients([r, y], x)[0]
2979      self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
2980
2981  @test_util.run_deprecated_v1
2982  @test_util.enable_output_all_intermediates
2983  def testWhileGradAfterSessionRun(self):
2984    v0 = constant_op.constant(2.)
2985    r = control_flow_ops.while_loop(
2986        lambda _: True, lambda v: v * v, [v0], maximum_iterations=3)
2987
2988    self.assertAllEqual(r, 256.)
2989    grad = gradients_impl.gradients(r, v0)[0]
2990    self.assertAllClose(grad, 1024.)
2991
2992  @test_util.run_deprecated_v1
2993  @test_util.enable_output_all_intermediates
2994  def testNestedWhileGradAfterSessionRun(self):
2995    v0 = constant_op.constant(2.)
2996
2997    def body(v):
2998      inner_v0 = constant_op.constant(1.)
2999      return control_flow_ops.while_loop(
3000          lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2)
3001
3002    r = control_flow_ops.while_loop(
3003        lambda _: True, body, [v0], maximum_iterations=3)
3004
3005    self.assertAllEqual(r, 256.)
3006    grad = gradients_impl.gradients(r, v0)[0]
3007    self.assertAllClose(grad, 1024.)
3008
3009  @test_util.run_v1_only("b/120545219")
3010  def testWhileGrad_MultipleUses(self):
3011    with self.cached_session():
3012      v = constant_op.constant(2.0, name="v")
3013      c = lambda v: math_ops.less(v, 100.0)
3014      b = math_ops.square
3015      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3016      r = math_ops.multiply(r, r)
3017
3018      r = gradients_impl.gradients(r, v)[0]
3019      self.assertEqual(524288.0, self.evaluate(r))
3020
3021  @test_util.run_v1_only("b/120545219")
3022  def testWhileGrad_LoopAdd(self):
3023    with self.cached_session():
3024      v = constant_op.constant(2.0, name="v")
3025      c = lambda v: math_ops.less(v, 100.0)
3026      b = math_ops.square
3027      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3028      r = math_ops.add(r, r)
3029
3030      r = gradients_impl.gradients(r, v)[0]
3031      self.assertAllClose(2048.0, self.evaluate(r))
3032
3033  def _testWhileGrad_Mul(self, use_gpu, p_iters):
3034    with self.cached_session(use_gpu=use_gpu) as sess:
3035      a = constant_op.constant(3.0, name="a")
3036      v = constant_op.constant(2.0, name="v")
3037      c = lambda v: math_ops.less(v, 100.0)
3038      b = lambda v: math_ops.multiply(v, a)
3039      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
3040
3041      grad_a, grad_v = gradients_impl.gradients(r, [a, v])
3042      grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
3043      self.assertAllClose(216.0, grad_a_val)
3044      self.assertAllClose(81.0, grad_v_val)
3045
3046  @test_util.run_deprecated_v1
3047  def testWhileGrad_Mul(self):
3048    self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
3049    self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
3050    self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
3051    self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
3052
3053  def testWhileGradInControlDeps(self):
3054
3055    @def_function.function
3056    def f():
3057      x_init = constant_op.constant(2.)
3058      loop_cond = lambda i, x: math_ops.less(i, 2)
3059      loop_body = lambda i, x: [i + 1, x**2]
3060      _, x = control_flow_ops.while_loop(loop_cond, loop_body, [0, x_init])
3061      with ops.control_dependencies([x]):
3062        (grad,) = gradients_impl.gradients(x, x_init)
3063        return grad
3064
3065    self.assertAllEqual(f(), 4. * 2.**3)  # 4 * x_init ^ 3
3066
3067  @test_util.run_deprecated_v1
3068  def testTfFunctionInV1WhileLoop(self):
3069
3070    # This test specifically tests that creating a Const node inside a
3071    # tf.function inside a v1 while_loop while inlining is turned on works.
3072    config = opt_cfg()
3073    assert config.graph_options.optimizer_options.do_function_inlining
3074    with session.Session(config=config):
3075
3076      @def_function.function
3077      def loop_body(i):
3078        # Here we create the const.
3079        return i + 1.
3080
3081      loop_cond = lambda i: True
3082      x = control_flow_ops.while_loop(
3083          loop_cond, loop_body, [0.], maximum_iterations=5)
3084      self.assertAllEqual(x, 5.)
3085
3086  def _testNestedWhileCondWhileGrad(self, use_gpu):
3087
3088    with self.cached_session(use_gpu=use_gpu):
3089      v = constant_op.constant(1.0)
3090
3091      def inner_loop(s):
3092        z = constant_op.constant(0)
3093        c = lambda i, x: math_ops.less(i, 4)
3094        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3095        return control_flow_ops.while_loop(c, b, [z, s])
3096
3097      c = lambda x: math_ops.less(x, 128.0)
3098
3099      def b(x):
3100        return control_flow_ops.cond(
3101            constant_op.constant(True),
3102            lambda: math_ops.square(inner_loop(x)[1]),
3103            lambda: math_ops.multiply(x, 2.0))
3104
3105      r = control_flow_ops.while_loop(c, b, [v])
3106      r = gradients_impl.gradients(r, v)[0]
3107      self.assertAllClose(512.0, self.evaluate(r))
3108
3109  @test_util.run_deprecated_v1
3110  def testNestedWhileCondWhileGrad(self):
3111    self._testNestedWhileCondWhileGrad(use_gpu=False)
3112
3113  @test_util.run_deprecated_v1
3114  def testNestedWhileCondWhileGradGpu(self):
3115    self._testNestedWhileCondWhileGrad(use_gpu=True)
3116
3117  @test_util.run_v1_only("b/120545219")
3118  def testWhileGrad_Variable(self):
3119    with self.cached_session():
3120      a = variables.Variable(3.0)
3121      v = constant_op.constant(2.0, name="v")
3122      c = lambda v: math_ops.less(v, 100.0)
3123      b = lambda v: math_ops.multiply(v, a)
3124      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3125
3126      r = gradients_impl.gradients(r, a)
3127      self.evaluate(variables.global_variables_initializer())
3128      self.assertAllClose(216.0, r[0])
3129
3130  @test_util.run_deprecated_v1
3131  def testWhileGrad_ResourceVariable(self):
3132    with self.cached_session():
3133      a = resource_variable_ops.ResourceVariable(3.0)
3134      v = constant_op.constant(2.0, name="v")
3135      c = lambda v: math_ops.less(v, 100.0)
3136      b = lambda v: math_ops.multiply(v, a)
3137      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3138
3139      g = gradients_impl.gradients(r, a)
3140      self.evaluate(variables.global_variables_initializer())
3141      self.assertAllClose(216.0, g[0])
3142
3143  def testWhileGrad_EagerResourceVariable(self):
3144    with context.eager_mode():
3145      a = resource_variable_ops.ResourceVariable(
3146          np.ones([2, 2], dtype=np.float32))
3147      v = constant_op.constant(1.0)
3148
3149      @eager_function.defun
3150      def fn():
3151        r = control_flow_ops.while_loop(
3152            lambda i, _: i < 2,
3153            lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
3154            [0, 1.0])[1]
3155        return gradients_impl.gradients(r, [v])[0]
3156
3157      self.assertEqual(self.evaluate(fn()), 32.)
3158
3159  def testWhileGrad_ResourceVarInFunctionCall(self):
3160
3161    @def_function.function
3162    def foo(x, var):
3163      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
3164
3165    @def_function.function
3166    def bar(var):
3167      r = control_flow_ops.while_loop(
3168          lambda i, _: i < 2,
3169          lambda i, x: (i + 1, foo(x, var)),
3170          [0, 0.0])[1]
3171      return gradients_impl.gradients(r, var)[0]
3172
3173    var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
3174    self.evaluate(variables.global_variables_initializer())
3175    grad = self.evaluate(bar(var))
3176    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
3177
3178  def testWhileGrad_ResourceVarInNestedFunctionCall(self):
3179
3180    @def_function.function
3181    def foo(x, var):
3182      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
3183
3184    @def_function.function
3185    def foo2(x, var):
3186      return foo(x, var)
3187
3188    @def_function.function
3189    def bar(var):
3190      r = control_flow_ops.while_loop(
3191          lambda i, _: i < 2,
3192          lambda i, x: (i + 1, foo2(x, var)),
3193          [0, 0.0])[1]
3194      return gradients_impl.gradients(r, var)[0]
3195
3196    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
3197    self.evaluate(variables.global_variables_initializer())
3198    grad = self.evaluate(bar(var))
3199    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
3200
3201  def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
3202    if test.is_gpu_available():
3203      self.skipTest("b/128635252")
3204
3205    @def_function.function
3206    def foo(x, var):
3207      return control_flow_ops.while_loop(
3208          lambda j, _: j < 3,
3209          lambda j, y: (j + 1,
3210                        y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
3211          [0, x])[1]
3212
3213    @def_function.function
3214    def bar(var):
3215      r = control_flow_ops.while_loop(
3216          lambda i, _: i < 2,
3217          lambda i, x: (i + 1, foo(x, var)),
3218          [0, 0.0])[1]
3219      return gradients_impl.gradients(r, var)[0]
3220
3221    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
3222    self.evaluate(variables.global_variables_initializer())
3223    grad = self.evaluate(bar(var))
3224    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])
3225
3226  def testWhileCondGrad_ResourceVarInFunctionCall(self):
3227
3228    @def_function.function
3229    def foo(x, var):
3230      return x + var.sparse_read([1])[0]
3231
3232    def body(i, x):
3233      return (i + 1, control_flow_ops.cond(
3234          math_ops.equal(i % 2, 0),
3235          lambda: foo(x, var1),
3236          lambda: foo(x, var2)))
3237
3238    @def_function.function
3239    def bar(var1, var2):
3240      r = control_flow_ops.while_loop(
3241          lambda i, _: i < 4, body, [0, 0.0])
3242      return gradients_impl.gradients(r, [var1, var2])
3243
3244    var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
3245    var2 = resource_variable_ops.ResourceVariable([4., 5.])
3246    self.evaluate(variables.global_variables_initializer())
3247    grads = self.evaluate(bar(var1, var2))
3248    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
3249    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])
3250
3251  @test_util.run_deprecated_v1
3252  def testWhileGrad_ResourceVarSparseRead(self):
3253    # NOTE(skyewm): this test is interesting because the gradient is the
3254    # aggregation result of IndexedSlices and Tensors.
3255    var = resource_variable_ops.ResourceVariable(np.ones(5),
3256                                                 dtype=dtypes.float32)
3257    r = control_flow_ops.while_loop(
3258        lambda i, _: i < 3,
3259        lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
3260        [0, constant_op.constant(1.0)])[1]
3261    grad = gradients_impl.gradients(r, var)[0]
3262
3263    self.evaluate(variables.global_variables_initializer())
3264    grad_val = self.evaluate(grad)
3265    arr = gradient_checker_v2._to_numpy(grad_val)
3266    self.assertAllEqual(arr, [0., 12., 0., 12., 0.])
3267
3268  @test_util.run_deprecated_v1
3269  def testWhileGrad_MultiResourceVarSparseRead(self):
3270    # NOTE(skyewm): this test is interesting because the gradient is the
3271    # aggregation result of IndexedSlices and Tensors.
3272    var1 = resource_variable_ops.ResourceVariable(np.ones(5),
3273                                                  dtype=dtypes.float32)
3274    var2 = resource_variable_ops.ResourceVariable(np.ones(3),
3275                                                  dtype=dtypes.float32)
3276    x1_init = constant_op.constant([0., 0.])
3277    x2_init = constant_op.constant(1.)
3278    x3_init = constant_op.constant(1.)
3279
3280    def body(i, unused_x1, x2, x3):
3281      y1 = var1.sparse_read([1, 3])
3282      y2 = x2 * 2
3283      y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
3284      return i + 1, y1, y2, y3
3285
3286    r = control_flow_ops.while_loop(
3287        lambda i, x1, x2, x3: i < 3, body,
3288        [0, x1_init, x2_init, x3_init])[1:]
3289    var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])
3290
3291    self.evaluate(variables.global_variables_initializer())
3292    var1_grad_val = self.evaluate(var1_grad)
3293    var2_grad_val = self.evaluate(var2_grad)
3294    self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
3295                        [0., 1., 0., 1., 0.])
3296    self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
3297                        [3., 0., 0.])
3298
3299  def testWhileGrad_Gather(self):
3300    # NOTE(skyewm): this test is interesting because the gather gradient
3301    # function returns an IndexedSlices.
3302    @tf_function_in_tf2
3303    def fn():
3304      x = constant_op.constant([1., 1., 1., 1., 1.])
3305      y = control_flow_ops.while_loop(
3306          lambda i, _: i < 3,
3307          lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
3308          [0, x[:1]])[1]
3309      z = y * 3.0
3310      grad = gradients_impl.gradients(z, x)[0]
3311      return y, grad
3312    y, grad = fn()
3313    self.assertEqual(self.evaluate(y), 8.)
3314    self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
3315
3316  def testWhileGrad_GatherNoFanOut(self):
3317    # NOTE(skyewm): this test is interesting because the gather gradient
3318    # function returns an IndexedSlices.
3319    @tf_function_in_tf2
3320    def fn():
3321      x = constant_op.constant([1., 1., 1., 1., 1.])
3322      y = control_flow_ops.while_loop(
3323          lambda i, _: i < 3,
3324          lambda i, x: (i + 1, array_ops.gather(x, [0])),
3325          [0, x[:1]])[1]
3326      z = y * 3.0
3327      grad = gradients_impl.gradients(z, x)[0]
3328      return y, grad
3329    y, grad = fn()
3330    self.assertEqual(self.evaluate(y), 1.)
3331    self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
3332
3333  @test_util.run_v1_only("b/120545219")
3334  def testWhileGradInCond(self):
3335
3336    with self.cached_session():
3337      n = ops.convert_to_tensor(1.0, name="n")
3338      x = array_ops.placeholder(dtypes.float32, shape=None)
3339      c = lambda n: math_ops.less(n, 10.0)
3340      b = lambda n: math_ops.add(n, x)
3341
3342      def fn1():
3343        r = control_flow_ops.while_loop(c, b, [n],
3344                                        [tensor_shape.unknown_shape()])
3345        return gradients_impl.gradients(r, x)[0]
3346
3347      r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
3348      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
3349
3350  @test_util.disable_control_flow_v2("b/116340060")
3351  @test_util.run_v1_only("b/120545219")
3352  def testGradInWhileWrtInitialLoopVal(self):
3353    with self.cached_session():
3354      x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
3355      y = x + 1
3356
3357      def body(i, v):
3358        z = v * 2
3359        return i + 1, gradients_impl.gradients(z, x)[0]
3360
3361      with self.assertRaisesRegex(
3362          ValueError,
3363          "Cannot compute gradient inside while loop with respect to op 'x'. "
3364          "We do not support taking the gradient wrt or through the initial "
3365          "value of a loop variable. Gradients can be computed through "
3366          "loop invariants or wrt the input parameters to the loop body."):
3367        control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
3368
3369  @test_util.run_v1_only("b/120545219")
3370  def testWhileGradInWhile(self):
3371    with self.cached_session():
3372      n = ops.convert_to_tensor(1.0, name="n")
3373      x = array_ops.placeholder(dtypes.float32, shape=None)
3374      c = lambda n: math_ops.less(n, 10.0)
3375      b = lambda n: math_ops.add(n, x)
3376
3377      def b1(n):
3378        r = control_flow_ops.while_loop(c, b, [n],
3379                                        [tensor_shape.unknown_shape()])
3380        return gradients_impl.gradients(r, x)
3381
3382      r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n],
3383                                      [tensor_shape.unknown_shape()])
3384      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
3385
3386  @test_util.run_v1_only("b/120545219")
3387  def testCondGradInNestedWhiles(self):
3388
3389    def outer_body(i, x):
3390      _, x = control_flow_ops.while_loop(
3391          lambda j, x: j < 3, inner_body, [0, 0.0])
3392      return i + 1, x
3393
3394    def inner_body(j, x):
3395      y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x)
3396      return j + 1, gradients_impl.gradients(y, x)[0]
3397
3398    i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
3399
3400    with self.cached_session() as sess:
3401      i_val, x_val = self.evaluate([i, x])
3402      self.assertEqual(i_val, 3)
3403      self.assertAllClose(x_val, 1.0)
3404
3405  @test_util.run_gpu_only
3406  def testGpuResourceAccess(self):
3407    with ops.device(test.gpu_device_name()):
3408      var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
3409
3410    @def_function.function
3411    def foo():
3412      return control_flow_ops.while_loop(
3413          lambda i, _: i < 3,
3414          lambda i, x: (i + 1, control_flow_ops.cond(
3415              constant_op.constant(True),
3416              lambda: x + var,
3417              lambda: x)),
3418          [0, 0.0])[1]
3419
3420    self.evaluate(variables.global_variables_initializer())
3421    self.assertEqual(self.evaluate(foo()), 9.0)
3422
3423  def testNestedResourceAccess(self):
3424    var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
3425
3426    @eager_function.defun
3427    def test_fn():
3428      x = constant_op.constant(0.0)
3429      r = control_flow_ops.while_loop(
3430          # Outer loop condition
3431          lambda i, y: i < 2,
3432          # Outer loop body
3433          lambda i, y: (i + 1, y + control_flow_ops.cond(
3434              constant_op.constant(True),
3435              # True branch
3436              lambda: control_flow_ops.while_loop(
3437                  # Inner loop condition
3438                  lambda j, z: j < 3,
3439                  # Inner loop body
3440                  lambda j, z: (j + 1, z + math_ops.square(var)),
3441                  # Inner initial loop value
3442                  [0, y])[1],
3443              # False branch
3444              lambda: (0.0))),
3445          # Outer initial loop value
3446          [0, x])[1]
3447
3448      grad = gradients_impl.gradients(r, x)[0]
3449      return r, grad
3450
3451    self.evaluate(variables.global_variables_initializer())
3452    r, grad = self.evaluate(test_fn())
3453    # 2 * 3 * 3^2
3454    self.assertEqual(r, 81.0)
3455    # v1 control flow gets the wrong answer!!!
3456    # Gradient computation:
3457    #   f(x) = x + 3^2
3458    #   inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27
3459    #   g(x) = x + inner_loop(x) = 2x + 27
3460    #   outer_loop(x) = g(g(x)) = 4x + 81
3461    #   outer_loop'(x) = 4
3462    # Note that v1 control flow gets 4.0 as well if the cond is removed.
3463    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
3464      self.assertEqual(grad, 4.0)
3465
3466  def testWhile_NestedInput(self):
3467    with self.cached_session() as sess:
3468      named = collections.namedtuple("named", ("a", "b"))
3469      loop_vars = [
3470          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3471          (constant_op.constant(2.0), constant_op.constant(3.0)),
3472          constant_op.constant(4.0)
3473      ]
3474      c = lambda lv0, _1, _2: lv0.a < 100.0
3475
3476      def b(lv0, lv1, lv2):
3477        lv0 = named(a=lv0.a + 1, b=lv0.b)
3478        lv1 = (lv1[0] + 1, lv1[1])
3479        lv2 += 2
3480        return [lv0, lv1, lv2]
3481
3482      r = control_flow_ops.while_loop(c, b, loop_vars)
3483
3484      self.assertTrue(isinstance(r, list))
3485      self.assertTrue(isinstance(r[0], named))
3486      self.assertTrue(isinstance(r[1], tuple))
3487      self.assertTrue(isinstance(r[2], ops.Tensor))
3488
3489      r_flattened = nest.flatten(r)
3490      self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
3491                       self.evaluate(r_flattened))
3492
3493  @test_util.run_v1_only("b/120545219")
3494  def testWhile_NestedBadArityFails(self):
3495    with self.cached_session():
3496      named = collections.namedtuple("named", ("a", "b"))
3497      loop_vars = [
3498          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3499          (constant_op.constant(2.0), constant_op.constant(3.0)),
3500          constant_op.constant(4.0)
3501      ]
3502      c = lambda lv0, _1, _2: lv0.a < 100.0
3503
3504      def b(lv0, lv1, _):
3505        return [lv0, lv1]
3506
3507      with self.assertRaisesRegex(ValueError, "the same number of elements"):
3508        control_flow_ops.while_loop(c, b, loop_vars)
3509
3510  @test_util.run_v1_only("b/120545219")
3511  def testWhileGrad_ys_xs(self):
3512    with self.cached_session():
3513      x = constant_op.constant(3.0, name="x")
3514      y = constant_op.constant(2.0, name="y")
3515
3516      c = lambda x, y: math_ops.less(x, 100.0)
3517
3518      def b(x, y):
3519        y1 = math_ops.add(x, y)
3520        x1 = math_ops.multiply(x, y1)
3521        return x1, y1
3522
3523      rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1)
3524
3525      r = gradients_impl.gradients([rx, ry], x)
3526      self.assertAllClose(304.0, r[0])
3527      r = gradients_impl.gradients([rx, ry], y)
3528      self.assertAllClose(124.0, r[0])
3529      r = gradients_impl.gradients([rx], x)
3530      self.assertAllClose(295.0, r[0])
3531      r = gradients_impl.gradients([rx], y)
3532      self.assertAllClose(120.0, r[0])
3533
3534  @test_util.run_deprecated_v1
3535  def testWhileGrad_Dependency(self):
3536    with self.cached_session():
3537      i = constant_op.constant(0, name="i")
3538      x = constant_op.constant(2.0, name="x")
3539
3540      c = lambda i, x: math_ops.less(i, 10)
3541
3542      def b(i, x):
3543        x = math_ops.multiply(x, 2.0)
3544        i = math_ops.add(i, 1)
3545        return i, x
3546
3547      ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3548
3549      r = gradients_impl.gradients([ri, rx], x)
3550      self.assertAllClose(1024.0, r[0])
3551      r = gradients_impl.gradients([rx], x)
3552      self.assertAllClose(1024.0, r[0])
3553
3554  @test_util.run_v1_only("b/120545219")
3555  def testWhileGrad_NoGradient(self):
3556    with self.cached_session():
3557      v = constant_op.constant(2.0, name="v")
3558      c = lambda v: math_ops.less(v, 100.0)
3559      b = math_ops.square
3560      r = control_flow_ops.while_loop(c, b, [v], back_prop=False)
3561      r = math_ops.add(r, v)
3562      r = gradients_impl.gradients(r, v)
3563      self.assertAllClose(1.0, r[0])
3564
3565  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3566  @test_util.run_v1_only("b/120545219")
3567  def testWhileGrad_NoDependency(self):
3568    with self.cached_session() as sess:
3569      variable = variables.Variable(array_ops.ones([2, 3]))
3570      duration = array_ops.zeros([], dtype=dtypes.int32)
3571
3572      def cond(duration, tensor, _):
3573        del tensor
3574        return duration < 10
3575
3576      def body(duration, tensor, _):
3577        return (duration + 1, tensor, tensor)
3578
3579      loop_vars = [duration, variable, variable]
3580      tensors = control_flow_ops.while_loop(
3581          cond=cond, body=body, loop_vars=loop_vars)
3582      cost = math_ops.reduce_sum(tensors[2])
3583      grad = gradients_impl.gradients(cost, [variable])
3584      self.evaluate(variables.global_variables_initializer())
3585      self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
3586
3587  @test_util.run_deprecated_v1
3588  def testWhileGrad_Const(self):
3589    with self.cached_session() as sess:
3590      c0 = constant_op.constant(0.0, name="c0")
3591      c1 = constant_op.constant(1.0, name="c1")
3592      duration = constant_op.constant(0, name="t")
3593
3594      def cond(duration, _):
3595        return duration < 1
3596
3597      def body(duration, _):
3598        return duration + 1, c1
3599
3600      loop_vars = [duration, c0]
3601      tensors = control_flow_ops.while_loop(
3602          cond=cond, body=body, loop_vars=loop_vars)
3603      cost = math_ops.reduce_sum(tensors[1])
3604      grad = gradients_impl.gradients(cost, [c0])
3605      self.assertAllClose(0.0, sess.run(grad[0]))
3606
3607  @test_util.run_v1_only("b/120545219")
3608  def testWhileGrad_SerialTwoLoops(self):
3609    with self.cached_session():
3610      i = constant_op.constant(0, name="i")
3611      x = constant_op.constant(2.0, name="x")
3612
3613      c = lambda i, x: math_ops.less(i, 5)
3614
3615      def b(i, x):
3616        x = math_ops.multiply(x, 2.0)
3617        i = math_ops.add(i, 1)
3618        return i, x
3619
3620      _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3621      _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1)
3622
3623      r = gradients_impl.gradients([rx], x)
3624      self.assertAllClose(1024.0, r[0])
3625
3626  @test_util.run_v1_only("b/120545219")
3627  def testWhileGrad_ParallelTwoLoops(self):
3628    with self.cached_session():
3629      i = constant_op.constant(0, name="i")
3630      x = constant_op.constant(2.0, name="x")
3631
3632      c = lambda i, x: math_ops.less(i, 5)
3633
3634      def b(i, x):
3635        x = math_ops.multiply(x, 2.0)
3636        i = math_ops.add(i, 1)
3637        return i, x
3638
3639      _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3640      _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3641      rx = math_ops.add(r1, r2)
3642
3643      r = gradients_impl.gradients([rx], x)
3644      self.assertAllClose(64.0, r[0])
3645
3646  @test_util.run_v1_only("b/120545219")
3647  def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
3648    with self.cached_session():
3649      i = constant_op.constant(0, name="i")
3650      x = constant_op.constant(1.0, name="x")
3651      y = constant_op.constant(1.0, name="y")
3652      c = lambda i, *_: math_ops.less(i, 1, name="cond_less")
3653
3654      def b(i, xi, yi):
3655        # return (i + 1, xi, xi + yi)
3656        return (math_ops.add(i, 1, name="inc"), array_ops.identity(
3657            xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))
3658
3659      _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y])
3660      with ops.control_dependencies([x_f]):
3661        y_f_d = array_ops.identity(y_f, name="y_f_d")
3662
3663      self.assertAllClose(2.0, self.evaluate(y_f_d))  # y_f_d = 1.0 + 1.0
3664      g = gradients_impl.gradients([y_f_d], [x])[0]
3665      self.assertTrue(g is not None)
3666      self.assertAllClose(1.0,
3667                          self.evaluate(g))  # y_f_d = x + 1.0, dy_f_d/dx = 1.0
3668
3669  def _testNestedWhileGrad_Simple(self, use_gpu):
3670    with self.cached_session(use_gpu=use_gpu):
3671      v = constant_op.constant(1.0)
3672
3673      def inner_loop(s):
3674        c = lambda x: math_ops.less(x, 4.0)
3675        b = lambda x: math_ops.multiply(x, 2.0)
3676        return control_flow_ops.while_loop(c, b, [s])
3677
3678      c = lambda x: math_ops.less(x, 2.0)
3679      b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
3680      r = control_flow_ops.while_loop(c, b, [v])
3681
3682      r = gradients_impl.gradients(r, v)[0]
3683      self.assertAllClose(8.0, self.evaluate(r))
3684
3685  @test_util.run_deprecated_v1
3686  def testNestedWhileGrad_Simple(self):
3687    self._testNestedWhileGrad_Simple(use_gpu=False)
3688    self._testNestedWhileGrad_Simple(use_gpu=True)
3689
3690  @test_util.run_v1_only("b/120545219")
3691  def testNestedWhileGrad_SerialInner(self):
3692    with self.cached_session():
3693      v = constant_op.constant(1.0)
3694
3695      def inner_loop1(s):
3696        z = constant_op.constant(0)
3697        c = lambda i, x: math_ops.less(i, 4)
3698        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3699        return control_flow_ops.while_loop(c, b, [z, s])
3700
3701      def inner_loop2(s):
3702        z = constant_op.constant(0)
3703        c = lambda i, x: math_ops.less(i, 4)
3704        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3705        return control_flow_ops.while_loop(c, b, [z, s])
3706
3707      c = lambda x: math_ops.less(x, 128.0)
3708      b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
3709      r = control_flow_ops.while_loop(c, b, [v])
3710
3711      r = gradients_impl.gradients(r, v)[0]
3712      self.assertAllClose(256.0, self.evaluate(r))
3713
3714  @test_util.run_deprecated_v1
3715  def testNestedWhileGrad_ParallelInner(self):
3716    with self.cached_session():
3717      v = constant_op.constant(1.0)
3718
3719      def inner_loop1(s):
3720        z = constant_op.constant(0)
3721        c = lambda i, x: math_ops.less(i, 4)
3722        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3723        return control_flow_ops.while_loop(c, b, [z, s])
3724
3725      def inner_loop2(s):
3726        z = constant_op.constant(0)
3727        c = lambda i, x: math_ops.less(i, 4)
3728        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3729        return control_flow_ops.while_loop(c, b, [z, s])
3730
3731      c = lambda x: math_ops.less(x, 128.0)
3732      b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
3733      r = control_flow_ops.while_loop(c, b, [v])
3734
3735      r = gradients_impl.gradients(r, v)[0]
3736      self.assertAllClose(512.0, self.evaluate(r))
3737
3738  @test_util.run_v1_only("b/120545219")
3739  def testNestedWhileGrad_ParallelIterations(self):
3740    # Make sure the stack pushes and pops of an inner loop are executed in
3741    # the sequential order of the iterations of its outer loop.
3742    with self.cached_session() as sess:
3743
3744      def inner_loop(t):
3745        fn = lambda n: n + math_ops.square(var)
3746        return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10)
3747
3748      def outer_loop(inp):
3749        return map_fn.map_fn(
3750            fn=inner_loop, elems=inp, parallel_iterations=10)
3751
3752      var = variables.Variable(constant_op.constant(3.0))
3753      inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
3754      res = outer_loop(inp)
3755      optimizer = adam.AdamOptimizer(learning_rate=0.001)
3756      train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
3757      self.evaluate(variables.global_variables_initializer())
3758      self.evaluate(train_op)
3759      self.assertAllClose(2.999, var.read_value())
3760
3761  def _testWhileCondGrad_Simple(self, use_gpu):
3762    with self.cached_session(use_gpu=use_gpu):
3763      v = ops.convert_to_tensor(2.0, name="v")
3764      n = ops.convert_to_tensor(100.0, name="n")
3765      one = ops.convert_to_tensor(1.0, name="one")
3766      c = lambda x: math_ops.less(x, n)
3767      # pylint: disable=undefined-variable
3768      # for OSS build
3769      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3770                                          lambda: math_ops.square(x),
3771                                          lambda: math_ops.subtract(x, one))
3772      # pylint: enable=undefined-variable
3773      r = control_flow_ops.while_loop(c, b, [v])
3774      r = gradients_impl.gradients(r, v)[0]
3775      self.assertAllClose(1024.0, self.evaluate(r))
3776
3777  @test_util.run_deprecated_v1
3778  def testWhileCondGrad_Simple(self):
3779    self._testWhileCondGrad_Simple(use_gpu=False)
3780    self._testWhileCondGrad_Simple(use_gpu=True)
3781
3782  @test_util.run_deprecated_v1
3783  def testWhileCondGrad_UnknownShape(self):
3784    with self.cached_session() as sess:
3785      v = array_ops.placeholder(dtypes.float32)
3786      n = ops.convert_to_tensor(100.0, name="n")
3787      one = ops.convert_to_tensor(1.0, name="one")
3788      c = lambda x: math_ops.less(x, n)
3789      # pylint: disable=undefined-variable
3790      # for OSS build
3791      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3792                                          lambda: math_ops.square(x),
3793                                          lambda: math_ops.subtract(x, one))
3794      # pylint: enable=undefined-variable
3795      r = control_flow_ops.while_loop(c, b, [v])
3796      r = gradients_impl.gradients(r, v)[0]
3797      r = sess.run(r, feed_dict={v: 2.0})
3798      self.assertAllClose(1024.0, r)
3799
3800  @test_util.run_deprecated_v1
3801  def testWhileGrad_Concat(self):
3802    with self.cached_session() as sess:
3803      x = variable_scope.get_variable("x", initializer=[[1., 2.]])
3804      i0 = constant_op.constant(0)
3805      h0 = array_ops.zeros([0, 2])
3806
3807      def condition(i, _):
3808        return i < 2
3809
3810      def body(i, h):
3811        return i + 1, array_ops.concat([h, x], 0)
3812
3813      _, h = control_flow_ops.while_loop(
3814          condition, body, [i0, h0],
3815          [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
3816      s = math_ops.reduce_sum(h)
3817
3818      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
3819      op = optimizer.minimize(s)
3820
3821      self.evaluate(variables.global_variables_initializer())
3822      self.evaluate(op)
3823      self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))
3824
3825  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3826  @test_util.run_v1_only("b/120545219")
3827  def testWhileWithRefsWithGradients_1(self):
3828    with self.cached_session() as sess:
3829      x = variables.VariableV1(0.)._ref()  # pylint: disable=protected-access
3830      i = constant_op.constant(0)
3831      c = lambda i, x: math_ops.less(i, 10)
3832
3833      self.assertEqual(x.dtype, dtypes.float32_ref)
3834
3835      def body(i, x):
3836        self.assertEqual(x.dtype, dtypes.float32_ref)
3837        return [i + 1, gen_array_ops.ref_identity(x)]
3838
3839      r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
3840
3841      grad_ys = [variables.VariableV1(73)._ref()]  # pylint: disable=protected-access
3842      grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
3843
3844      self.evaluate(variables.global_variables_initializer())
3845
3846      self.assertEqual(r[0].dtype, dtypes.int32)
3847      self.assertEqual(r[1].dtype, dtypes.float32_ref)
3848
3849      value_i, value_x, value_x_grad = sess.run(r + grad)
3850
3851    self.assertEqual(10, value_i)
3852    self.assertEqual(0, value_x)
3853    self.assertEqual(73, value_x_grad)
3854
3855  @test_util.deprecated_graph_mode_only
3856  def testWhileGrad_IndexedSlices(self):
3857    with self.cached_session():
3858      values = constant_op.constant([2.0, 4.0], name="values")
3859      indices = constant_op.constant([0, 3], name="indices")
3860      shape = constant_op.constant([10], name="dense_shape")
3861      i = constant_op.constant(0)
3862      x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape)
3863
3864      def c(i, _):
3865        return i < 10
3866
3867      def b(i, x):
3868        return [
3869            i + 1,
3870            indexed_slices.IndexedSlices(x.values * 2.0, x.indices,
3871                                         x.dense_shape)
3872        ]
3873
3874      _, r = control_flow_ops.while_loop(c, b, [i, x])
3875      r = gradients_impl.gradients(r.values, values)[0]
3876      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3877
3878  @test_util.deprecated_graph_mode_only
3879  def testWhileGrad_SparseTensor(self):
3880    with self.cached_session():
3881      values = constant_op.constant([2.0, 4.0], name="values")
3882      indices = constant_op.constant(
3883          [[0], [3]], dtype=dtypes.int64, name="indices")
3884      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
3885      i = constant_op.constant(0)
3886      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
3887
3888      def c(i, _):
3889        return i < 10
3890
3891      def b(i, x):
3892        return [
3893            i + 1,
3894            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
3895        ]
3896
3897      _, r = control_flow_ops.while_loop(c, b, [i, x])
3898      r = gradients_impl.gradients(r.values, values)[0]
3899      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3900
3901  @test_util.deprecated_graph_mode_only
3902  def testCallGradInLoop(self):
3903    with self.cached_session() as sess:
3904      i0 = constant_op.constant(0)
3905      params = constant_op.constant(5.0)
3906      params_1 = math_ops.square(params)
3907
3908      def c(i, _):
3909        return i < 10
3910
3911      def b(i, x):
3912        data = constant_op.constant([1.0, 2.0, 3.0])
3913        data = math_ops.multiply(data, params_1)
3914        x1 = x + gradients_impl.gradients(data, params)[0]
3915        return i + 1, x1
3916
3917      output_grad = control_flow_ops.while_loop(
3918          c, b, [i0, constant_op.constant(0.0)])
3919      self.assertAllClose(600.0, self.evaluate(output_grad)[1])
3920
3921  @test_util.run_deprecated_v1
3922  def testWhileAndTensorArray(self):
3923    with self.cached_session() as sess:
3924      param = constant_op.constant(2.0)
3925      n0 = constant_op.constant(0)
3926      y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
3927
3928      def c(i, _):
3929        return i < 10
3930
3931      def b(i, y):
3932        return [
3933            i + 1,
3934            map_fn.map_fn(lambda x: math_ops.multiply(x, param), y)
3935        ]
3936
3937      r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1)
3938      r = gradients_impl.gradients(r, param)[0]
3939      self.assertAllClose(107520.0, self.evaluate(r))
3940
3941  @test_util.run_deprecated_v1
3942  def testNestedWhileAndTensorArray(self):
3943    n = constant_op.constant(3.0)
3944
3945    def Body(row, ta):
3946
3947      def InnerBody(row, col, ta):
3948        # Note: row and col are 1-based.
3949        ta = ta.write(
3950            math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
3951        return row, col + 1., ta
3952
3953      ta = control_flow_ops.while_loop(
3954          lambda _, col, _1: col <= n,
3955          InnerBody, [row, constant_op.constant(1.), ta],
3956          return_same_structure=False)[2]
3957      return row + 1., ta
3958
3959    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
3960    ta = control_flow_ops.while_loop(
3961        lambda row, _: row <= n,
3962        Body, [constant_op.constant(1.), ta],
3963        return_same_structure=False)[1]
3964
3965    output = array_ops.reshape(ta.stack(), [3, 3])
3966    self.assertAllEqual(
3967        self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]])
3968    # TODO(b/117675481): This does not work with current TA. Enable with new TA.
3969    # grad = gradients_impl.gradients(output, [n])
3970    # self.assertEqual(self.evaluate(grad), 3.5)
3971
3972  @test_util.run_deprecated_v1
3973  def testWhileGrad_StopGrad(self):
3974    with self.cached_session():
3975      x = constant_op.constant(3.0, name="x")
3976      y = constant_op.constant(2.0, name="y")
3977
3978      c = lambda x, y: math_ops.less(x, 100.0)
3979
3980      def b(x, y):
3981        y1 = math_ops.square(y)
3982        x1 = math_ops.add(math_ops.square(x), y1)
3983        return x1, y1
3984
3985      rx, ry = control_flow_ops.while_loop(c, b, [x, y])
3986
3987      r = gradients_impl.gradients(rx, y)[0]
3988      self.assertEqual(136.0, self.evaluate(r))
3989      r = gradients_impl.gradients(ry, y)[0]
3990      self.assertEqual(32.0, self.evaluate(r))
3991
3992      r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
3993      self.assertEqual(r, None)
3994      r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
3995      self.assertEqual(r, None)
3996
3997      r = gradients_impl.gradients(
3998          array_ops.stop_gradient(math_ops.square(rx)), y)[0]
3999      self.assertEqual(r, None)
4000      r = gradients_impl.gradients(
4001          array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
4002      self.assertEqual(r, None)
4003      r = gradients_impl.gradients(
4004          array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
4005      self.assertEqual(r, None)
4006
4007      r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
4008      self.assertEqual(168.0, self.evaluate(r))
4009      r = gradients_impl.gradients(
4010          math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
4011      self.assertEqual(136.0, self.evaluate(r))
4012      r = gradients_impl.gradients(
4013          math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
4014      self.assertEqual(32.0, self.evaluate(r))
4015
4016  @test_util.run_deprecated_v1
4017  def testWhileGrad_StopGradInside(self):
4018    with self.cached_session():
4019      x = constant_op.constant(3.0, name="x")
4020      y = constant_op.constant(2.0, name="y")
4021
4022      c = lambda x, y: math_ops.less(x, 100.0)
4023
4024      def b(x, y):
4025        y1 = array_ops.stop_gradient(math_ops.square(y))
4026        x1 = math_ops.add(math_ops.square(x), y1)
4027        return x1, y1
4028
4029      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
4030
4031      r = gradients_impl.gradients(rx, y)[0]
4032      self.assertAllClose(0.0, self.evaluate(r))
4033      r = gradients_impl.gradients(rx, x)[0]
4034      self.assertAllClose(156.0, self.evaluate(r))
4035
4036  @test_util.run_deprecated_v1
4037  def testWhileGrad_StopGradInsideNoShape(self):
4038    with self.cached_session() as sess:
4039      x = array_ops.placeholder(dtypes.float32)
4040      y = array_ops.placeholder(dtypes.float32)
4041
4042      c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)
4043
4044      def b(x, y):
4045        y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
4046        x1 = math_ops.add(math_ops.square(x), y1)
4047        return x1, y1
4048
4049      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
4050
4051      grad_y = gradients_impl.gradients(rx, y)[0]
4052      grad_x = gradients_impl.gradients(rx, x)[0]
4053      feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
4054      self.assertAllClose([0.0, 0.0], sess.run(grad_y, feed_dict=feed_dict))
4055      self.assertAllClose([156.0, 400.0], sess.run(grad_x, feed_dict=feed_dict))
4056      name = "gradients/while/stopped_grad"
4057      all_ops = x.graph.get_operations()
4058      self.assertFalse(any(name in op.name for op in all_ops))
4059
4060  @test_util.run_deprecated_v1
4061  def testWhileGradGradFail(self):
4062    theta = variables.Variable(initial_value=1.)
4063
4064    def fn(prev, x):
4065      return prev + x * theta
4066
4067    result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
4068    grad_theta = gradients_impl.gradients(result, theta)
4069    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
4070      with self.assertRaisesRegex(TypeError, "Second-order gradient"):
4071        gradients_impl.gradients(grad_theta, theta)
4072    grad_theta_stopped = array_ops.stop_gradient(grad_theta)
4073    gradients_impl.gradients(grad_theta_stopped, theta)
4074
4075  @test_util.run_deprecated_v1
4076  def testStopGradOnWhileGrad(self):
4077    with self.cached_session():
4078      x = constant_op.constant(2.0, name="x")
4079      y = constant_op.constant(2.0, name="y")
4080
4081      c = lambda x: math_ops.less(x, 100.0)
4082      b = lambda x: math_ops.multiply(x, y)
4083      rx = control_flow_ops.while_loop(c, b, [x])
4084
4085      rg = gradients_impl.gradients(rx, y)[0]
4086      rg = array_ops.stop_gradient(rg)
4087      r = math_ops.add(math_ops.square(y), rx)
4088      r = math_ops.add(r, rg)
4089      r = gradients_impl.gradients(r, y)[0]
4090      self.assertEqual(388.0, self.evaluate(r))
4091
4092  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
4093  @test_util.run_deprecated_v1
4094  def testWhileGradientWithNontrainablePath1(self):
4095    q = variables.Variable([7., 8.])
4096
4097    def cond(_, y):
4098      del y
4099      return False
4100
4101    def body(x, _):
4102      return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
4103
4104    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
4105    dy_dq, = gradients_impl.gradients(y, q)
4106    self.assertIsNotNone(dy_dq)
4107    with self.cached_session() as sess:
4108      self.evaluate(q.initializer)
4109      self.assertAllClose([0., 0.], self.evaluate(dy_dq))
4110
4111  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
4112  @test_util.run_v1_only("b/120545219")
4113  def testWhileGradientWithNontrainablePath2(self):
4114    q = variables.Variable([7., 8.])
4115
4116    def cond(_, y):
4117      return math_ops.equal(y, 0.)
4118
4119    def body(x, _):
4120      zero = constant_op.constant(0, dtype=dtypes.int64)
4121      return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
4122
4123    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
4124    dy_dq, = gradients_impl.gradients(y, q)
4125    self.assertIsNotNone(dy_dq)
4126    with self.cached_session() as sess:
4127      self.evaluate(q.initializer)
4128      self.assertAllClose([1., 1.], self.evaluate(dy_dq))
4129
4130  @test_util.run_v1_only("b/120545219")
4131  def testIssue16504(self):
4132    c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
4133    w = variables.Variable(
4134        initial_value=np.ones(100), dtype=dtypes.float32) / 100
4135    k = variables.Variable(0, dtype=dtypes.int32)
4136    chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
4137
4138    def cond(k, _, chg_w):
4139      return math_ops.logical_and(k < 10, chg_w > 1e-3)
4140
4141    def body(k, w, chg_w):
4142      grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
4143      w_n = w * math_ops.exp(-0.1 * grad)
4144      w_n /= math_ops.reduce_sum(w_n)
4145      chg_w = (
4146          math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
4147              math_ops.abs(w)))
4148      return k + 1, w_n, chg_w
4149
4150    _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
4151    grad, = gradients_impl.gradients(w, c)
4152    self.assertIsNotNone(grad)
4153
4154  @test_util.run_v1_only("b/120545219")
4155  def testStopGradMultiFlows(self):
4156    with self.cached_session():
4157
4158      def body(i, y, r):
4159        x = variable_scope.get_variable(
4160            "x",
4161            shape=(),
4162            dtype=dtypes.float32,
4163            initializer=init_ops.ones_initializer())
4164        y *= x
4165        return [i + 1, y, r + math_ops.reduce_sum(y)]
4166
4167      i0 = constant_op.constant(0)
4168      y0 = array_ops.ones(5)
4169      r0 = constant_op.constant(0.0)
4170      cond = lambda i, y, r: i < 1
4171      _, _, r = control_flow_ops.while_loop(
4172          cond, body, [i0, y0, r0], back_prop=True)
4173
4174      vars_ = variables.global_variables()
4175      grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
4176      z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
4177      result = gradients_impl.gradients(z, vars_)[0]
4178      self.evaluate(variables.global_variables_initializer())
4179      self.assertEqual(5.0, self.evaluate(result))
4180
4181  @test_util.run_v1_only("b/120545219")
4182  def testOneValueCond(self):
4183
4184    with self.cached_session():
4185      c = array_ops.placeholder(dtypes.int32, shape=[])
4186      one = ops.convert_to_tensor(1, name="one")
4187      two = ops.convert_to_tensor(2, name="two")
4188      p = math_ops.greater_equal(c, 1)
4189      i = control_flow_ops.cond(p, lambda: one, lambda: two)
4190      self.assertTrue(isinstance(i, ops.Tensor))
4191
4192      # True case: c = 2 is >= 1
4193      self.assertEqual([1], i.eval(feed_dict={c: 2}))
4194
4195      # False case: c = 0 is not >= 1
4196      self.assertEqual([2], i.eval(feed_dict={c: 0}))
4197
4198  @test_util.run_deprecated_v1
4199  def testExampleCond(self):
4200
4201    with self.cached_session():
4202      x = ops.convert_to_tensor([-2.0, 2.0], name="x")
4203      d = array_ops.placeholder(dtypes.int32, shape=[])
4204
4205      def l2():
4206        return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))
4207
4208      def l1():
4209        return math_ops.reduce_sum(math_ops.abs(x))
4210
4211      i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1)
4212      self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
4213      self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
4214
4215  @test_util.run_v1_only("b/120545219")
4216  def testCase(self):
4217    with self.cached_session():
4218      x = constant_op.constant(1)
4219      y = constant_op.constant(2)
4220      z = constant_op.constant(3)
4221      f1 = lambda: constant_op.constant(17)
4222      f2 = lambda: constant_op.constant(23)
4223      f3 = lambda: constant_op.constant(-1)
4224
4225      r1 = control_flow_ops.case(
4226          {
4227              x < y: f1,
4228              x > z: f2
4229          }, default=f3, exclusive=True)
4230      self.assertAllEqual(r1, 17)
4231
4232      r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3)
4233      self.assertAllEqual(r2, 23)
4234
4235      # Duplicate events can happen, first one is selected
4236      r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3)
4237      self.assertAllEqual(r3, 17)
4238
4239      # Duplicate events cause an error if exclusive = True
4240      r4 = control_flow_ops.case(
4241          [(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
4242      with self.assertRaisesOpError("Input error:"):
4243        self.evaluate(r4)
4244
4245      # Check that the default is called if none of the others are
4246      r5 = control_flow_ops.case({x > y: f1}, default=f3)
4247      self.assertAllEqual(r5, -1)
4248
4249      ran_once = [False, False, False]
4250
4251      def break_run_twice(ix):
4252
4253        def _break():
4254          ran_once[ix] = True
4255          return constant_op.constant(ix)
4256
4257        return _break
4258
4259      # Should not fail - each conditional gets called exactly once
4260      # except default.  Default gets called twice: once to create an
4261      # empty output and once for the actual cond switch.
4262      r6 = control_flow_ops.case(
4263          [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
4264          default=lambda: constant_op.constant(2))
4265
4266      self.assertAllEqual(r6, 0)
4267
4268  @test_util.run_v1_only("b/120545219")
4269  def testCaseSideEffects(self):
4270    with self.cached_session() as sess:
4271      v0 = variables.Variable(-1)
4272      v1 = variables.Variable(-1)
4273      v2 = variables.Variable(-1)
4274
4275      a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
4276      b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
4277      c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)
4278
4279      x = constant_op.constant(1)
4280      y = constant_op.constant(2)
4281
4282      r0 = control_flow_ops.case(
4283          ((x < y, a), (x > y, b)), default=c, exclusive=True)
4284      r1 = control_flow_ops.case(
4285          ((x > y, a), (x < y, b)), default=c, exclusive=True)
4286      r2 = control_flow_ops.case(
4287          ((x > y, a), (x > y, b)), default=c, exclusive=True)
4288
4289      self.evaluate(variables.global_variables_initializer())
4290      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4291      self.assertEqual(2, self.evaluate(r2))
4292      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])
4293
4294      self.evaluate(variables.global_variables_initializer())
4295      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4296      self.assertEqual(1, self.evaluate(r1))
4297      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])
4298
4299      self.evaluate(variables.global_variables_initializer())
4300      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4301      self.assertEqual(0, self.evaluate(r0))
4302      self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])
4303
4304  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
4305  @test_util.run_v1_only("b/120545219")
4306  def testOneOpCond(self):
4307    with self.cached_session():
4308      v = variables.Variable(0)
4309      c = ops.convert_to_tensor(0)
4310      one = ops.convert_to_tensor(1)
4311      two = ops.convert_to_tensor(2)
4312      p = math_ops.greater_equal(c, 1)
4313
4314      def a():
4315        return state_ops.assign(v, one)
4316
4317      def b():
4318        return state_ops.assign(v, two)
4319
4320      i = control_flow_ops.cond(p, a, b)
4321      self.assertTrue(isinstance(i, ops.Tensor))
4322      self.evaluate(variables.global_variables_initializer())
4323
4324      self.assertEqual(0, self.evaluate(v))
4325
4326      # True case: c = 2 is >= 1, v is set to 1.
4327      self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
4328      self.assertEqual(1, self.evaluate(v))
4329
4330      # False case: c = 0 is not >= 1, v is set to 2.
4331      self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
4332      self.assertEqual(2, self.evaluate(v))
4333
4334  @test_util.run_v1_only("b/120545219")
4335  def testWithOpsDependencies(self):
4336    with self.cached_session() as sess:
4337      v = variables.VariableV1(0.0)
4338      c = constant_op.constant(10)
4339
4340      # Fetching v directly will result in an uninitialized error
4341      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4342        self.evaluate([c, v])
4343
4344      # Use a control dependency to ensure init_variable is run
4345      # while asking for c
4346      real_v = control_flow_ops.with_dependencies(
4347          name="real_tensor",
4348          output_tensor=v._ref(),  # pylint: disable=protected-access
4349          dependencies=[v.initializer])
4350      c_val, real_v_val = self.evaluate([c, real_v])
4351
4352    # Ensure the result of 'real_c' is the same as 'c'
4353    self.assertAllEqual(10, c_val)
4354
4355    # Ensure that 'v' is initialized
4356    self.assertAllClose(0.0, real_v_val)
4357
4358  @test_util.run_v1_only("b/120545219")
4359  def testWithTensorDependencies(self):
4360    with self.cached_session():
4361      v = variables.VariableV1(0.0)
4362      c1 = constant_op.constant(10)
4363      c2 = constant_op.constant(20)
4364
4365      # c1_with_init_v depends on the init op for v
4366      c1_with_init_v = control_flow_ops.with_dependencies(
4367          name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
4368      # c2_with_c1 depends on the value of c1_with_init_v
4369      c2_with_c1_dep = control_flow_ops.with_dependencies(
4370          name="c2_with_c1_dep",
4371          output_tensor=c2,
4372          dependencies=[c1_with_init_v])
4373
4374      # Fetching v directly will result in an uninitialized error
4375      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4376        self.evaluate(v)
4377
4378      # Get the value of 'c2_with_c1_dep', which should cause 'v'
4379      # to be initialized.
4380      self.assertAllEqual(20, self.evaluate(c2_with_c1_dep))
4381
4382      # Ensure that 'v' is initialized
4383      self.assertAllClose(0.0, self.evaluate(v))
4384
4385  @test_util.run_v1_only("b/120545219")
4386  def testWithIndexedSlicesDependencies(self):
4387    with self.cached_session():
4388      v = variables.VariableV1(
4389          np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
4390      v_at_1 = indexed_slices.IndexedSlices(v, constant_op.constant([1]))
4391      gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
4392      v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
4393                                                             v_at_1)
4394      gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
4395                                                  v_at_1_after_init.indices)
4396
4397      # Fetching gather_v_at_1 will result in an uninitialized error
4398      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4399        self.evaluate(gather_v_at_1)
4400
4401      # Getting gather_v_at_1_after_init will work, and initialize v.
4402      self.assertAllEqual([[10.0, 11.0]],
4403                          self.evaluate(gather_v_at_1_after_init))
4404
4405      # Double check that 'v' is initialized
4406      self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
4407                          self.evaluate(v))
4408
4409  def testDependenciesDevice(self):
4410    with ops.Graph().as_default():
4411      # device set on tensor => same device on dep.
4412      with ops.device("/job:ps"):
4413        vd = variables.VariableV1([0.0])
4414      with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
4415      self.assertTrue("/job:ps" in with_vd_dep.device)
4416
4417      # No device set on tensor => no device on dep.
4418      vnod = variables.VariableV1([0.0])
4419      with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
4420                                                         vnod)
4421      self.assertDeviceEqual(None, with_vnod_dep.device)
4422
4423      # device set on tensor, default device on graph => default device on dep.
4424      vdef = variables.VariableV1([0.0], name="vdef")
4425      with ops.device("/job:worker/device:GPU:1"):
4426        with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
4427                                                           vdef)
4428        # The device is empty, but the colocation constraint is set.
4429        self.assertDeviceEqual("", with_vdef_dep.device)
4430        self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
4431
4432  @test_util.run_v1_only("b/120545219")
4433  def testGroup(self):
4434    with self.cached_session() as sess:
4435      v1 = variables.VariableV1([0.0])
4436      v2 = variables.VariableV1([1.0])
4437
4438      # Group init1 and init2 and run.
4439      init = control_flow_ops.group(v1.initializer, v2.initializer)
4440      # Fetching v1 directly will result in an uninitialized error
4441      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4442        self.evaluate(v1)
4443
4444      # Runs "init" before fetching v1 and v2.
4445      init.run()
4446      v1_val, v2_val = self.evaluate([v1, v2])
4447
4448    # Ensure that v1 and v2 are initialized
4449    self.assertAllClose([0.0], v1_val)
4450    self.assertAllClose([1.0], v2_val)
4451
4452  @test_util.run_v1_only("b/120545219")
4453  def testGroupEmpty(self):
4454    op = control_flow_ops.group()
4455    self.assertEqual(op.type, "NoOp")
4456    self.assertEqual(op.control_inputs, [])
4457
4458  @test_util.run_deprecated_v1
4459  def testMergeShapes(self):
4460    # All inputs unknown.
4461    p1 = array_ops.placeholder(dtypes.float32)
4462    p2 = array_ops.placeholder(dtypes.float32)
4463    p3 = array_ops.placeholder(dtypes.float32)
4464    m, index = control_flow_ops.merge([p1, p2, p3])
4465    self.assertIs(None, m.get_shape().ndims)
4466    self.assertEqual([], index.get_shape())
4467
4468    # All inputs known with different ranks.
4469    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4470    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
4471    m, index = control_flow_ops.merge([p1, p2])
4472    self.assertIs(None, m.get_shape().ndims)
4473    self.assertEqual([], index.get_shape())
4474
4475    # All inputs known with some dimensions different.
4476    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4477    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
4478    m, index = control_flow_ops.merge([p1, p2])
4479    self.assertEqual([None, None], m.get_shape().as_list())
4480    self.assertEqual([], index.get_shape())
4481
4482    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4483    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4484    m, index = control_flow_ops.merge([p1, p2])
4485    self.assertEqual([None, 2], m.get_shape().as_list())
4486    self.assertEqual([], index.get_shape())
4487
4488    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4489    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
4490    m, index = control_flow_ops.merge([p1, p2])
4491    self.assertEqual([None, 2], m.get_shape().as_list())
4492    self.assertEqual([], index.get_shape())
4493
4494    # All inputs known with same dimensions.
4495    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4496    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4497    m, index = control_flow_ops.merge([p1, p2])
4498    self.assertEqual([1, 2], m.get_shape().as_list())
4499    self.assertEqual([], index.get_shape())
4500
4501    p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4502    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4503    m, index = control_flow_ops.merge([p1, p2])
4504    self.assertEqual([None, 2], m.get_shape().as_list())
4505    self.assertEqual([], index.get_shape())
4506
4507    p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4508    p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4509    m, index = control_flow_ops.merge([p1, p2])
4510    self.assertEqual([None, None], m.get_shape().as_list())
4511    self.assertEqual([], index.get_shape())
4512
4513  @test_util.run_v1_only("b/120545219")
4514  def testRefSelect(self):
4515    index = array_ops.placeholder(dtypes.int32)
4516
4517    # All inputs unknown.
4518    p1 = array_ops.placeholder(dtypes.float32)
4519    p2 = array_ops.placeholder(dtypes.float32)
4520    p3 = array_ops.placeholder(dtypes.float32)
4521    v1 = variables.VariableV1(p1, validate_shape=False)
4522    v2 = variables.VariableV1(p2, validate_shape=False)
4523    v3 = variables.VariableV1(p3, validate_shape=False)
4524    self.assertIs(None, v1.get_shape().ndims)
4525    s = control_flow_ops.ref_select(index, [v1, v2, v3])
4526    self.assertIs(None, s.get_shape().ndims)
4527
4528    # All inputs known but different.
4529    v1 = variables.VariableV1([[1, 2]])
4530    v2 = variables.VariableV1([[2], [1]])
4531    s = control_flow_ops.ref_select(index, [v1, v2])
4532    self.assertIs(None, s.get_shape().ndims)
4533
4534    # All inputs known and same.
4535    v1 = variables.VariableV1([[1, 2]])
4536    v2 = variables.VariableV1([[1, 2]])
4537    s = control_flow_ops.ref_select(index, [v1, v2])
4538    self.assertEqual([1, 2], s.get_shape())
4539
4540    # Possibly the same but not guaranteed.
4541    v1 = variables.VariableV1([[1., 2.]])
4542    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4543    v2 = variables.VariableV1(p2, validate_shape=False)
4544    s = control_flow_ops.ref_select(index, [v1, v2])
4545    self.assertEqual(None, s.get_shape())
4546
4547  @test_util.run_deprecated_v1
4548  def testRunLoopTensor(self):
4549    with self.cached_session() as sess:
4550      tensor_list = []
4551
4552      def condition(t):
4553        return t < constant_op.constant(5)
4554
4555      def body(_):
4556        tensor_list.append(constant_op.constant(5))
4557        return constant_op.constant(10)
4558
4559      result = control_flow_ops.while_loop(condition, body,
4560                                           [constant_op.constant(4)])
4561      self.assertEqual(10, self.evaluate(result))
4562
4563      # Ensure that we cannot run a tensor that escapes the loop body
4564      # accidentally.
4565      with self.assertRaises(ValueError):
4566        sess.run(tensor_list[0])
4567
4568  @test_util.run_v1_only("b/120545219")
4569  def testWhilePyFuncBasic(self):
4570
4571    def func(x):
4572      return np.square(x)
4573
4574    with self.cached_session():
4575      r = control_flow_ops.while_loop(
4576          lambda i, v: i < 4,
4577          lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
4578          [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)],
4579          [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
4580      self.assertEqual(self.evaluate(r[1]), 65536.0)
4581
4582  @test_util.run_v1_only("b/120545219")
4583  def testWhileFuncBasic(self):
4584
4585    @function.Defun(dtypes.float32)
4586    def func(x):
4587      return math_ops.square(math_ops.square(x))
4588
4589    with self.cached_session():
4590      x = constant_op.constant(2.0, dtypes.float32)
4591      r = control_flow_ops.while_loop(
4592          lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
4593          [constant_op.constant(0), x],
4594          [tensor_shape.unknown_shape(),
4595           tensor_shape.unknown_shape()])
4596      grad = gradients_impl.gradients(r, x)[0]
4597      self.assertEqual(self.evaluate(r[1]), 65536.0)
4598      self.assertEqual(self.evaluate(grad), 524288.0)
4599      # while_v2 does not have stacks.
4600      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
4601        self.assertEqual(
4602            len([op for op in x.graph.get_operations() if op.type == "StackV2"
4603                ]), 1)
4604
4605
4606  @test_util.run_v1_only("b/120545219")
4607  def testQIntSwitchMerge(self):
4608    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4609      constant_qint = constant_op.constant(np.array([42]), dtypes.qint8)
4610      cond = constant_op.constant(True, dtypes.bool)
4611      v_f, v_t = control_flow_ops.switch(constant_qint, cond)
4612      result = control_flow_ops.merge([v_f, v_t])
4613      self.evaluate(result)
4614
4615  @test_util.run_v1_only("b/120545219")
4616  def testQIntRefSwitchMerge(self):
4617    with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
4618      var_qint = gen_state_ops.variable(
4619          shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="")
4620      assign_op = state_ops.assign(
4621          var_qint, constant_op.constant(np.array([42]), dtypes.qint8))
4622      self.evaluate(assign_op)
4623
4624      cond = constant_op.constant(True, dtypes.bool)
4625      v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
4626      result = control_flow_ops.ref_merge([v_f, v_t])
4627      self.evaluate(result)
4628
4629  @test_util.run_v1_only("b/120545219")
4630  def testUInt64SwitchMerge(self):
4631    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4632      constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
4633      cond = constant_op.constant(True, dtypes.bool)
4634      v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
4635      result = control_flow_ops.merge([v_f, v_t])
4636      self.evaluate(result)
4637
4638  def testSwitchEagerMode(self):
4639    if not context.executing_eagerly():
4640      return
4641    input_data = [1, 2, 3, 4]
4642    vf, vt = control_flow_ops.switch(input_data, False)
4643    self.assertAllEqual(vf, input_data)
4644    self.assertAllEqual(vt, [])
4645
4646  @test_util.run_deprecated_v1
4647  def testQIntArgAndRet(self):
4648
4649    @function.Defun(dtypes.qint8)
4650    def func(x):
4651      return x
4652
4653    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4654      qint = constant_op.constant(np.array([42]), dtypes.qint8)
4655      result = func(qint)
4656      self.evaluate(result)
4657
4658  def testSparseIdentity(self):
4659    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4660    st2 = control_flow_ops._Identity(st1)
4661    self.assertAllEqual(st1.indices, st2.indices)
4662    self.assertAllEqual(st1.values, st2.values)
4663    self.assertAllEqual(st1.dense_shape, st2.dense_shape)
4664
4665  def testSparseEnterExit(self):
4666    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4667    st2 = control_flow_ops._Enter(st1, "foo_1")
4668    st3 = control_flow_ops.exit(st2)
4669    self.assertAllEqual(st1.indices, st3.indices)
4670    self.assertAllEqual(st1.values, st3.values)
4671    self.assertAllEqual(st1.dense_shape, st3.dense_shape)
4672
4673  def _buildWhileWithShapeInvariants(self, shape_invariants):
4674    r = constant_op.constant([1, 2])
4675
4676    def cond(_):
4677      return False
4678
4679    def body(_):
4680      return constant_op.constant([1])
4681
4682    return control_flow_ops.while_loop(
4683        cond, body, [r], shape_invariants=shape_invariants)
4684
4685  def testWhileOutputShapeWithShapeInvariantsUnknownRank(self):
4686    @def_function.function
4687    def runTest():
4688      while_output = self._buildWhileWithShapeInvariants(
4689          [tensor_shape.TensorShape(None)])
4690      self.assertIsNone(while_output.shape.rank)
4691    runTest()
4692
4693  def testWhileOutputShapeWithShapeInvariantsPartialShape(self):
4694    @def_function.function
4695    def runTest():
4696      while_output = self._buildWhileWithShapeInvariants(
4697          [tensor_shape.TensorShape([None])])
4698      self.assertAllEqual(while_output.shape.as_list(), [None])
4699    runTest()
4700
4701  def testFunctionInWhile(self):
4702
4703    @def_function.function
4704    def body(x):
4705      return x + 1
4706
4707    r = control_flow_ops.while_loop(lambda x: x < 5, body, [0])
4708    self.assertAllEqual(r, 5.)
4709
4710
4711class ControlFlowContextCheckTest(test.TestCase):
4712
4713  def _getWhileTensor(self):
4714    """Creates and returns a tensor from a while context."""
4715    tensor = []
4716
4717    def body(i):
4718      if not tensor:
4719        tensor.append(constant_op.constant(1))
4720      return i + tensor[0]
4721
4722    control_flow_ops.while_loop(lambda i: i < 10, body, [0])
4723    return tensor[0]
4724
4725  def _getCondTensor(self):
4726    cond_tensor = []
4727
4728    def true_fn():
4729      if not cond_tensor:
4730        cond_tensor.append(constant_op.constant(1))
4731      return cond_tensor[0]
4732
4733    control_flow_ops.cond(
4734        math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4735    return cond_tensor[0]
4736
4737  @test_util.run_v1_only("b/120545219")
4738  def testInvalidContext(self):
4739    # Accessing a while loop tensor outside of control flow is illegal.
4740    while_tensor = self._getWhileTensor()
4741    with self.assertRaisesRegex(
4742        ValueError,
4743        "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
4744        "is in a while loop. See info log for more details."):
4745      math_ops.add(1, while_tensor)
4746
4747  @test_util.run_v1_only("b/120545219")
4748  def testInvalidContextInCond(self):
4749    # Accessing a while loop tensor in cond is illegal.
4750    while_tensor = self._getWhileTensor()
4751    with self.assertRaisesRegex(
4752        ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
4753        "'while/Const_1' is in a while loop. See info log for more details."):
4754      # TODO(skyewm): this passes if we return while_tensor directly instead
4755      # of using it as input to another op.
4756      control_flow_ops.cond(
4757          math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
4758          lambda: constant_op.constant(0))
4759
4760  @test_util.run_v1_only("b/120545219")
4761  def testInvalidContextInWhile(self):
4762    # Accessing a while loop tensor in a different while loop is illegal.
4763    while_tensor = self._getWhileTensor()
4764    with self.assertRaisesRegex(
4765        ValueError,
4766        "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are "
4767        "in different while loops. See info log for more details."):
4768      control_flow_ops.while_loop(lambda i: i < 10,
4769                                  lambda x: math_ops.add(1, while_tensor), [0])
4770
4771    with self.assertRaisesRegex(
4772        ValueError,
4773        "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' "
4774        "because they are in different while loops. See info log for more "
4775        "details."):
4776      control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
4777
4778  def testValidCondContext(self):
4779    # Accessing a tensor from a cond context is OK (although dangerous).
4780    cond_tensor = self._getCondTensor()
4781    math_ops.add(1, cond_tensor)
4782
4783  def testValidCondContextBranches(self):
4784    # Accessing a tensor from a cond context from the other branch's cond
4785    # context is OK (although dangerous).
4786    cond_tensor = []
4787
4788    def branch_fn():
4789      if not cond_tensor:
4790        cond_tensor.append(constant_op.constant(1))
4791      return cond_tensor[0]
4792
4793    control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
4794
4795  @test_util.run_v1_only("b/120545219")
4796  def testValidWhileContext(self):
4797    # Accessing a tensor in a nested while is OK.
4798    def body(_):
4799      c = constant_op.constant(1)
4800      return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
4801
4802    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4803
4804  @test_util.run_v1_only("b/120545219")
4805  def testValidNestedContexts(self):
4806    # Accessing a tensor from a cond context in a while context, all inside an
4807    # outer while context, is OK.
4808    def body(_):
4809      cond_tensor = self._getCondTensor()
4810      # Create another cond containing the while loop for good measure
4811      return control_flow_ops.cond(
4812          math_ops.less(1, 2),
4813          lambda: control_flow_ops.while_loop(lambda i: i < 3,
4814                                              lambda i: i + cond_tensor, [0]),
4815          lambda: constant_op.constant(0))
4816
4817    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4818
4819  @test_util.run_v1_only("b/120545219")
4820  def testInvalidNestedContexts(self):
4821    # Accessing a tensor from a while context in a different while context, all
4822    # inside a cond context, is illegal.
4823    def true_fn():
4824      while_tensor = self._getWhileTensor()
4825      return control_flow_ops.while_loop(lambda i: i < 3,
4826                                         lambda i: i + while_tensor, [0])
4827
4828    with self.assertRaisesRegex(
4829        ValueError,
4830        "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because"
4831        " they are in different while loops. See info log for more details."):
4832      control_flow_ops.cond(
4833          math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4834
4835
4836class TupleTest(test.TestCase):
4837
4838  @test_util.run_v1_only("b/120545219")
4839  def testTensors(self):
4840    for v1_first in [True, False]:
4841      with self.cached_session():
4842        v1 = variables.VariableV1([1.0])
4843        add1 = math_ops.add(
4844            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4845            2.0)
4846        v2 = variables.VariableV1([10.0])
4847        add2 = math_ops.add(
4848            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4849            20.0)
4850        t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
4851
4852        # v1 is not initialized.
4853        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4854          self.evaluate(v1)
4855
4856        # v2 is not initialized.
4857        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4858          self.evaluate(v2)
4859
4860        if v1_first:
4861          # Getting t1 initializes v2.
4862          self.assertAllClose([3.0], self.evaluate(t1))
4863          self.assertAllClose([10.0], self.evaluate(v2))
4864        else:
4865          # Getting t2 initializes v1.
4866          self.assertAllClose([30.0], self.evaluate(t2))
4867          self.assertAllClose([1.0], self.evaluate(v1))
4868
4869  @test_util.run_v1_only("b/120545219")
4870  def testIndexedSlices(self):
4871    for v1_first in [True, False]:
4872      with self.cached_session():
4873        v1 = variables.VariableV1(
4874            np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
4875                np.float32))
4876        v1_at_1 = indexed_slices.IndexedSlices(
4877            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4878            constant_op.constant([1]))
4879
4880        v2 = variables.VariableV1(
4881            np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
4882                np.float32))
4883        v2_at_1 = indexed_slices.IndexedSlices(
4884            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4885            constant_op.constant([1]))
4886
4887        st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
4888        g1 = array_ops.gather(st1.values, st1.indices)
4889        g2 = array_ops.gather(st2.values, st2.indices)
4890
4891        # v1 is not initialized.
4892        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4893          self.evaluate(v1)
4894
4895        # v2 is not initialized.
4896        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4897          self.evaluate(v2)
4898
4899        if v1_first:
4900          # Getting g1 initializes v2.
4901          self.assertAllClose([[10.0, 11.0]], self.evaluate(g1))
4902          self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
4903                              self.evaluate(v2))
4904        else:
4905          # Getting g2 initializes v1.
4906          self.assertAllClose([[10.1, 11.1]], self.evaluate(g2))
4907          self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
4908                              self.evaluate(v1))
4909
4910  def testAcceptTensorsAsControlInputs(self):
4911    with self.cached_session():
4912      var = variables.VariableV1(0)
4913      assign = state_ops.assign(var, 1)
4914      t, = control_flow_ops.tuple(
4915          [constant_op.constant(0)], control_inputs=[assign])
4916
4917      # Should trigger the assign.
4918      self.evaluate(t)
4919
4920      self.assertEqual(1, self.evaluate(var))
4921
4922
4923class AssertTest(test.TestCase):
4924
4925  @test_util.run_deprecated_v1
4926  def testGuardedAssertDoesNotCopyWhenTrue(self):
4927    if test_util.is_gpu_available():
4928      self.skipTest("b/128646478 fails in opensource")
4929
4930    with self.session() as sess:
4931      with ops.device(test.gpu_device_name()):
4932        value = constant_op.constant(1.0)
4933      with ops.device("/cpu:0"):
4934        true = constant_op.constant(True)
4935        guarded_assert = control_flow_ops.Assert(true, [value], name="guarded")
4936        unguarded_assert = gen_logging_ops._assert(
4937            true, [value], name="unguarded")
4938      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
4939      guarded_metadata = config_pb2.RunMetadata()
4940      sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
4941      unguarded_metadata = config_pb2.RunMetadata()
4942      sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
4943      guarded_nodestat_names = [
4944          n.node_name
4945          for d in guarded_metadata.step_stats.dev_stats
4946          for n in d.node_stats
4947      ]
4948      unguarded_nodestat_names = [
4949          n.node_name
4950          for d in unguarded_metadata.step_stats.dev_stats
4951          for n in d.node_stats
4952      ]
4953      guarded_memcpy_nodestat_names = [
4954          n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
4955      ]
4956      unguarded_memcpy_nodestat_names = [
4957          n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
4958      ]
4959      if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
4960        # A copy was performed for the unguarded assert
4961        self.assertLess(0, len(unguarded_memcpy_nodestat_names),
4962                        str(unguarded_nodestat_names))
4963      # No copy was performed for the guarded assert
4964      self.assertEqual([], guarded_memcpy_nodestat_names)
4965
4966
4967class WhileOpBenchmark(test.Benchmark):
4968  """Evaluate the performance of while_loop op."""
4969
4970  def _getInitVariables(self):
4971    batch_size = 10
4972    image_size = 256
4973    kernel_size = 3
4974    depth = 16
4975
4976    init_step = constant_op.constant(-1)
4977    image = variable_scope.get_variable(
4978        "image",
4979        initializer=random_ops.random_normal(
4980            [batch_size, image_size, image_size, depth],
4981            dtype=dtypes.float32,
4982            stddev=1e-1))
4983    kernel = variable_scope.get_variable(
4984        "weights",
4985        initializer=random_ops.truncated_normal(
4986            [kernel_size, kernel_size, depth, depth],
4987            dtype=dtypes.float32,
4988            stddev=1e-1))
4989    return init_step, image, kernel
4990
4991  def _runOneBenchmark(self,
4992                       default_device,
4993                       num_iters=10,
4994                       static_unroll=False,
4995                       steps=10):
4996    """Evaluate the while loop performance.
4997
4998    Args:
4999      default_device: The default device to run all ops except the loop_body.
5000        loop_body is always run on GPU.
5001      num_iters: Number of iterations to run.
5002      static_unroll: If true, run unrolled version; otherwise, run while_loop.
5003      steps: Total number of repeated steps to run the loop.
5004
5005    Returns:
5006      The duration of the run in seconds.
5007    """
5008
5009    def loop_body(i, x):
5010      with ops.device("/gpu:0"):
5011        # Always put loop body on GPU.
5012        nx = nn_ops.conv2d(
5013            input=x,
5014            filter=kernel,
5015            strides=[1, 1, 1, 1],
5016            padding="SAME",
5017            data_format="NHWC",
5018            name="conv2d")
5019        ni = math_ops.add(i, 1)
5020        return ni, nx
5021
5022    ops.reset_default_graph()
5023    with session.Session() as sess, ops.device(default_device):
5024      # Get the initial id i, input x, and kernel.
5025      i, x, kernel = self._getInitVariables()
5026      self.evaluate(variables.global_variables_initializer())
5027
5028      if static_unroll:
5029        for _ in range(steps):
5030          i, x = loop_body(i, x)
5031      else:
5032        i, x = control_flow_ops.while_loop(
5033            lambda i, _: i < steps,
5034            loop_body, [i, x],
5035            parallel_iterations=steps,
5036            swap_memory=True)
5037
5038      r = math_ops.reduce_sum(x)
5039      dx, dk = gradients_impl.gradients(r, [x, kernel])
5040      # Use group to avoid fetching back results.
5041      r = control_flow_ops.group(dx, dk)
5042
5043      for _ in range(3):
5044        # exclude warm up time
5045        self.evaluate(r)
5046
5047      start_time = time.time()
5048      for _ in range(num_iters):
5049        self.evaluate(r)
5050      return (time.time() - start_time) / num_iters
5051
5052  def benchmarkWhileOpCrossDevicePlacement(self):
5053    iters = 10
5054    # Run loop body on GPU, but other ops on CPU.
5055    duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
5056    self.report_benchmark(
5057        name="while_op_cross_device", iters=iters, wall_time=duration)
5058
5059  def benchmarkWhileOpSameDevicePlacement(self):
5060    iters = 10
5061    # Run all ops on the same GPU device.
5062    duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
5063    self.report_benchmark(
5064        name="while_op_same_device", iters=iters, wall_time=duration)
5065
5066  def benchmarkWhileOpUnrollCrossDevicePlacement(self):
5067    iters = 10
5068    # Run loop body on GPU, but other ops on CPU.
5069    duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
5070    self.report_benchmark(
5071        name="unroll_cross_device_cpu", iters=iters, wall_time=duration)
5072
5073  def benchmarkWhileOpUnrollSameDevicePlacement(self):
5074    iters = 10
5075    # Run all ops on GPU.
5076    duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
5077    self.report_benchmark(
5078        name="unroll_same_device", iters=iters, wall_time=duration)
5079
5080
5081@test_util.with_control_flow_v2
5082class EagerTest(test.TestCase):
5083
5084  def testCond(self):
5085    with context.eager_mode():
5086      pred = math_ops.less(1, 2)
5087      fn1 = lambda: [constant_op.constant(10)]
5088      fn2 = lambda: [constant_op.constant(20)]
5089      r = control_flow_ops.cond(pred, fn1, fn2)
5090
5091      self.assertAllEqual(r.numpy(), 10)
5092      self.assertFalse(isinstance(r, list))
5093
5094  # TODO(b/117279927): Re-enable once msan failure is fixed.
5095  def DISABLED_testCondInDefun(self):
5096    with context.eager_mode():
5097
5098      @eager_function.defun
5099      def foo(pred):
5100        # TODO(b/111124878): this only needs to output one element.
5101        fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
5102        fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
5103        return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
5104
5105      r = foo(True)
5106      self.assertAllEqual(r[0].numpy(), 10)
5107      self.assertNotIsInstance(r, list)
5108
5109      r = foo(False)
5110      self.assertAllEqual(r[0].numpy(), 20)
5111      self.assertFalse(isinstance(r, list))
5112
5113  def testWhileLoop(self):
5114    with context.eager_mode():
5115      tensor = constant_op.constant([1, 2, 3, 4, 5])
5116      self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
5117
5118  def testWhileLoopWithMaxIterations(self):
5119    with context.eager_mode():
5120      tensor = constant_op.constant([1, 2, 3, 4, 5])
5121      self.assertAllEqual(
5122          isum(tensor, maximum_iterations=3).numpy(),
5123          [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
5124
5125  @test_util.run_v1_only("b/120545219")
5126  def testWhileWithMaximumIterationsAndSingleArgument(self):
5127    with context.eager_mode():
5128      tensor = constant_op.constant(0)
5129      r = control_flow_ops.while_loop(
5130          lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
5131      self.assertEqual(1, r.numpy())
5132
5133  def testWithDependencies(self):
5134    with context.eager_mode():
5135      t1 = constant_op.constant(1)
5136      t2 = constant_op.constant(2)
5137      t3 = control_flow_ops.with_dependencies(t1, t2)
5138      self.assertAllEqual(t2.numpy(), t3.numpy())
5139
5140  def testTuple(self):
5141    with context.eager_mode():
5142      t1 = constant_op.constant(1)
5143      t2 = constant_op.constant(2)
5144      tup1, tup2 = control_flow_ops.tuple([t1, t2])
5145      self.assertAllEqual(t1.numpy(), tup1.numpy())
5146      self.assertAllEqual(t2.numpy(), tup2.numpy())
5147
5148  @test_util.run_v1_only("b/120545219")
5149  def testCase(self):
5150    with context.eager_mode():
5151      x = constant_op.constant(1)
5152      y = constant_op.constant(2)
5153      z = constant_op.constant(3)
5154      f1 = lambda: constant_op.constant(17)
5155      f2 = lambda: constant_op.constant(23)
5156      f3 = lambda: constant_op.constant(-1)
5157
5158      r1 = control_flow_ops.case(
5159          [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
5160      self.assertAllEqual(r1.numpy(), 17)
5161
5162
5163if __name__ == "__main__":
5164  test.main()
5165