1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=g-long-lambda 17"""Tests for tensorflow.ops.control_flow_ops.""" 18 19import collections 20import math 21import re 22import sys 23import time 24 25from absl.testing import parameterized 26import numpy as np 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import rewriter_config_pb2 30from tensorflow.python import tf2 31from tensorflow.python.client import device_lib 32from tensorflow.python.client import session 33from tensorflow.python.data.experimental.ops import cardinality 34from tensorflow.python.data.ops import dataset_ops 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.eager import function as eager_function 38from tensorflow.python.eager import wrap_function 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import errors_impl 42from tensorflow.python.framework import function 43from tensorflow.python.framework import indexed_slices 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import sparse_tensor 46from tensorflow.python.framework import tensor_shape 47from tensorflow.python.framework import tensor_spec 48from tensorflow.python.framework import test_util 49from tensorflow.python.ops import array_ops 50from tensorflow.python.ops import control_flow_ops 51from tensorflow.python.ops import control_flow_util 52from tensorflow.python.ops import data_flow_ops 53from tensorflow.python.ops import functional_ops 54from tensorflow.python.ops import gen_array_ops 55from tensorflow.python.ops import gen_control_flow_ops 56from tensorflow.python.ops import gen_data_flow_ops 57from tensorflow.python.ops import gen_logging_ops 58from tensorflow.python.ops import gen_state_ops 59from tensorflow.python.ops import gradient_checker_v2 60from tensorflow.python.ops import gradients_impl 61from tensorflow.python.ops import init_ops 62from tensorflow.python.ops import linalg_ops 63from tensorflow.python.ops import logging_ops 64from tensorflow.python.ops import map_fn 65from tensorflow.python.ops import math_ops 66from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 67from tensorflow.python.ops import nn_ops 68from tensorflow.python.ops import random_ops 69from tensorflow.python.ops import resource_variable_ops 70from tensorflow.python.ops import script_ops 71from tensorflow.python.ops import sparse_ops 72from tensorflow.python.ops import state_ops 73from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 74from tensorflow.python.ops import tensor_array_ops 75from tensorflow.python.ops import variable_scope 76from tensorflow.python.ops import variables 77from tensorflow.python.ops import while_v2 # pylint: disable=unused-import 78# pylint: disable=unused-import 79from tensorflow.python.ops.ragged import ragged_factory_ops 80from tensorflow.python.ops.ragged import ragged_tensor 81import tensorflow.python.ops.tensor_array_grad 82# pylint: enable=unused-import 83from tensorflow.python.platform import test 84from tensorflow.python.training import adam 85from tensorflow.python.training import gradient_descent 86from tensorflow.python.util import nest 87 88 89def check_consumers(graph): 90 """Sanity check on the consumer list of the tensors.""" 91 92 consumer_count = {} 93 for op in graph.get_operations(): 94 for v in op.inputs: 95 cnt = consumer_count.get(v, 0) 96 consumer_count[v] = cnt + 1 97 for k, v in consumer_count.items(): 98 if len(k.consumers()) != v: 99 return False 100 return True 101 102 103def all_fetchables(): 104 tensor_names = [] 105 graph = ops.get_default_graph() 106 for op in graph.get_operations(): 107 for t in op.outputs: 108 if graph.is_fetchable(t): 109 tensor_names.append(t.name) 110 return tensor_names 111 112 113def all_feedables(): 114 feedable_tensors = [] 115 graph = ops.get_default_graph() 116 for op in graph.get_operations(): 117 for t in op.inputs: 118 if graph.is_feedable(t): 119 feedable_tensors.append(t) 120 return feedable_tensors 121 122 123def opt_cfg(do_constant_folding=True): 124 return config_pb2.ConfigProto( 125 allow_soft_placement=True, 126 graph_options=config_pb2.GraphOptions( 127 optimizer_options=config_pb2.OptimizerOptions( 128 opt_level=config_pb2.OptimizerOptions.L1, 129 do_function_inlining=True, 130 do_constant_folding=do_constant_folding))) 131 132 133def isum(s, maximum_iterations=None): 134 i = constant_op.constant(0, name="i") 135 c = lambda i, s: math_ops.less(i, 10) 136 b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)] 137 _, r_s = control_flow_ops.while_loop( 138 c, b, [i, s], maximum_iterations=maximum_iterations) 139 return r_s 140 141 142def enqueue_print_op(s): 143 """Enqueues an op that prints a message to be captured in the test.""" 144 return logging_ops.print_v2("ControlFlowOpsTest: " + s) 145 146 147def filter_test_messages(s): 148 """Returns a list of messages printed by enqueue_print_op.""" 149 prefix = "ControlFlowOpsTest: " 150 return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)] 151 152 153def tf_function_in_tf2(f): 154 if tf2.enabled(): 155 # In TF1 do not wrap with tf.function so that we can test the v1 control 156 # flow code path. 157 return def_function.function(f) 158 return f 159 160 161@test_util.with_eager_op_as_function 162@test_util.with_control_flow_v2 163class ControlFlowTest(test.TestCase, parameterized.TestCase): 164 165 @test_util.run_v1_only("b/120545219") 166 def testRefIdentity(self): 167 with self.cached_session(): 168 v = variables.VariableV1(7) 169 170 v = control_flow_ops._Identity(v) 171 op = state_ops.assign(v, 9) 172 v2 = control_flow_ops.with_dependencies([op], v) 173 174 self.assertTrue(isinstance(v2, ops.Tensor)) 175 self.evaluate(variables.global_variables_initializer()) 176 self.assertEqual(9, self.evaluate(v2)) 177 178 @test_util.run_v1_only("b/120545219") 179 def testRefEnter(self): 180 with self.cached_session(): 181 v = variables.VariableV1(7) 182 183 enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) 184 nine = constant_op.constant(9) 185 enter_nine = gen_control_flow_ops.enter(nine, "foo_1") 186 op = state_ops.assign(enter_v, enter_nine) 187 v2 = control_flow_ops.with_dependencies([op], enter_v) 188 v3 = control_flow_ops.exit(v2) 189 self.evaluate(variables.global_variables_initializer()) 190 self.assertEqual(9, self.evaluate(v3)) 191 192 @test_util.run_v1_only("b/120545219") 193 def testRefSwitch(self): 194 with self.cached_session(): 195 v = variables.VariableV1(7) 196 197 p = constant_op.constant(True) 198 v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access 199 v2 = state_ops.assign(v1[1], 9) 200 self.evaluate(variables.global_variables_initializer()) 201 self.assertEqual(9, self.evaluate(v2)) 202 203 def testEnterMulExit(self): 204 with self.cached_session(): 205 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 206 enter_data = gen_control_flow_ops.enter(data, "foo_1", False) 207 five = constant_op.constant(5) 208 enter_five = gen_control_flow_ops.enter(five, "foo_1", False) 209 mul_op = math_ops.multiply(enter_data, enter_five) 210 exit_op = control_flow_ops.exit(mul_op) 211 212 result = self.evaluate(exit_op) 213 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 214 215 @test_util.run_deprecated_v1 216 def testEnterShapePropagation(self): 217 with self.cached_session(): 218 v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) 219 220 # If is_constant=True, the shape information should be propagated. 221 enter_v_constant = gen_control_flow_ops.enter( 222 v, "frame1", is_constant=True) 223 self.assertEqual(enter_v_constant.shape, [2]) 224 225 # Otherwise, the shape should be unknown. 226 enter_v_non_constant = gen_control_flow_ops.enter( 227 v, "frame2", is_constant=False) 228 self.assertEqual(enter_v_non_constant.shape, None) 229 230 @test_util.run_v1_only("b/120545219") 231 def testSwitchMergeIndexedSlices(self): 232 with self.cached_session(): 233 values = constant_op.constant([1, 2, 3, 4, 5, 6]) 234 indices = constant_op.constant([0, 2, 4, 6, 8, 10]) 235 data = indexed_slices.IndexedSlices(values, indices) 236 pred = ops.convert_to_tensor(True) 237 switch_op = control_flow_ops.switch(data, pred) 238 merge_op = control_flow_ops.merge(switch_op)[0] 239 240 val = merge_op.values 241 ind = merge_op.indices 242 self.assertAllEqual(np.arange(1, 7), val) 243 self.assertAllEqual(np.arange(0, 12, 2), ind) 244 245 @test_util.run_v1_only("b/120545219") 246 def testSwitchDeadBranch(self): 247 with self.cached_session(): 248 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 249 ports = ops.convert_to_tensor(True, name="ports") 250 switch_op = control_flow_ops.switch(data, ports) 251 dead_branch = array_ops.identity(switch_op[0]) 252 253 with self.assertRaisesWithPredicateMatch( 254 errors_impl.InvalidArgumentError, 255 lambda e: "Retval[0] does not have value" in str(e)): 256 self.evaluate(dead_branch) 257 258 @test_util.run_v1_only("b/120545219") 259 def testSwitchMergeLess(self): 260 with self.cached_session(): 261 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 262 zero = ops.convert_to_tensor(0) 263 one = ops.convert_to_tensor(1) 264 less_op = math_ops.less(zero, one) 265 switch_op = control_flow_ops.switch(data, less_op) 266 merge_op = control_flow_ops.merge(switch_op)[0] 267 268 result = self.evaluate(merge_op) 269 self.assertAllEqual(np.arange(1, 7), result) 270 271 @test_util.run_v1_only("b/120545219") 272 def testSwitchMergeAddIdentity(self): 273 with self.cached_session(): 274 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 275 ports = ops.convert_to_tensor(False, name="ports") 276 switch_op = control_flow_ops.switch(data, ports) 277 one = constant_op.constant(1) 278 add_op = math_ops.add(switch_op[0], one) 279 id_op = array_ops.identity(switch_op[1]) 280 merge_op = control_flow_ops.merge([add_op, id_op])[0] 281 282 result = self.evaluate(merge_op) 283 self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result) 284 285 @test_util.run_v1_only("b/120545219") 286 def testSwitchMergeAddMul(self): 287 with self.cached_session(): 288 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 289 ports = ops.convert_to_tensor(True, name="ports") 290 switch_op = control_flow_ops.switch(data, ports) 291 one = constant_op.constant(1) 292 add_op = math_ops.add(switch_op[0], one) 293 five = constant_op.constant(5) 294 mul_op = math_ops.multiply(switch_op[1], five) 295 merge_op = control_flow_ops.merge([add_op, mul_op])[0] 296 297 result = self.evaluate(merge_op) 298 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 299 300 @test_util.run_v1_only("b/120545219") 301 def testLoop_false(self): 302 with self.cached_session(): 303 false = ops.convert_to_tensor(False) 304 n = constant_op.constant(10) 305 306 enter_false = gen_control_flow_ops.enter(false, "foo_1", False) 307 enter_n = gen_control_flow_ops.enter(n, "foo_1", False) 308 309 merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] 310 switch_n = control_flow_ops.switch(merge_n, enter_false) 311 exit_n = control_flow_ops.exit(switch_n[0]) 312 next_n = control_flow_ops.next_iteration(switch_n[0]) 313 merge_n.op._update_input(1, next_n) 314 315 result = self.evaluate(exit_n) 316 self.assertAllEqual(10, result) 317 318 @test_util.run_deprecated_v1 319 def testLoop_1(self): 320 with self.cached_session(): 321 zero = constant_op.constant(0) 322 one = constant_op.constant(1) 323 n = constant_op.constant(10) 324 325 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 326 enter_one = gen_control_flow_ops.enter(one, "foo", True) 327 enter_n = gen_control_flow_ops.enter(n, "foo", True) 328 329 with ops.device(test.gpu_device_name()): 330 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 331 332 less_op = math_ops.less(merge_i, enter_n) 333 cond_op = control_flow_ops.loop_cond(less_op) 334 switch_i = control_flow_ops.switch(merge_i, cond_op) 335 336 add_i = math_ops.add(switch_i[1], enter_one) 337 338 next_i = control_flow_ops.next_iteration(add_i) 339 merge_i.op._update_input(1, next_i) 340 341 exit_i = control_flow_ops.exit(switch_i[0]) 342 result = self.evaluate(exit_i) 343 self.assertAllEqual(10, result) 344 345 @test_util.run_v1_only("b/120545219") 346 def testLoop_2(self): 347 with self.cached_session(): 348 zero = constant_op.constant(0) 349 one = constant_op.constant(1) 350 n = constant_op.constant(10) 351 352 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 353 enter_one = gen_control_flow_ops.enter(one, "foo", True) 354 enter_n = gen_control_flow_ops.enter(n, "foo", True) 355 356 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 357 358 less_op = math_ops.less(merge_i, enter_n) 359 cond_op = control_flow_ops.loop_cond(less_op) 360 switch_i = control_flow_ops.switch(merge_i, cond_op) 361 362 add_i = math_ops.add(switch_i[1], enter_one) 363 364 with ops.device(test.gpu_device_name()): 365 next_i = control_flow_ops.next_iteration(add_i) 366 merge_i.op._update_input(1, next_i) 367 368 exit_i = control_flow_ops.exit(switch_i[0]) 369 result = self.evaluate(exit_i) 370 self.assertAllEqual(10, result) 371 372 @test_util.run_v1_only("b/120545219") 373 def testDifferentFrame(self): 374 with self.cached_session(): 375 data = array_ops.placeholder(dtypes.float32, shape=[]) 376 enter_1 = gen_control_flow_ops.enter(data, "foo_1", False) 377 enter_2 = gen_control_flow_ops.enter(data, "foo_2", False) 378 res = math_ops.add(enter_1, enter_2) 379 with self.assertRaisesOpError("has inputs from different frames"): 380 res.eval(feed_dict={data: 1.0}) 381 382 @test_util.run_deprecated_v1 383 def testCondBool(self): 384 values = constant_op.constant(10) 385 fn1 = lambda: math_ops.add(values, 1) 386 fn2 = lambda: math_ops.subtract(values, 1) 387 with self.assertRaisesRegex(TypeError, "must not be a Python bool"): 388 _ = control_flow_ops.cond(False, fn1, fn2) 389 390 @test_util.run_deprecated_v1 391 def testCondInt(self): 392 p = array_ops.placeholder(dtypes.bool, shape=[]) 393 v = constant_op.constant(10) 394 fn1 = lambda: math_ops.add(v, 1) 395 fn2 = lambda: math_ops.subtract(v, 1) 396 y = control_flow_ops.cond(p, fn1, fn2) 397 grad = gradients_impl.gradients(y, [v]) 398 self.assertAllEqual([None], grad) 399 400 def testCondOutputShape(self): 401 x = constant_op.constant(1.0) 402 b = control_flow_ops.cond( 403 constant_op.constant(True), lambda: math_ops.square(x), 404 lambda: math_ops.subtract(x, 1.)) 405 self.assertEqual(b.shape, tensor_shape.TensorShape([])) 406 407 @test_util.run_v1_only("b/120545219") 408 def testFetchable(self): 409 with self.cached_session() as sess: 410 x = array_ops.placeholder(dtypes.float32) 411 control_flow_ops.cond( 412 constant_op.constant(True), lambda: x + 2, lambda: x + 0) 413 graph = ops.get_default_graph() 414 for op in graph.get_operations(): 415 for t in op.inputs: 416 if graph.is_fetchable(t.op): 417 sess.run(t, feed_dict={x: 3}) 418 else: 419 with self.assertRaisesRegex(ValueError, 420 "has been marked as not fetchable"): 421 sess.run(t, feed_dict={x: 3}) 422 423 @test_util.disable_control_flow_v2("Not relevant") 424 @test_util.run_v1_only("b/120545219") 425 def testFeedable(self): 426 with self.cached_session() as sess: 427 c = constant_op.constant(2) 428 i0 = constant_op.constant(0) 429 r = control_flow_ops.while_loop(lambda i: i < 1000, 430 lambda i: math_ops.square(c) + i, [i0]) 431 self.assertEqual(1000, r.eval(feed_dict={i0: 0})) 432 feedable_tensors = all_feedables() 433 for t in feedable_tensors: 434 sess.run(r, feed_dict={t: 3}) 435 graph = ops.get_default_graph() 436 for op in graph.get_operations(): 437 for t in op.inputs: 438 if t not in feedable_tensors and t.dtype is dtypes.int32: 439 with self.assertRaisesRegex(ValueError, "may not be fed"): 440 sess.run(r, feed_dict={t: 3}) 441 442 @test_util.run_v1_only("b/120545219") 443 def testCondIndexedSlices(self): 444 with self.cached_session(): 445 values = constant_op.constant([10]) 446 indices = constant_op.constant([0]) 447 x = indexed_slices.IndexedSlices(values, indices) 448 pred = math_ops.less(1, 2) 449 fn1 = lambda: indexed_slices.IndexedSlices( 450 math_ops.add(x.values, 1), indices) 451 fn2 = lambda: indexed_slices.IndexedSlices( 452 math_ops.subtract(x.values, 1), indices) 453 r = control_flow_ops.cond(pred, fn1, fn2) 454 455 val = r.values 456 ind = r.indices 457 self.assertAllEqual([11], val) 458 self.assertAllEqual([0], ind) 459 460 def testCondMismatchedIndexedSlices(self): 461 @def_function.function 462 def foo(): 463 values = constant_op.constant([10]) 464 indices = constant_op.constant([0]) 465 x = indexed_slices.IndexedSlices(values, indices) 466 with self.assertRaisesRegex(TypeError, 467 "Cannot reconcile tf.cond 0-th outputs"): 468 control_flow_ops.cond( 469 constant_op.constant(True), lambda: indexed_slices.IndexedSlices( 470 math_ops.add(x.values, 1), indices), 471 lambda: math_ops.add(x.values, 1), indices) 472 foo() 473 474 def testCondSparseTensor(self): 475 values = constant_op.constant([2.0, 4.0], name="values") 476 indices = constant_op.constant([[0], [3]], 477 dtype=dtypes.int64, 478 name="indices") 479 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 480 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 481 pred = math_ops.less(1, 2) 482 fn1 = lambda: sparse_tensor.SparseTensor( 483 indices + 1, x.values + 1, dense_shape=shape) 484 fn2 = lambda: sparse_tensor.SparseTensor( 485 indices, x.values - 1, dense_shape=shape) 486 r = control_flow_ops.cond(pred, fn1, fn2) 487 self.assertAllEqual([3.0, 5.0], r.values) 488 self.assertAllEqual([[1], [4]], r.indices) 489 self.assertAllEqual(r.values.get_shape(), (2,)) 490 491 def testCondRaggedTensor(self): 492 rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 493 pred = math_ops.less(1, 2) 494 fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0) 495 fn2 = lambda: rt[:2] - 2 496 result = control_flow_ops.cond(pred, fn1, fn2) 497 self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values) 498 self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits) 499 500 @test_util.run_v1_only("b/120545219") 501 def testCondResource(self): 502 503 with self.cached_session(): 504 rv = resource_variable_ops.ResourceVariable(True) 505 self.evaluate(variables.global_variables_initializer()) 506 t = ops.convert_to_tensor(1.0) 507 508 def case(): 509 assign = resource_variable_ops.assign_variable_op(rv.handle, False) 510 with ops.control_dependencies([assign]): 511 return array_ops.identity(t) 512 513 self.assertEqual( 514 1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t))) 515 516 @test_util.run_deprecated_v1 517 def testCondResourceGradShape(self): 518 rv1 = resource_variable_ops.ResourceVariable([1.0, 2.0]) 519 rv2 = resource_variable_ops.ResourceVariable([3.0, 4.0]) 520 pred = constant_op.constant(True) 521 result = control_flow_ops.cond(pred, lambda: rv1, lambda: rv2) 522 grads = gradients_impl.gradients(result, [rv1, rv2]) 523 self.assertAllEqual(grads[0].shape.as_list(), [2]) 524 self.assertAllEqual(grads[1].shape.as_list(), [2]) 525 526 @test_util.run_v1_only("b/120545219") 527 def testCondWithTensorArrayGrad(self): 528 with self.cached_session() as sess: 529 with ops.device(test.gpu_device_name()): 530 pred = array_ops.placeholder(dtypes.bool, []) 531 x = constant_op.constant([1.0, 2.0, 3.0]) 532 y = control_flow_ops.cond( 533 pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x), 534 lambda: constant_op.constant([1.0, 1.0, 1.0])) 535 g = gradients_impl.gradients(y, x)[0] 536 537 self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0]) 538 self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0]) 539 540 @test_util.run_v1_only("b/120545219") 541 def testCondIndexedSlicesDifferentTypes(self): 542 with self.cached_session(): 543 values = constant_op.constant([10]) 544 i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32) 545 i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64) 546 x = indexed_slices.IndexedSlices(values, i_32) 547 pred = math_ops.less(1, 2) 548 fn1 = lambda: indexed_slices.IndexedSlices( 549 math_ops.add(x.values, 1), i_32) 550 fn2 = lambda: indexed_slices.IndexedSlices( 551 math_ops.subtract(x.values, 1), i_64) 552 r = control_flow_ops.cond(pred, fn1, fn2) 553 554 val = r.values 555 ind = r.indices 556 self.assertAllEqual([11], val) 557 self.assertAllEqual([0], ind) 558 self.assertTrue(ind.dtype == np.int64) 559 560 @test_util.run_v1_only("b/120545219") 561 def testCondColocation(self): 562 with self.session(): 563 with ops.device("/cpu:0"): 564 v = variables.Variable(7.0) 565 566 x = constant_op.constant(10.0) 567 pred = math_ops.less(1.0, 2.0) 568 fn1 = lambda: math_ops.add(v, 1.0) 569 fn2 = lambda: math_ops.subtract(x, 1.0) 570 r = control_flow_ops.cond(pred, fn1, fn2) 571 572 for op in x.graph.get_operations(): 573 if op.name == "cond/Add/Switch": 574 self.assertDeviceEqual(op.device, "/cpu:0") 575 576 def _testCond_1(self, use_gpu): 577 with self.cached_session(use_gpu=use_gpu): 578 x = constant_op.constant(10) 579 pred = math_ops.less(1, 2) 580 fn1 = lambda: math_ops.add(x, 1) 581 fn2 = lambda: math_ops.subtract(x, 1) 582 r = control_flow_ops.cond(pred, fn1, fn2) 583 584 result = self.evaluate(r) 585 self.assertAllEqual(11, result) 586 587 def testCond_1(self): 588 589 self._testCond_1(use_gpu=False) 590 # TODO(b/116526896): Enable GPU tests. 591 # self._testCond_1(use_gpu=True) 592 593 def testCond_2(self): 594 595 with self.cached_session(): 596 x = constant_op.constant(10) 597 r = control_flow_ops.cond( 598 math_ops.less(1, 0), lambda: math_ops.add(x, 1), 599 lambda: math_ops.subtract(x, 1)) 600 result = self.evaluate(r) 601 self.assertAllEqual(9, result) 602 603 def testCond_3(self): 604 605 with self.cached_session(): 606 x = constant_op.constant(10) 607 pred = math_ops.less(1, 2) 608 fn1 = lambda: math_ops.add(x, 1) 609 fn2 = lambda: math_ops.subtract(x, 1) 610 fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1) 611 r = control_flow_ops.cond(pred, fn3, fn2) 612 613 result = self.evaluate(r) 614 self.assertAllEqual(12, result) 615 616 @test_util.run_in_graph_and_eager_modes 617 def testCondPruning(self): 618 v1 = variables.Variable(7) 619 v2 = variables.Variable(7) 620 v3 = variables.Variable(7) 621 622 def f(): 623 age = constant_op.constant(3) 624 max_age = constant_op.constant(2) 625 pred = math_ops.greater(age, max_age) 626 fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op] 627 fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op] 628 r = control_flow_ops.cond(pred, fn1, fn2) 629 self.assertEqual(len(r), 2) 630 return r[1] 631 632 f_defun = eager_function.defun(f) 633 634 if not context.executing_eagerly(): 635 with self.cached_session(): 636 self.evaluate(variables.global_variables_initializer()) 637 result = self.evaluate(f()) 638 self.assertEqual(True, result) 639 # Only second cond result was fetched, so v1 assign shouldn't run. 640 self.assertEqual(7, self.evaluate(v1)) 641 self.assertEqual(2, self.evaluate(v2)) 642 self.assertEqual(7, self.evaluate(v3)) 643 644 result = f_defun() 645 self.assertEqual(True, self.evaluate(result)) 646 # Both v1 and v2 branch assignments should be run in defun. 647 self.assertEqual(1, self.evaluate(v1)) 648 self.assertEqual(2, self.evaluate(v2)) 649 self.assertEqual(7, self.evaluate(v3)) 650 651 def testCond_5(self): 652 with self.cached_session(): 653 alive = constant_op.constant(True, name="alive") 654 count = constant_op.constant(0, name="count") 655 656 def body(i): 657 return control_flow_ops.cond( 658 alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)], 659 lambda: [alive, count]) 660 661 for i in range(10): 662 alive, count = body(i) 663 self.assertAllEqual(4, self.evaluate(count)) 664 665 @test_util.run_v1_only("b/120545219") 666 def testCond_6(self): 667 with self.cached_session(): 668 v1 = variables.Variable([7]) 669 670 age = constant_op.constant(3) 671 pred = math_ops.greater(age, 4) 672 fn1 = lambda: age 673 fn2 = lambda: v1 674 r = control_flow_ops.cond(pred, fn1, fn2) 675 676 self.evaluate(variables.global_variables_initializer()) 677 result = self.evaluate(r) 678 self.assertAllEqual(np.array([7]), result) 679 680 def testCond_7(self): 681 with self.cached_session() as sess: 682 x = constant_op.constant(10) 683 y = constant_op.constant(200) 684 pred = math_ops.less(1, 2) 685 fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)] 686 fn2 = lambda: [y, y] 687 r = control_flow_ops.cond(pred, fn1, fn2) 688 self.assertAllEqual([11, 12], self.evaluate(r)) 689 690 @parameterized.parameters(dtypes.float32, dtypes.float64) 691 @test_util.run_v1_only("Uses tf.gradients") 692 def testCondResourceGrad(self, dtype): 693 init = constant_op.constant([7.], dtype=dtype) 694 v1 = variables.Variable(init) 695 696 age = constant_op.constant(3., dtype=dtype) 697 pred = math_ops.greater(age, 4.) 698 fn1 = lambda: age 699 fn2 = lambda: v1 700 r = control_flow_ops.cond(pred, fn1, fn2) 701 702 grad = gradients_impl.gradients(r, v1)[0] 703 self.evaluate(variables.global_variables_initializer()) 704 self.assertAllEqual(grad, [1.]) 705 706 @test_util.run_gpu_only 707 @test_util.run_deprecated_v1 708 def testCond_Device(self): 709 x = constant_op.constant(-10.) 710 711 # True branch function defined outside of device scope 712 def true_fn(): 713 return math_ops.exp(x) 714 715 with ops.device("CPU:0"): 716 r = control_flow_ops.cond( 717 constant_op.constant(True), true_fn, lambda: 0.) 718 self.assertIn("cpu", r.device.lower()) 719 720 with session.Session() as sess: 721 options = config_pb2.RunOptions(output_partition_graphs=True) 722 run_metadata = config_pb2.RunMetadata() 723 sess.run(r, options=options, run_metadata=run_metadata) 724 # We expect that everything runs on CPU, even if GPU is available. 725 self.assertEqual(len(run_metadata.partition_graphs), 1) 726 727 def _count_matching_switch_nodes_on_device(self, run_metadata, device_str, 728 dtype): 729 # Returns the number of Switch nodes with type dtype placed on 730 # `device_str`. 731 device_graphs = [ 732 g for g in run_metadata.partition_graphs 733 if device_str in g.node[0].device 734 ] 735 self.assertLen(device_graphs, 1) 736 switch_nodes = [ 737 n for n in device_graphs[0].node 738 if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum 739 ] 740 return len(switch_nodes) 741 742 @test_util.run_gpu_only 743 @test_util.run_deprecated_v1 744 def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self): 745 x = array_ops.placeholder(dtypes.float32) 746 747 # `arg` is used in the cond then branch so a Switch node is created for it. 748 # We test that the Switch node gets placed on the same device as `arg`. 749 # We force `arg` to be on CPU here. 750 with ops.device("CPU:0"): 751 arg = x + 10. 752 753 def true_fn(): 754 with ops.device("CPU:0"): 755 return arg + 1 756 757 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 758 759 # Disable Loop_optimizer grappler pass for this test because it replaces 760 # Switch with Identity when it's part of a dead branch. 761 config = config_pb2.ConfigProto() 762 config.graph_options.rewrite_options.loop_optimization = ( 763 rewriter_config_pb2.RewriterConfig.OFF) 764 765 with self.session(config=config) as sess: 766 run_metadata = config_pb2.RunMetadata() 767 options = config_pb2.RunOptions(output_partition_graphs=True) 768 sess.run( 769 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 770 self.assertLen(run_metadata.partition_graphs, 2) 771 # Check that the Switch for `arg` gets placed on CPU. 772 self.assertEqual( 773 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 774 dtypes.float32), 1) 775 self.assertEqual( 776 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 777 dtypes.float32), 0) 778 779 @test_util.run_gpu_only 780 @test_util.run_deprecated_v1 781 def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self): 782 x = array_ops.placeholder(dtypes.float32) 783 784 # `arg` is used in the cond then branch so a Switch node is created for it. 785 # We test that the Switch node gets placed on the same device as `arg`. 786 # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU 787 # by placer. 788 arg = dataset_ops.Dataset.range(8) 789 790 def true_fn(): 791 return cardinality.cardinality(arg) 792 793 r = control_flow_ops.cond( 794 constant_op.constant(True), true_fn, 795 lambda: constant_op.constant(0, dtypes.int64)) 796 797 # Disable Loop_optimizer grappler pass for this test because it replaces 798 # Switch with Identity when it's part of a dead branch. 799 config = config_pb2.ConfigProto() 800 config.graph_options.rewrite_options.loop_optimization = ( 801 rewriter_config_pb2.RewriterConfig.OFF) 802 803 with session.Session(config=config) as sess: 804 run_metadata = config_pb2.RunMetadata() 805 options = config_pb2.RunOptions(output_partition_graphs=True) 806 sess.run( 807 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 808 self.assertLen(run_metadata.partition_graphs, 2) 809 # Check that the Switch for `arg` gets placed on CPU. 810 self.assertEqual( 811 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 812 dtypes.variant), 1) 813 self.assertEqual( 814 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 815 dtypes.variant), 0) 816 817 @test_util.run_gpu_only 818 @test_util.run_deprecated_v1 819 def testCondSwitchColocatedWithInputWhenInputOnGPU(self): 820 x = array_ops.placeholder(dtypes.float32) 821 822 # `arg` is used in the cond then branch so a Switch node is created for it. 823 # We test that the Switch node gets placed on the same device as `arg`. 824 # Note: `arg` gets placed on GPU by default by the placer. 825 arg = x + 10. 826 827 def true_fn(): 828 with ops.device("CPU:0"): 829 return arg + 1 830 831 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 832 833 # Disable Loop_optimizer grappler pass for this test because it replaces 834 # Switch with Identity when it's part of a dead branch. 835 config = config_pb2.ConfigProto() 836 config.graph_options.rewrite_options.loop_optimization = ( 837 rewriter_config_pb2.RewriterConfig.OFF) 838 839 with session.Session(config=config) as sess: 840 run_metadata = config_pb2.RunMetadata() 841 options = config_pb2.RunOptions(output_partition_graphs=True) 842 sess.run( 843 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 844 self.assertEqual(len(run_metadata.partition_graphs), 2) 845 # Check that the Switch for `arg` gets placed on GPU. 846 self.assertEqual( 847 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 848 dtypes.float32), 0) 849 self.assertEqual( 850 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 851 dtypes.float32), 1) 852 853 def testCondAccessTrueBranchTensorInFalseBranchRaises(self): 854 855 @def_function.function 856 def f(): 857 c = constant_op.constant(1.) 858 inputs = {"c": c} 859 860 def true_fn(inputs): 861 inputs["c"] = array_ops.identity(inputs["c"], name="true_branch") 862 return inputs["c"] 863 864 def false_fn(inputs): 865 return array_ops.identity(inputs["c"]) 866 867 pred = constant_op.constant(True) 868 return control_flow_ops.cond( 869 pred, lambda: true_fn(inputs), lambda: false_fn(inputs)) 870 871 # This was needed for backwards compatibility with TF2 Estimators which 872 # rely on variable names. 873 prefix = "cond/" if context.executing_eagerly() else "" 874 875 with self.assertRaisesRegex( 876 ValueError, 877 "Tensor %strue_branch:0 in true_fn is accessed from false_fn." % 878 prefix): 879 f() 880 881 def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): 882 883 @def_function.function 884 def f(): 885 c = constant_op.constant(1.) 886 inputs = {"c": c} 887 888 def br1_fn(inputs): 889 inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity") 890 return inputs["c"] 891 892 def br4_fn(inputs): 893 return array_ops.identity(inputs["c"]) 894 895 def other_fn(): 896 return array_ops.identity(c) 897 898 return control_flow_ops.switch_case( 899 constant_op.constant(2), 900 [other_fn, lambda: br1_fn(inputs), other_fn, other_fn, 901 lambda: br4_fn(inputs)]) 902 903 # This was needed for backwards compatibility with TF2 Estimators which 904 # rely on variable names. 905 prefix = "switch_case/indexed_case/" if context.executing_eagerly() else "" 906 with self.assertRaisesRegex( 907 ValueError, "Tensor %sbr1_identity:0 in branch 1 is " 908 "accessed from branch 4." % prefix): 909 f() 910 911 def testCondListOutput(self): 912 with self.cached_session() as sess: 913 x = constant_op.constant(10) 914 y = constant_op.constant(200) 915 pred = math_ops.less(1, 2) 916 fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)] 917 fn2 = lambda: [y, y] 918 r = control_flow_ops.cond(pred, fn1, fn2) 919 test_result = self.evaluate(r) 920 self.assertListEqual([210, 210], test_result) 921 922 def testTupleOutput(self): 923 with self.cached_session() as sess: 924 x = constant_op.constant(10) 925 y = constant_op.constant(200) 926 pred = math_ops.less(1, 2) 927 fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y)) 928 fn2 = lambda: (y, y) 929 r = control_flow_ops.cond(pred, fn1, fn2) 930 test_result = self.evaluate(r) 931 self.assertTupleEqual((210, 210), test_result) 932 933 def testDictOutput(self): 934 with self.cached_session() as sess: 935 x = constant_op.constant(10) 936 y = constant_op.constant(200) 937 pred = math_ops.less(1, 2) 938 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 939 fn2 = lambda: {"a": y, "b": y} 940 r = control_flow_ops.cond(pred, fn1, fn2) 941 test_result = self.evaluate(r) 942 self.assertDictEqual({"a": 210, "b": 210}, test_result) 943 944 def testEmbeddedListOutput(self): 945 x = constant_op.constant(10) 946 y = constant_op.constant(200) 947 pred = math_ops.less(1, 2) 948 fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]] 949 fn2 = lambda: [[y, y]] 950 # Pass strict=True flag as cond_v2 allows for tensors to be 951 # in nested output structures as singletons 952 r = control_flow_ops.cond(pred, fn1, fn2, strict=True) 953 test_result = self.evaluate(r) 954 self.assertListEqual([[210, 210]], test_result) 955 956 def testEmbeddedTupleOutput(self): 957 with self.cached_session() as sess: 958 x = constant_op.constant(10) 959 y = constant_op.constant(200) 960 pred = math_ops.less(1, 2) 961 fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y))) 962 fn2 = lambda: ((y, y)) 963 r = control_flow_ops.cond(pred, fn1, fn2) 964 test_result = self.evaluate(r) 965 self.assertTupleEqual(((210, 210)), test_result) 966 967 def testEmbeddedDictOutput(self): 968 with self.cached_session() as sess: 969 x = constant_op.constant(10) 970 y = constant_op.constant(200) 971 pred = math_ops.less(1, 2) 972 fn1 = lambda: {"a": {"c": math_ops.add(x, y)}, 973 "b": {"d": math_ops.add(x, y)}} 974 fn2 = lambda: {"a": {"c": y}, 975 "b": {"d": y}} 976 r = control_flow_ops.cond(pred, fn1, fn2) 977 test_result = self.evaluate(r) 978 self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result) 979 980 @test_util.run_v1_only("b/120545219") 981 def testCheckNestedOutputStruct(self): 982 with self.cached_session() as sess: 983 x = constant_op.constant(10) 984 y = constant_op.constant(200) 985 pred = math_ops.less(1, 2) 986 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 987 fn2 = lambda: {"c": y, "d": y} 988 v1_msg = "The two structures don't have the same nested structure" 989 v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same " 990 "number, type, and overall structure of return values.") 991 with self.assertRaisesRegex( 992 TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError, 993 v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): 994 control_flow_ops.cond(pred, fn1, fn2) 995 996 @test_util.run_v1_only("b/120545219") 997 def testCondWithControl(self): 998 with self.cached_session() as sess: 999 control_holder = array_ops.placeholder(dtypes.float32, shape=()) 1000 a = constant_op.constant(3) 1001 1002 def true_branch(): 1003 with ops.control_dependencies([control_holder]): 1004 _ = a + 1 1005 return a + 2 1006 1007 r = control_flow_ops.cond( 1008 constant_op.constant(True), true_branch, 1009 lambda: constant_op.constant(1)) 1010 result = sess.run(r, feed_dict={control_holder: 5.}) 1011 self.assertEqual(5, result) 1012 1013 @test_util.run_v1_only("b/120545219") 1014 def testUninitializedRefIdentity(self): 1015 with self.cached_session() as sess: 1016 v = gen_state_ops.variable( 1017 shape=[1], 1018 dtype=dtypes.float32, 1019 name="v", 1020 container="", 1021 shared_name="") 1022 inited = state_ops.is_variable_initialized(v) 1023 v_f, v_t = control_flow_ops.ref_switch(v, inited) 1024 # Both v_f and v_t are uninitialized references. However, an actual use 1025 # of the reference in the 'true' branch in the 'tf.identity' op will 1026 # not 'fire' when v is uninitialized, so this is a valid construction. 1027 # This test tests that ref_identity allows uninitialized ref as input 1028 # so that this construction is allowed. 1029 v_f_op = gen_array_ops.ref_identity(v_f) 1030 v_t_op = gen_array_ops.ref_identity(v_t) 1031 with ops.control_dependencies([v_f_op]): 1032 assign_v = state_ops.assign(v, [1.0]) 1033 with ops.control_dependencies([v_t_op]): 1034 orig_v = array_ops.identity(v) 1035 merged_op = control_flow_ops.merge([assign_v, orig_v]) 1036 self.assertAllEqual([1.0], self.evaluate(merged_op.output)) 1037 1038 def testCondSwitchIdentity(self): 1039 # Make sure the recv identity is not removed by optimization. 1040 with session.Session(config=opt_cfg()) as sess: 1041 pred = constant_op.constant(True) 1042 1043 def fn1(): 1044 return control_flow_ops.no_op() 1045 1046 def fn2(): 1047 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 1048 1049 r = control_flow_ops.cond(pred, fn1, fn2) 1050 self.evaluate(r) 1051 1052 def testCondRecvIdentity(self): 1053 # Make sure the switch identity is not removed by optimization. 1054 with session.Session(config=opt_cfg()) as sess: 1055 with ops.device(test.gpu_device_name()): 1056 pred = constant_op.constant(True) 1057 1058 def fn1(): 1059 return control_flow_ops.no_op() 1060 1061 def fn2(): 1062 with ops.device("/cpu:0"): 1063 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 1064 1065 r = control_flow_ops.cond(pred, fn1, fn2) 1066 self.evaluate(r) 1067 1068 @test_util.run_deprecated_v1 1069 @test_util.enable_control_flow_v2 1070 def testDisableLoweringSwitchMerge(self): 1071 if test_util.is_gpu_available(): 1072 self.skipTest( 1073 "Single threaded executor doesn't support partitioned graphs. " 1074 "Skipping GPU test.") 1075 # Make pred feedable to ensure we don't constant-fold it out. 1076 run_opts = config_pb2.RunOptions( 1077 trace_level=config_pb2.RunOptions.FULL_TRACE) 1078 run_metadata_no_lowering = config_pb2.RunMetadata() 1079 run_metadata_with_lowering = config_pb2.RunMetadata() 1080 1081 config = opt_cfg(do_constant_folding=False) 1082 1083 pred = array_ops.placeholder_with_default( 1084 constant_op.constant(True), shape=()) 1085 r = control_flow_ops.cond(pred, lambda: True, lambda: False) 1086 1087 with session.Session(config=config) as sess: 1088 r_value = sess.run( 1089 r, options=run_opts, run_metadata=run_metadata_with_lowering) 1090 self.assertEqual(r_value, True) 1091 1092 # Use the single threaded executor, which disables control flow lowering. 1093 config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" 1094 with session.Session(config=config) as sess: 1095 r_value = sess.run( 1096 r, options=run_opts, run_metadata=run_metadata_no_lowering) 1097 self.assertEqual(r_value, True) 1098 1099 self.assertTrue( # pylint: disable=g-complex-comprehension 1100 any("switch" in ns.node_name 1101 for dev_stat in run_metadata_with_lowering.step_stats.dev_stats 1102 for ns in dev_stat.node_stats)) 1103 1104 self.assertTrue( # pylint: disable=g-complex-comprehension 1105 all("switch" not in ns.node_name 1106 for dev_stat in run_metadata_no_lowering.step_stats.dev_stats 1107 for ns in dev_stat.node_stats)) 1108 1109 @test_util.run_v1_only("b/120545219") 1110 def testCondGrad_1(self): 1111 with self.cached_session(): 1112 x = constant_op.constant(10.0, name="x") 1113 pred = math_ops.less(1, 2) 1114 fn1 = lambda: array_ops.identity(x) 1115 fn2 = lambda: array_ops.identity(x) 1116 r = control_flow_ops.cond(pred, fn1, fn2) 1117 1118 grad = gradients_impl.gradients(r, [x])[0] 1119 self.assertAllEqual(1.0, self.evaluate(grad)) 1120 1121 @test_util.run_deprecated_v1 1122 @test_util.enable_control_flow_v2 1123 def testCondComputeGradAfterSessRunFails(self): 1124 with self.cached_session(): 1125 x = constant_op.constant(10.0, name="x") 1126 pred = math_ops.less(1, 2) 1127 1128 def true_fn(): 1129 a = x * x 1130 return a * a 1131 1132 def false_fn(): 1133 return x * x 1134 1135 r = control_flow_ops.cond(pred, true_fn, false_fn) 1136 1137 self.assertAllEqual(r, 10000.) 1138 grad = gradients_impl.gradients(r, [x])[0] 1139 with self.assertRaisesRegex( 1140 errors_impl.InvalidArgumentError, 1141 r"Connecting to invalid output 1 of source node cond which has 1 " 1142 r"outputs. Try using " 1143 "tf.compat.v1.experimental.output_all_intermediates\(True\)."): 1144 self.evaluate(grad) 1145 1146 @test_util.run_deprecated_v1 1147 @test_util.enable_output_all_intermediates 1148 def testCondComputeGradAfterSessRun(self): 1149 with self.cached_session(): 1150 x = constant_op.constant(10.0, name="x") 1151 pred = math_ops.less(1, 2) 1152 1153 def true_fn(): 1154 a = x * x 1155 return a * a 1156 1157 def false_fn(): 1158 return x * x 1159 1160 r = control_flow_ops.cond(pred, true_fn, false_fn) 1161 1162 self.assertAllEqual(r, 10000.) 1163 grad = gradients_impl.gradients(r, [x])[0] 1164 self.assertAllEqual(grad, 4000.) 1165 1166 @test_util.run_deprecated_v1 1167 @test_util.enable_output_all_intermediates 1168 def testNestedCondComputeGradAfterSessRun(self): 1169 with self.cached_session(): 1170 x = constant_op.constant(10.0, name="x") 1171 pred = math_ops.less(1, 2) 1172 1173 def true_fn(): 1174 1175 def inner_true_fn(): 1176 a = x * x 1177 return a * a 1178 1179 def inner_false_fn(): 1180 return x * x 1181 1182 return control_flow_ops.cond( 1183 constant_op.constant(True), inner_true_fn, inner_false_fn) 1184 1185 def false_fn(): 1186 return x * x 1187 1188 r = control_flow_ops.cond(pred, true_fn, false_fn) 1189 1190 self.assertAllEqual(r, 10000.) 1191 grad = gradients_impl.gradients(r, [x])[0] 1192 self.assertAllEqual(grad, 4000.) 1193 1194 @test_util.run_deprecated_v1 1195 def testCondGrad_2(self): 1196 with self.cached_session(): 1197 c = array_ops.placeholder(dtypes.int32, shape=[]) 1198 x = constant_op.constant(10.0) 1199 pred = math_ops.less(c, 2) 1200 fn1 = lambda: math_ops.multiply(x, 42.0) 1201 fn2 = lambda: math_ops.multiply(x, 3.0) 1202 r = control_flow_ops.cond(pred, fn1, fn2) 1203 1204 grad = gradients_impl.gradients(r, [x])[0] 1205 self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) 1206 self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) 1207 1208 @test_util.disable_control_flow_v2( 1209 "b/110550782 (gradient w.r.t external variable)") 1210 @test_util.run_deprecated_v1 1211 def testCondGrad_3(self): 1212 with self.cached_session(): 1213 c = array_ops.placeholder(dtypes.int32, shape=[]) 1214 ox = constant_op.constant(10.0) 1215 pred = math_ops.less(c, 2) 1216 1217 def fn1(x): 1218 m = x * x 1219 return gradients_impl.gradients(m, [ox])[0] 1220 1221 fn2 = lambda: math_ops.multiply(ox, 3.0) 1222 y = math_ops.multiply(7.0, ox) 1223 r = control_flow_ops.cond(pred, lambda: fn1(y), fn2) 1224 1225 self.assertAllEqual(980.0, r.eval(feed_dict={c: 1})) 1226 self.assertAllEqual(30.0, r.eval(feed_dict={c: 3})) 1227 1228 @test_util.run_deprecated_v1 1229 def testCondGradMultiDevice(self): 1230 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 1231 allow_soft_placement=True) 1232 with self.cached_session(config=config) as sess: 1233 pred = array_ops.placeholder(dtypes.bool, []) 1234 x = array_ops.placeholder(dtypes.float32) 1235 y = array_ops.placeholder(dtypes.float32) 1236 1237 with ops.device("/cpu:0"): 1238 z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0) 1239 1240 with ops.device("/cpu:1"): 1241 grad = gradients_impl.gradients(z, x)[0] 1242 1243 with ops.device("/cpu:0"): 1244 grad_grad = gradients_impl.gradients(grad, x)[0] 1245 1246 self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0) 1247 self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 1248 1249 # v1 control flow gets None second derivative for some reason. 1250 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 1251 self.assertIsNone(grad_grad) 1252 return 1253 1254 self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0) 1255 self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 1256 1257 @test_util.run_v1_only("b/120545219") 1258 def testNestedCond_Simple(self): 1259 with self.cached_session(): 1260 x = constant_op.constant(0., name="X") 1261 y = control_flow_ops.cond( 1262 constant_op.constant(True), lambda: x, 1263 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1264 result = gradients_impl.gradients(y, x)[0] 1265 self.assertEqual(1.0, self.evaluate(result)) 1266 1267 z = control_flow_ops.cond( 1268 constant_op.constant(False), lambda: x, 1269 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1270 result = gradients_impl.gradients(z, x)[0] 1271 self.assertEqual(1.0, self.evaluate(result)) 1272 1273 @test_util.run_v1_only("b/120545219") 1274 def testCondGrad_Gather(self): 1275 with self.cached_session() as sess: 1276 v1 = variables.Variable([1.0, 42.0]) 1277 c = array_ops.placeholder(dtypes.int32, shape=[]) 1278 pred = math_ops.less(c, 2) 1279 fn1 = lambda: array_ops.identity(v1) 1280 fn2 = lambda: array_ops.gather(v1, [1, 1]) 1281 r = control_flow_ops.cond(pred, fn1, fn2) 1282 # The following `grad` is a Tensor since it is the aggregation of an 1283 # IndexedSlice and a Tensor. It is an `IndexedSlices` with control flow 1284 # v2. 1285 grad = gradients_impl.gradients(r, [v1])[0] 1286 self.evaluate(variables.global_variables_initializer()) 1287 1288 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1289 self.assertIsInstance(grad, indexed_slices.IndexedSlices) 1290 1291 grad_value = sess.run(grad, feed_dict={c: 1}) 1292 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [1.0, 1.0]) 1293 1294 grad_value = sess.run(grad, feed_dict={c: 3}) 1295 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [0.0, 2.0]) 1296 1297 @test_util.run_deprecated_v1 1298 def testCondGrad_ResourceVarSparseRead(self): 1299 # NOTE(skyewm): this test is interesting because the 1300 # ResourceVariable.sparse_read gradient function returns IndexedSlices. 1301 var = resource_variable_ops.ResourceVariable( 1302 np.ones((4, 2), dtype=np.float32)) 1303 x = constant_op.constant(1.0) 1304 r = control_flow_ops.cond( 1305 constant_op.constant(True), 1306 lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])), 1307 lambda: constant_op.constant(np.zeros((2, 3)), 1308 dtype=dtypes.float32)) 1309 grad = gradients_impl.gradients(r, var)[0] 1310 1311 self.evaluate(variables.global_variables_initializer()) 1312 grad_val = self.evaluate(grad) 1313 self.assertIsInstance(grad_val, indexed_slices.IndexedSlicesValue) 1314 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.], 1315 [1., 1.], 1316 [1., 1.], 1317 [0., 0.]]) 1318 1319 def testCondGrad_MultiGather(self): 1320 # NOTE(skyewm): this test is interesting because the array_ops.gather and 1321 # ResourceVariable.sparse_read gradient functions returns IndexedSlices. 1322 var = resource_variable_ops.ResourceVariable( 1323 np.ones((4, 2), dtype=np.float32)) 1324 x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32)) 1325 x2 = constant_op.constant(2.0) 1326 1327 def true_fn(): 1328 y1 = var.sparse_read([1, 2]) 1329 y2 = array_ops.gather(x1, [2]) * x2 1330 y3 = x2 * [1., 1., 1.] 1331 return y1, y2, y3 1332 1333 def false_fn(): 1334 y1 = np.zeros((2, 2), dtype=np.float32) 1335 y2 = array_ops.gather(x1, [2]) * x2 1336 y3 = array_ops.gather(x1, [2]) 1337 return y1, y2, y3 1338 1339 @def_function.function 1340 def foo(): 1341 r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn) 1342 return gradients_impl.gradients(r, [var, x1, x2]) 1343 1344 grad = foo() 1345 self.evaluate(variables.global_variables_initializer()) 1346 var_grad, x1_grad, x2_grad = self.evaluate(grad) 1347 self.assertIsInstance(var_grad, indexed_slices.IndexedSlicesValue) 1348 self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.], 1349 [1., 1.], 1350 [1., 1.], 1351 [0., 0]]) 1352 self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue) 1353 self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.], 1354 [0., 0., 0.], 1355 [2., 2., 2.]]) 1356 self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue) 1357 self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.) 1358 1359 @test_util.run_v1_only("b/120545219") 1360 def testCondPredicateTensor(self): 1361 """Regression test for lowering predicate from non-first output of an op.""" 1362 1363 @eager_function.defun 1364 def foo(): 1365 return constant_op.constant("foo"), constant_op.constant(True) 1366 1367 r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0) 1368 self.assertEqual(self.evaluate(r), 1.0) 1369 1370 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1371 def testCondFeedConstantPredicate(self): 1372 with self.cached_session() as sess: 1373 value = constant_op.constant(37.0) 1374 predicate = constant_op.constant(True) 1375 cond_output = control_flow_ops.cond( 1376 predicate, lambda: constant_op.constant(0.0), lambda: value) 1377 result = array_ops.identity(cond_output) 1378 self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1379 self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1380 self.assertEqual(0.0, sess.run(result)) 1381 1382 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1383 def testCondFeedPlaceholderWithDefaultPredicate(self): 1384 with self.cached_session() as sess: 1385 value = constant_op.constant(37.0) 1386 predicate = array_ops.placeholder_with_default( 1387 constant_op.constant(True), []) 1388 cond_output = control_flow_ops.cond( 1389 predicate, lambda: constant_op.constant(0.0), lambda: value) 1390 result = array_ops.identity(cond_output) 1391 self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1392 self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1393 self.assertAllEqual(0.0, sess.run(result)) 1394 1395 def testCondTensorDeps(self): 1396 t = array_ops.identity(1.) 1397 1398 @def_function.function 1399 def f(): 1400 with ops.control_dependencies([t]): 1401 return array_ops.identity(2.) 1402 1403 f.get_concrete_function() 1404 1405 @test_util.run_in_graph_and_eager_modes 1406 def testCondAutoControlDeps(self): 1407 if test_util.is_gpu_available(): 1408 self.skipTest("b/128676188 causes OOM on opensource gpu tests") 1409 1410 print_prefix = "testCondAutoControlDeps: " 1411 1412 def branch_fn(): 1413 enqueue_print_op("A") 1414 enqueue_print_op("B") 1415 with ops.control_dependencies([enqueue_print_op("C")]): 1416 return constant_op.constant(10) 1417 1418 def build_cond(): 1419 return control_flow_ops.cond( 1420 constant_op.constant(True), branch_fn, lambda: 0) 1421 1422 def build_nested_cond(): 1423 return control_flow_ops.cond( 1424 constant_op.constant(True), build_cond, lambda: 0) 1425 1426 # In v1 graph mode, pruning should make only "C" print. 1427 if not context.executing_eagerly(): 1428 with self.cached_session(): 1429 with self.captureWritesToStream(sys.stderr) as printed: 1430 self.assertEqual(self.evaluate(build_cond()), 10) 1431 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1432 1433 with self.captureWritesToStream(sys.stderr) as printed: 1434 self.assertEqual(self.evaluate(build_nested_cond()), 10) 1435 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1436 1437 # In defuns, all prints should execute in program order. 1438 # This doesn't work with legacy control flow. 1439 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1440 1441 @eager_function.defun 1442 def cond(): 1443 return build_cond() 1444 1445 with self.captureWritesToStream(sys.stderr) as printed: 1446 self.assertEqual(self.evaluate(cond()), 10) 1447 self.assertEqual(["A", "B", "C"], 1448 filter_test_messages(printed.contents())) 1449 1450 @eager_function.defun 1451 def nested_cond(): 1452 return build_nested_cond() 1453 1454 with self.captureWritesToStream(sys.stderr) as printed: 1455 self.assertEqual(self.evaluate(nested_cond()), 10) 1456 self.assertEqual(["A", "B", "C"], 1457 filter_test_messages(printed.contents())) 1458 1459 # wrap_function should prune. 1460 def pruned_cond(): 1461 return build_cond() 1462 pruned_cond = wrap_function.wrap_function(pruned_cond, []) 1463 1464 with self.captureWritesToStream(sys.stderr) as printed: 1465 self.assertEqual(self.evaluate(pruned_cond()), 10) 1466 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1467 1468 def pruned_nested_cond(): 1469 return build_nested_cond() 1470 pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, []) 1471 1472 with self.captureWritesToStream(sys.stderr) as printed: 1473 self.assertEqual(self.evaluate(pruned_nested_cond()), 10) 1474 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1475 1476 1477 @test_util.run_in_graph_and_eager_modes 1478 @test_util.disable_tfrt("b/179459136") 1479 def testWhileAutoControlDeps(self): 1480 # Legacy while_loop fails this test because it produces deprecation notices 1481 # in stderr. 1482 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return 1483 1484 def cond(i, unused_x): 1485 enqueue_print_op("A") 1486 return i < 2 1487 1488 def body(i, x): 1489 enqueue_print_op("B") 1490 with ops.control_dependencies([enqueue_print_op("C")]): 1491 x = array_ops.identity(x) 1492 with ops.control_dependencies([enqueue_print_op("D")]): 1493 return i + 1, x 1494 1495 def build_while(): 1496 return control_flow_ops.while_loop( 1497 cond, body, [constant_op.constant(0), constant_op.constant(0)]) 1498 1499 def build_nested_while(): 1500 return control_flow_ops.cond( 1501 constant_op.constant(True), build_while, lambda: [0, 0]) 1502 1503 # In v1 graph mode, pruning should make only "D" print. 1504 if not context.executing_eagerly(): 1505 with self.cached_session(): 1506 with self.captureWritesToStream(sys.stderr) as printed: 1507 self.assertEqual(self.evaluate(build_while()[0]), 2) 1508 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1509 1510 with self.captureWritesToStream(sys.stderr) as printed: 1511 self.assertEqual(self.evaluate(build_nested_while()[0]), 2) 1512 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1513 1514 # In defuns, all prints should execute in program order. 1515 @eager_function.defun 1516 def while_loop(): 1517 return build_while()[0] 1518 1519 with self.captureWritesToStream(sys.stderr) as printed: 1520 self.assertEqual(self.evaluate(while_loop()), 2) 1521 self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], 1522 filter_test_messages(printed.contents())) 1523 1524 @eager_function.defun 1525 def nested_while_loop(): 1526 return build_nested_while()[0] 1527 1528 with self.captureWritesToStream(sys.stderr) as printed: 1529 self.assertEqual(self.evaluate(nested_while_loop()), 2) 1530 self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], 1531 filter_test_messages(printed.contents())) 1532 1533 # wrap_function should prune. 1534 def pruned_while(): 1535 return build_while()[0] 1536 pruned_while = wrap_function.wrap_function(pruned_while, []) 1537 1538 with self.captureWritesToStream(sys.stderr) as printed: 1539 self.assertEqual(self.evaluate(pruned_while()), 2) 1540 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1541 1542 def pruned_nested_while(): 1543 return build_nested_while()[0] 1544 pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, []) 1545 1546 with self.captureWritesToStream(sys.stderr) as printed: 1547 self.assertEqual(self.evaluate(pruned_nested_while()), 2) 1548 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1549 1550 # Microbenchmark: 256,000 iterations/s. 1551 def testWhile_1(self): 1552 with self.cached_session(): 1553 n = constant_op.constant(0) 1554 c = lambda x: math_ops.less(x, 10000) 1555 b = lambda x: math_ops.add(x, 1) 1556 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1557 self.assertEqual(10000, self.evaluate(r)) 1558 1559 @test_util.run_v1_only("b/120545219") 1560 def testWhileExternalControlDependencies(self): 1561 with self.cached_session(): 1562 v = variables.Variable(0.0) 1563 self.evaluate(v.initializer) 1564 increment = v.assign_add(1.0).read_value() 1565 1566 def body_fn(i): 1567 with ops.control_dependencies([increment]): 1568 return i + 1 1569 1570 result = control_flow_ops.while_loop(cond=lambda i: i < 2, 1571 body=body_fn, loop_vars=[1]) 1572 self.assertAllEqual(result, 2) 1573 self.assertAllEqual(v.read_value(), 1.0) 1574 1575 @test_util.run_v1_only("b/120545219") 1576 def testWhileExternalControlDependenciesNoInput(self): 1577 with self.cached_session(): 1578 v = variables.Variable(0.0) 1579 self.evaluate(v.initializer) 1580 # TODO(apassos): figure out why the reading is necessary here. 1581 increment = v.assign_add(1.0).read_value() 1582 1583 def body_fn(unused_i): 1584 with ops.control_dependencies([increment]): 1585 return constant_op.constant(5, name="five") 1586 1587 result = control_flow_ops.while_loop(cond=lambda i: i < 5, 1588 body=body_fn, loop_vars=[0]) 1589 self.evaluate(result) 1590 self.assertAllEqual(self.evaluate(v), 1.0) 1591 1592 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 1593 @test_util.run_v1_only("b/120545219") 1594 def testWhileWithRefs_1(self): 1595 with self.cached_session() as sess: 1596 x = variables.VariableV1(0)._ref() # pylint: disable=protected-access 1597 i = constant_op.constant(0) 1598 c = lambda i, x: math_ops.less(i, 100) 1599 1600 self.assertEqual(x.dtype, dtypes.int32_ref) 1601 1602 def b(i, x): 1603 self.assertEqual(x.dtype, dtypes.int32_ref) 1604 return (i + 1, gen_array_ops.ref_identity(x)) 1605 1606 r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5) 1607 1608 self.evaluate(variables.global_variables_initializer()) 1609 1610 self.assertEqual(r[0].dtype, dtypes.int32) 1611 self.assertEqual(r[1].dtype, dtypes.int32_ref) 1612 1613 value_i, value_x = self.evaluate(r) 1614 1615 self.assertEqual(100, value_i) 1616 self.assertEqual(0, value_x) 1617 1618 def testWhile_2(self): 1619 with self.cached_session(): 1620 s = constant_op.constant(0) 1621 r = isum(s) 1622 self.assertAllEqual(45, self.evaluate(r)) 1623 1624 def testWhileWithMaximumIterations(self): 1625 with self.cached_session(): 1626 s = constant_op.constant([1, 2, 3, 4, 5]) 1627 r = isum(s, maximum_iterations=3) 1628 self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r)) 1629 1630 @test_util.run_v1_only("b/120545219") 1631 def testWhileWithMaximumIterationsAndSingleArgument(self): 1632 with self.cached_session(): 1633 r = control_flow_ops.while_loop( 1634 lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) 1635 self.assertEqual(1, self.evaluate(r)) 1636 1637 @test_util.run_v1_only("b/120545219") 1638 def testXLAGradInLoop(self): 1639 # We have an optimization that moves certain reduction ops, this test makes 1640 # sure we don't do that for XLA ops. 1641 1642 # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs" 1643 # and "Shape" op. 1644 input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) 1645 input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) 1646 def cond(i1, i2): 1647 return False 1648 1649 def body(i1, i2): 1650 return math_ops.add(i1, i2), math_ops.add(i1, i2) 1651 1652 xla_context = control_flow_ops.XLAControlFlowContext() 1653 xla_context.Enter() 1654 1655 out1, _ = control_flow_ops.while_loop( 1656 cond, body, (input1, input2), maximum_iterations=2) 1657 g = gradients_impl.gradients(out1, [input1]) 1658 1659 for op in out1.graph.get_operations(): 1660 # Test that the "Shape" is directly passed to BroadcastGradientArgs 1661 # instead of being pushed to the stack. 1662 if op.type == "BroadcastGradientArgs": 1663 self.assertEqual(op.inputs[0].op.type, "Shape") 1664 self.assertEqual(op.inputs[1].op.type, "Shape") 1665 xla_context.Exit() 1666 1667 1668 @test_util.disable_control_flow_v2("b/115776323 (max_iters)") 1669 @test_util.run_v1_only("b/120545219") 1670 def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): 1671 v = constant_op.constant(1.0) 1672 1673 def training_loop_with_gradient(i): 1674 out = control_flow_ops.while_loop( 1675 lambda i_, _: i_ < 3, 1676 lambda i_, j: [i_ + 1, j * v], [0, 1.0], 1677 maximum_iterations=i) 1678 g = gradients_impl.gradients(out, v) 1679 with ops.control_dependencies(g): 1680 return i + 1 1681 1682 xla_context = control_flow_ops.XLAControlFlowContext() 1683 xla_context.Enter() 1684 # Create training loop, ensure we can call gradient() of 1685 # while_loop inside the training loop. 1686 loop = control_flow_ops.while_loop(lambda i: i < 3, 1687 training_loop_with_gradient, [0]) 1688 xla_context.Exit() 1689 1690 loop_execute = array_ops.identity(loop) # Because loop is not fetchable. 1691 1692 # Should execute without issue. 1693 self.assertEqual(3, self.evaluate(loop_execute)) 1694 1695 @test_util.run_v1_only("b/120545219") 1696 def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): 1697 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1698 self.skipTest("WhileV2 does lazy evaluation of maximum_iterations") 1699 v = constant_op.constant(1.0) 1700 1701 def inner_body(i, x): 1702 out = control_flow_ops.while_loop( 1703 lambda i, _: i < 3, 1704 lambda i, j: [i + 1, j * v], [0, x], 1705 maximum_iterations=i) 1706 return out 1707 1708 def create_while_loop(maximum_iterations=None): 1709 return control_flow_ops.while_loop( 1710 lambda i, _: i < 3, 1711 inner_body, [0, 1.0], 1712 maximum_iterations=maximum_iterations) 1713 1714 loop_no_xla = create_while_loop(maximum_iterations=5) 1715 # maximum_iterations is fine outside of an XLA scope 1716 gs = gradients_impl.gradients(loop_no_xla, v) 1717 self.evaluate(gs) # This should execute without error. 1718 1719 xla_context = control_flow_ops.XLAControlFlowContext() 1720 xla_context.Enter() 1721 loop_no_maxiter = create_while_loop() 1722 loop_with_maxiter = create_while_loop(maximum_iterations=2) 1723 xla_context.Exit() 1724 1725 with self.assertRaisesRegex( 1726 ValueError, 1727 r"Cannot create a gradient accumulator for tensor '.+' inside " 1728 r"XLA while_loop because maximum_iterations was not passed to " 1729 r"the tf.while_loop call \('.+'\)."): 1730 _ = gradients_impl.gradients(loop_no_maxiter, v) 1731 1732 with self.assertRaisesRegex( 1733 ValueError, 1734 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1735 r"while_loop. maximum_iterations tensor '.+' for while_loop context " 1736 r"'.+' must be statically known \(e.g. a constant value or known " 1737 r"shape dimension\), or be defined at or outside the while loop " 1738 r"context '.*' \(currently defined in '.*'\)"): 1739 _ = gradients_impl.gradients(loop_with_maxiter, v) 1740 1741 @test_util.run_v1_only("b/120545219") 1742 def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): 1743 v = constant_op.constant(1.0) 1744 1745 def create_while_loop(): 1746 max_iter_holder = [] 1747 1748 def create_mi(): 1749 max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) 1750 return 1.0 1751 1752 _ = control_flow_ops.cond( 1753 constant_op.constant(True), create_mi, create_mi) 1754 1755 return control_flow_ops.while_loop( 1756 lambda i, _: i < 3, 1757 lambda i, x: (i + 1, v * x), (0, 1.0), 1758 maximum_iterations=max_iter_holder[0]) 1759 1760 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1761 xla_context = control_flow_ops.XLAControlFlowContext() 1762 xla_context.Enter() 1763 with self.assertRaisesRegex(ValueError, r"must be from the same graph.*"): 1764 loop = create_while_loop() 1765 xla_context.Exit() 1766 else: 1767 xla_context = control_flow_ops.XLAControlFlowContext() 1768 xla_context.Enter() 1769 loop = create_while_loop() 1770 xla_context.Exit() 1771 with self.assertRaisesRegex( 1772 ValueError, 1773 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1774 r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " 1775 r"while_loop context '.+' must be statically known \(e.g. a constant " 1776 r"value or known shape dimension\), or be defined at or outside the " 1777 r"while loop context '' \(currently defined in 'cond/.+'\)"): 1778 _ = gradients_impl.gradients(loop, v) 1779 1780 @test_util.run_v1_only("b/120545219") 1781 def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): 1782 if test_util.is_gpu_available(): 1783 self.skipTest("b/128646372, b/128645947 fails in opensource build") 1784 1785 v = constant_op.constant(1.0) 1786 1787 p = array_ops.placeholder(dtype=dtypes.int32) 1788 1789 def mid_body_builder(iterations): 1790 1791 def mid_body(i, x): 1792 r = control_flow_ops.while_loop( 1793 lambda *_: True, 1794 lambda i, x: (i + 1, v * x), (0, x), 1795 maximum_iterations=iterations, 1796 name="inner") 1797 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 1798 1799 return mid_body 1800 1801 def outer_body(i, x): 1802 iterations = array_ops.size(p, name="iterations") 1803 return (i + 1, x + control_flow_ops.while_loop( 1804 lambda *_: True, 1805 mid_body_builder(iterations), (0, x), 1806 maximum_iterations=iterations, 1807 name="mid")[1]) 1808 1809 def create_while_loop(): 1810 with ops.device("/cpu:0"): 1811 r = control_flow_ops.while_loop( 1812 lambda *_: True, 1813 outer_body, (0, 1.0), 1814 maximum_iterations=5, 1815 name="outer") 1816 return array_ops.identity(r[1]) 1817 1818 xla_context = control_flow_ops.XLAControlFlowContext() 1819 xla_context.Enter() 1820 final_with_xla_context = create_while_loop() 1821 xla_context.Exit() 1822 1823 final_without_xla_context = create_while_loop() 1824 1825 with self.session(use_gpu=False) as sess: 1826 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1827 run_metadata_without_xla_context = config_pb2.RunMetadata() 1828 run_metadata = config_pb2.RunMetadata() 1829 1830 final_value_without_xla_context = sess.run( 1831 final_without_xla_context, 1832 feed_dict={p: [0, 0, 0]}, 1833 options=opts, 1834 run_metadata=run_metadata_without_xla_context) 1835 1836 final_value_with_xla_context = sess.run( 1837 final_with_xla_context, 1838 feed_dict={p: [0, 0, 0]}, 1839 options=opts, 1840 run_metadata=run_metadata) 1841 1842 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1843 # With while_v2 on xla, run_metadata only contains the unlowered While 1844 # op so node_stats does not have statistics for the pushes. So as a 1845 # loose check we check the pushes in the lowered version. 1846 for dev in run_metadata_without_xla_context.step_stats.dev_stats: 1847 if "/device:CPU" in dev.device: 1848 node_stats = dev.node_stats 1849 stack_push_count = len([ 1850 x for x in node_stats 1851 if re.match(r".*TensorListPushBack_?\d*", x.node_name) 1852 ]) 1853 else: 1854 for dev in run_metadata.step_stats.dev_stats: 1855 if "/device:CPU" in dev.device: 1856 node_stats = dev.node_stats 1857 stack_push_op = "StackPushV2" 1858 stack_push_count = len( 1859 [x for x in node_stats if x.node_name.endswith("StackPushV2")]) 1860 # Pushes to the stack = product of maximum_iterations values; 1861 # the last two "3"s comes from size(p), when p == [0, 0, 0]. 1862 self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats)) 1863 1864 self.assertAllClose(final_value_with_xla_context, 1865 final_value_without_xla_context) 1866 1867 # Have more than 10 parallel iterations and hence exercise k-bound 1868 # most of the time. 1869 @test_util.run_deprecated_v1 1870 def testWhile_3(self): 1871 with self.cached_session(): 1872 1873 def compute(i, m, c, o): 1874 m, c = [math_ops.add(m, 1), math_ops.add(c, 1)] 1875 o = math_ops.add(o, m) 1876 o = math_ops.add(o, c) 1877 i = math_ops.add(i, 1) 1878 return [i, m, c, o] 1879 1880 i = ops.convert_to_tensor(0) 1881 m = ops.convert_to_tensor(0) 1882 c = ops.convert_to_tensor(0) 1883 o = ops.convert_to_tensor(0) 1884 d = ops.convert_to_tensor(100) 1885 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d), 1886 compute, [i, m, c, o]) 1887 result = r[3] 1888 self.assertAllEqual(10100, result) 1889 1890 @test_util.run_deprecated_v1 1891 def testWhile_4(self): 1892 with self.cached_session(): 1893 1894 def compute(i, m, c, o): 1895 m, c = [array_ops.gather(x, i), array_ops.gather(x, i)] 1896 o = math_ops.add(o, m) 1897 o = math_ops.add(o, c) 1898 i = math_ops.add(i, 1) 1899 return [i, m, c, o] 1900 1901 i = ops.convert_to_tensor(0) 1902 m = ops.convert_to_tensor(0) 1903 c = ops.convert_to_tensor(0) 1904 o = ops.convert_to_tensor(0) 1905 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1906 s = array_ops.size(x) 1907 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s), 1908 compute, [i, m, c, o]) 1909 result = r[3] 1910 self.assertAllEqual(42, result) 1911 1912 @test_util.run_v1_only("b/120545219") 1913 def testWhile_5(self): 1914 with self.cached_session(): 1915 1916 def compute(i, c, o): 1917 c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0), 1918 [1] + array_ops.expand_dims(i, 0)) 1919 o = array_ops.concat([o, c], 0) 1920 i = math_ops.add(i, 1) 1921 return [i, c, o] 1922 1923 i = ops.convert_to_tensor(0) 1924 c = ops.convert_to_tensor([0]) 1925 o = ops.convert_to_tensor([0]) 1926 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1927 s = array_ops.size(x) 1928 r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s), 1929 compute, [i, c, o], [ 1930 i.get_shape(), 1931 tensor_shape.unknown_shape(), 1932 tensor_shape.unknown_shape() 1933 ]) 1934 result = r[2] 1935 self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result) 1936 1937 @test_util.run_gpu_only 1938 @test_util.run_deprecated_v1 1939 def testWhile_Device(self): 1940 1941 # Body function defined outside of device scope 1942 def body(x): 1943 return math_ops.exp(x) 1944 1945 with ops.device("CPU:0"): 1946 r = control_flow_ops.while_loop( 1947 lambda x: x < 10, body, [constant_op.constant(-10.)]) 1948 self.assertIn("cpu", r.device.lower()) 1949 1950 with session.Session() as sess: 1951 options = config_pb2.RunOptions(output_partition_graphs=True) 1952 run_metadata = config_pb2.RunMetadata() 1953 sess.run(r, options=options, run_metadata=run_metadata) 1954 # We expect that everything runs on CPU, even if GPU is available. 1955 self.assertEqual(len(run_metadata.partition_graphs), 1) 1956 1957 @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)") 1958 @test_util.run_v1_only("b/120545219") 1959 def testBufferForwarding(self): 1960 run_options = config_pb2.RunOptions( 1961 trace_level=config_pb2.RunOptions.FULL_TRACE) 1962 run_metadata = config_pb2.RunMetadata() 1963 1964 with self.cached_session() as sess: 1965 with ops.device("/cpu:0"): 1966 c = constant_op.constant(2) 1967 i0 = constant_op.constant(0) 1968 r = control_flow_ops.while_loop(lambda i: i < 1000, 1969 lambda i: math_ops.square(c) + i, [i0]) 1970 r_val = sess.run(r, options=run_options, run_metadata=run_metadata) 1971 self.assertEqual(1000, r_val) 1972 self.assertTrue(run_metadata.HasField("step_stats")) 1973 unique_allocs = set() 1974 for node_stat in run_metadata.step_stats.dev_stats[0].node_stats: 1975 for output in node_stat.output: 1976 unique_allocs.add( 1977 output.tensor_description.allocation_description.ptr) 1978 # Prior to cl/147536680, the number of unique allocations was about 1005. 1979 self.assertLess(len(unique_allocs), 756) 1980 1981 def _testWhile_Gpu_1(self, use_gpu): 1982 with self.cached_session(use_gpu=use_gpu): 1983 n = constant_op.constant(1.0) 1984 c = lambda x: math_ops.less(x, 10.0) 1985 b = lambda x: math_ops.add(x, 1.0) 1986 r = control_flow_ops.while_loop(c, b, [n]) 1987 self.assertAllClose(10.0, self.evaluate(r)) 1988 1989 def testWhile_Gpu_1(self): 1990 self._testWhile_Gpu_1(use_gpu=False) 1991 self._testWhile_Gpu_1(use_gpu=True) 1992 1993 def _testWhile_Gpu_2(self, use_gpu): 1994 with self.cached_session(use_gpu=use_gpu): 1995 n = constant_op.constant(1.0) 1996 c = lambda x: math_ops.less(x, 10.0) 1997 1998 def b(x): 1999 with ops.device("/cpu:0"): 2000 return math_ops.add(x, 1.0) 2001 2002 r = control_flow_ops.while_loop(c, b, [n]) 2003 self.assertAllClose(10.0, self.evaluate(r)) 2004 2005 def testWhile_Gpu_2(self): 2006 self._testWhile_Gpu_2(use_gpu=False) 2007 self._testWhile_Gpu_2(use_gpu=True) 2008 2009 def testWhileShape(self): 2010 with self.cached_session(): 2011 i = constant_op.constant(0) 2012 m = array_ops.ones([2, 2]) 2013 c = lambda i, j: math_ops.less(i, 2) 2014 2015 def _b(i, j): 2016 new_i = math_ops.add(i, 1) 2017 new_j = array_ops.tile(j, [2, 2]) 2018 return [new_i, new_j] 2019 2020 r = control_flow_ops.while_loop( 2021 c, _b, [i, m], 2022 [i.get_shape(), tensor_shape.unknown_shape()]) 2023 r = r[1] * array_ops.ones([8, 8]) 2024 self.assertAllEqual(np.ones((8, 8)), self.evaluate(r)) 2025 2026 @test_util.disable_control_flow_v2("b/131265085") 2027 @test_util.run_v1_only("b/131265085") 2028 def testWhileBadShape(self): 2029 x = constant_op.constant([2.0, 4.0], name="values") 2030 i = constant_op.constant(0) 2031 c = lambda i, _: math_ops.less(i, 10) 2032 b = lambda i, x: [i + 1, x + 1] 2033 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2034 # Shape of x is [2], but we specify a shape of [5]. 2035 control_flow_ops.while_loop( 2036 c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])]) 2037 2038 @test_util.run_in_graph_and_eager_modes 2039 def testWhileBadBodyReturn(self): 2040 x = constant_op.constant([2.0, 4.0], name="values") 2041 i = constant_op.constant(0) 2042 c = lambda i, *x: math_ops.less(i, 10) 2043 2044 # body accepts N values and returns N+1 values. 2045 b = lambda i, *x: (i, i) + x 2046 2047 with self.assertRaisesRegex( 2048 ValueError, "The two structures don't have the same nested structure."): 2049 control_flow_ops.while_loop(c, b, [i, x]) 2050 2051 @test_util.run_deprecated_v1 2052 def testWhileWithNonTensorInput_Scalar(self): 2053 with self.cached_session(): 2054 n = 0 2055 c = lambda x: x < 10000 2056 b = lambda x: x + 1 2057 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 2058 self.assertEqual(10000, self.evaluate(r)) 2059 2060 def testWhileWithNonTensorInput_Vector(self): 2061 with self.cached_session(): 2062 n = np.array([0]) # Note, [0] would not work here; that is a list 2063 c = lambda x: x[0] < 10000 2064 b = lambda x: array_ops.stack([x[0] + 1]) 2065 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 2066 self.assertEqual([10000], self.evaluate(r)) 2067 2068 def testWhileShapeInference(self): 2069 with self.cached_session(): 2070 i = constant_op.constant(0) 2071 m = array_ops.ones([2, 2]) 2072 c = lambda i, j: math_ops.less(i, 2) 2073 2074 def b(i, j): 2075 new_i = math_ops.add(i, 1) 2076 new_j = array_ops.concat([j, j], 0) 2077 return [new_i, new_j] 2078 2079 r = control_flow_ops.while_loop( 2080 c, b, [i, m], 2081 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 2082 self.assertTrue(r[1].shape.is_compatible_with([8, 2])) 2083 2084 @test_util.run_v1_only("b/120545219") 2085 def testWhileShapeInferenceBadShape(self): 2086 with self.cached_session(): 2087 i = constant_op.constant(0) 2088 m = array_ops.ones([2, 2]) 2089 c = lambda i, j: math_ops.less(i, 2) 2090 b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)] 2091 with self.assertRaisesRegex( 2092 ValueError, 2093 r".*\(2, 2\).*\(4, 2\) after one iteration\. To allow the shape to " 2094 r"vary across iterations, use the `shape_invariants` argument of " 2095 r"tf.while_loop to specify a less-specific shape\."): 2096 control_flow_ops.while_loop(c, b, [i, m]) 2097 2098 def testWhileShapeInferenceSparseTensor(self): 2099 values = constant_op.constant([2.0, 4.0], name="values") 2100 indices = constant_op.constant([[0], [3]], 2101 dtype=dtypes.int64, 2102 name="indices") 2103 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2104 i = constant_op.constant(0) 2105 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2106 2107 def c(i, _): 2108 return i < 10 2109 2110 def b1(i, x): # modifies values. (shape of components is not changed.) 2111 return [ 2112 i + 1, 2113 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 2114 ] 2115 2116 def b2(i, x): # adds new values. (shape of components is changed.) 2117 return [ 2118 i + 1, 2119 sparse_ops.sparse_add( 2120 x, 2121 sparse_tensor.SparseTensor( 2122 indices=math_ops.cast( 2123 array_ops.fill([1, 1], i), dtypes.int64), 2124 values=array_ops.fill([1], 1.0), 2125 dense_shape=x.dense_shape)) 2126 ] 2127 2128 def b3(i, x): # modifies rank. (shape of all components is changed.) 2129 return [ 2130 i + 1, 2131 sparse_tensor.SparseTensor( 2132 array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, 2133 array_ops.concat([x.dense_shape, [10]], axis=0)) 2134 ] 2135 2136 def check_shapes(r, indices, values, dense_shape): 2137 self.assertTrue(r.indices.shape.is_compatible_with(indices)) 2138 self.assertTrue(r.values.shape.is_compatible_with(values)) 2139 self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape)) 2140 2141 # Default shape invariant; b1 only modifies values. 2142 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 2143 check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) 2144 2145 # Default shape invariant; b2 adds new values 2146 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 2147 check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) 2148 2149 # Explicit shape invariant, allowing any rank; b1 only modifies values. 2150 _, r = control_flow_ops.while_loop( 2151 c, b1, [i, x], 2152 [i.get_shape(), tensor_shape.TensorShape([None])]) 2153 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2154 2155 # Explicit shape invariant, allowing any rank; b3 modifies rank. 2156 _, r = control_flow_ops.while_loop( 2157 c, b3, [i, x], 2158 [i.get_shape(), tensor_shape.TensorShape([None])]) 2159 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2160 2161 # Shape invariant with ndims=None. Technically, this isn't supported 2162 # according to the docs, but we support it for backwards compatibility. 2163 _, r = control_flow_ops.while_loop( 2164 c, b1, [i, x], 2165 [i.get_shape(), tensor_shape.TensorShape(None)]) 2166 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2167 _, r = control_flow_ops.while_loop( 2168 c, b3, [i, x], 2169 [i.get_shape(), tensor_shape.TensorShape(None)]) 2170 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2171 2172 @test_util.disable_control_flow_v2("b/131265085") 2173 @test_util.run_v1_only("b/131265085") 2174 def testWhileBadShapeSparseTensor(self): 2175 values = constant_op.constant([2.0, 4.0], name="values") 2176 indices = constant_op.constant([[0], [3]], 2177 dtype=dtypes.int64, 2178 name="indices") 2179 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2180 i = constant_op.constant(0) 2181 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2182 c = lambda i, _: i < 10 2183 b1 = lambda i, x: [i+1, x] 2184 def b2(i, x): # modifies rank. (shape of all components is changed.) 2185 return [ 2186 i + 1, 2187 sparse_tensor.SparseTensor( 2188 array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, 2189 array_ops.concat([x.dense_shape, [10]], axis=0)) 2190 ] 2191 2192 # Explicit shape invariant, with a specific (incompatible) rank. 2193 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2194 control_flow_ops.while_loop( 2195 c, b1, [i, x], 2196 [i.get_shape(), tensor_shape.TensorShape([5])]) 2197 2198 # Default shape invariant, but b2 modifies rank (which is not allowed). 2199 with self.assertRaises(ValueError): 2200 control_flow_ops.while_loop(c, b2, [i, x]) 2201 2202 def testWhileShapeInferenceIndexedSlices(self): 2203 with self.cached_session(): 2204 values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values") 2205 indices = constant_op.constant([0, 3], name="indices") 2206 shape = constant_op.constant([10, 2], name="dense_shape") 2207 i = constant_op.constant(0) 2208 x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape) 2209 2210 def c(i, _): 2211 return i < 10 2212 2213 def b(i, x): 2214 return [ 2215 i + 1, 2216 indexed_slices.IndexedSlices(x.values * 2.0, x.indices, 2217 x.dense_shape) 2218 ] 2219 2220 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2221 self.assertEqual(r.dense_shape.get_shape()[0], 2) 2222 self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2])) 2223 2224 _, r = control_flow_ops.while_loop( 2225 c, b, [i, x], 2226 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 2227 self.assertEqual(r.dense_shape.get_shape()[0], 2) 2228 self.assertTrue(r.values.get_shape().is_compatible_with([None, 2])) 2229 2230 @test_util.disable_control_flow_v2("b/131265085") 2231 @test_util.run_v1_only("b/131265085") 2232 def testWhileBadShapeIndexedSlices(self): 2233 values = constant_op.constant([2.0, 4.0], name="values") 2234 indices = constant_op.constant([[0], [3]], 2235 dtype=dtypes.int64, 2236 name="indices") 2237 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2238 i = constant_op.constant(0) 2239 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2240 c = lambda i, _: 10 2241 b = lambda i, x: [i+1, x] 2242 2243 # Explicit shape invariant, with a specific (incompatible) rank. 2244 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2245 control_flow_ops.while_loop( 2246 c, b, [i, x], 2247 [i.get_shape(), tensor_shape.TensorShape([5])]) 2248 2249 def testWhileShapeInferenceRaggedTensor(self): 2250 i = constant_op.constant(0) 2251 x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 2252 c = lambda i, _: i < 10 2253 2254 def b1(i, x): # Adds new values to rows (but doesn't create new rows) 2255 return [ 2256 i + 1, 2257 array_ops.concat([x, x], axis=1) 2258 ] 2259 2260 def b2(i, x): # Adds new rows. 2261 return [ 2262 i + 1, 2263 array_ops.concat([x, x], axis=0) 2264 ] 2265 2266 def check_shapes(r, values, splits): 2267 self.assertTrue(r.values.shape.is_compatible_with(values)) 2268 self.assertTrue(r.row_splits.shape.is_compatible_with(splits)) 2269 2270 # Default shape invariant; b1 adds new values to rows. 2271 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 2272 check_shapes(r, values=[None], splits=[4]) 2273 2274 # Default shape invariant; b2 adds new rows (not allowed). 2275 if not context.executing_eagerly(): 2276 with self.assertRaises(ValueError): 2277 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 2278 2279 # Explicit shape invariant; b1 adds new values to rows. 2280 # (deprecated: use TensorShape instead of RaggedTensorSpec) 2281 _, r = control_flow_ops.while_loop( 2282 c, b1, [i, x], 2283 [i.get_shape(), tensor_shape.TensorShape([None, None])]) 2284 check_shapes(r, values=[None], splits=[None]) 2285 2286 # Explicit shape invariant; b1 adds new values to rows. 2287 _, r = control_flow_ops.while_loop( 2288 c, b1, [i, x], 2289 [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None], 2290 dtypes.int32)]) 2291 check_shapes(r, values=[None], splits=[None]) 2292 2293 # Explicit shape invariant; b2 adds new rows. 2294 _, r = control_flow_ops.while_loop( 2295 c, b2, [i, x], 2296 [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None], 2297 dtypes.int32)]) 2298 check_shapes(r, values=[None], splits=[None]) 2299 2300 def testWhileShapeInferenceRaggedTensorRaggedRank2(self): 2301 i = constant_op.constant(0) 2302 x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]], 2303 [[], [8, 9, 10]]]) 2304 c = lambda i, _: i < 10 2305 def b(i, x): 2306 return [ 2307 i + 1, 2308 array_ops.concat([x, x[..., i:i+1]], axis=-1) 2309 ] 2310 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2311 self.assertEqual(r.row_splits.shape.as_list(), [3]) 2312 self.assertIn(r.values.row_splits.shape.as_list(), ([6], [None])) 2313 self.assertIn(r.values.values.shape.as_list(), ([49], [None])) 2314 2315 def testWhileShapeInvariantTensorSpec(self): 2316 i = constant_op.constant(0) 2317 x = constant_op.constant([1]) 2318 c = lambda i, _: i < 10 2319 b = lambda i, x: (i + 1, array_ops.stack([x, x])) 2320 shape_invariants = [ 2321 tensor_spec.TensorSpec([], dtype=dtypes.int32), 2322 tensor_spec.TensorSpec(None, dtype=dtypes.int32)] 2323 control_flow_ops.while_loop(c, b, [i, x], shape_invariants) 2324 2325 # TODO(b/131265085) Remove this decorator when bug is fixed. 2326 @test_util.build_as_function_and_v1_graph 2327 def testWhileShapeInvariantWrongTypeSpecType(self): 2328 c = lambda i, _: i < 10 2329 b = lambda i, x: (i + 1, x) 2330 i = constant_op.constant(0) 2331 x = sparse_tensor.SparseTensor([[0]], [1.0], [10]) 2332 shape_invariants = [ 2333 tensor_spec.TensorSpec([], dtype=dtypes.int32), 2334 sparse_tensor.SparseTensorSpec([None])] 2335 control_flow_ops.while_loop(c, b, [i, x], shape_invariants) 2336 2337 x2 = constant_op.constant([1]) 2338 with self.assertRaises(TypeError): 2339 control_flow_ops.while_loop(c, b, [i, x2], shape_invariants) 2340 2341 x3 = ragged_factory_ops.constant([[1, 2], [3]]) 2342 with self.assertRaises(TypeError): 2343 control_flow_ops.while_loop(c, b, [i, x3], shape_invariants) 2344 2345 i2 = constant_op.constant(0.0) 2346 with self.assertRaises(TypeError): 2347 control_flow_ops.while_loop(c, b, [i2, x], shape_invariants) 2348 2349 # TODO(b/131265085) Remove this decorator when bug is fixed. 2350 @test_util.build_as_function_and_v1_graph 2351 def testWhileShapeInvariantBadType(self): 2352 i = constant_op.constant(0) 2353 x = constant_op.constant([1]) 2354 c = lambda i, _: i < 10 2355 b = lambda i, x: (i + 1, x) 2356 with self.assertRaises((ValueError, TypeError)): 2357 control_flow_ops.while_loop(c, b, [i, x], ["foo", "bar"]) 2358 2359 def _testNestedWhile_1(self, use_gpu): 2360 with self.cached_session(use_gpu=use_gpu): 2361 n = constant_op.constant(0) 2362 2363 def cpu_sum(s): 2364 c = lambda i, s: math_ops.less(i, 10) 2365 2366 def b(i, s): 2367 i1 = math_ops.add(i, 1) 2368 with ops.device("/cpu:0"): 2369 s1 = math_ops.add(i, s) 2370 return i1, s1 2371 2372 _, r_s = control_flow_ops.while_loop(c, b, [n, s]) 2373 return r_s 2374 2375 c = lambda x: math_ops.less(x, 200) 2376 b = lambda x: math_ops.add(x, cpu_sum(n)) 2377 r = control_flow_ops.while_loop(c, b, [n]) 2378 self.assertEqual(225, self.evaluate(r)) 2379 2380 def testNestedWhile_1(self): 2381 self._testNestedWhile_1(use_gpu=False) 2382 self._testNestedWhile_1(use_gpu=True) 2383 2384 def _testNestedWhile_2(self, use_gpu): 2385 # Test the cases that A -> Enter and Exit -> A are partitioned. 2386 with self.cached_session(use_gpu=use_gpu): 2387 s0 = constant_op.constant(2.0) 2388 2389 def inner_loop(s): 2390 c = lambda s: math_ops.less(s, 20.0) 2391 2392 def b(s): 2393 s1 = math_ops.add(s, s) 2394 return s1 2395 2396 r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1) 2397 return r_s 2398 2399 outer_c = lambda x: math_ops.less(x, 3000.0) 2400 2401 def outer_b(x): 2402 x = logging_ops.Print(x, [x]) # Edge "Print -> Enter" is partitioned 2403 x = inner_loop(x) 2404 with ops.device("/cpu:0"): 2405 x = math_ops.square(x) # Edge "Exit -> Square" is partitioned 2406 return x 2407 2408 r = control_flow_ops.while_loop( 2409 outer_c, outer_b, [s0], parallel_iterations=1) 2410 self.assertEqual(1048576.0, self.evaluate(r)) 2411 2412 def testNestedWhile_2(self): 2413 self._testNestedWhile_2(use_gpu=False) 2414 self._testNestedWhile_2(use_gpu=True) 2415 2416 @test_util.run_v1_only("b/120545219") 2417 def testWhileWithControl_1(self): 2418 with self.cached_session(): 2419 n = constant_op.constant(0) 2420 r = constant_op.constant(0) 2421 condition = lambda n_, r_: math_ops.less(n_, 10) 2422 2423 def body(n_, r_): 2424 n_ = math_ops.add(n_, 1) 2425 with r_.graph.control_dependencies([r_]): 2426 r_ = constant_op.constant(12) 2427 return [n_, r_] 2428 2429 res = control_flow_ops.while_loop( 2430 condition, body, [n, r], parallel_iterations=1) 2431 self.assertAllEqual(12, res[1]) 2432 2433 @test_util.run_deprecated_v1 2434 def testWhileWithControl_2(self): 2435 with self.cached_session(): 2436 r = constant_op.constant(0) 2437 condition = lambda r_: math_ops.less(r_, 10) 2438 2439 def body(r_): 2440 with r_.graph.control_dependencies([r_]): 2441 r_ = constant_op.constant(12) 2442 return [r_] 2443 2444 res = control_flow_ops.while_loop( 2445 condition, body, [r], parallel_iterations=1) 2446 self.assertAllEqual(12, self.evaluate(res)) 2447 2448 @test_util.run_v1_only("b/120545219") 2449 def testWhileWithControl_3(self): 2450 with self.cached_session() as sess: 2451 b = array_ops.placeholder(dtypes.bool) 2452 c = constant_op.constant(1) 2453 x0 = constant_op.constant(0) 2454 with ops.control_dependencies([b]): 2455 r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0]) 2456 self.assertEqual(10, sess.run(r, {b: True})) 2457 2458 @test_util.run_v1_only("b/120545219") 2459 def testWhileWithControl_4(self): 2460 with self.cached_session() as sess: 2461 b = array_ops.placeholder(dtypes.bool) 2462 c = constant_op.constant(1) 2463 x0 = constant_op.constant(0) 2464 with ops.control_dependencies([b]): 2465 r = control_flow_ops.while_loop( 2466 lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0]) 2467 self.assertEqual(10, sess.run(r, {b: True})) 2468 2469 @test_util.run_v1_only("b/120545219") 2470 def testWhileWithControl_5(self): 2471 with self.cached_session() as sess: 2472 b = array_ops.placeholder(dtypes.bool) 2473 c = constant_op.constant(1) 2474 x0 = constant_op.constant(0) 2475 2476 def body(x): 2477 with ops.control_dependencies([b]): 2478 return x + c 2479 2480 r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0]) 2481 self.assertEqual(10, sess.run(r, {b: True})) 2482 2483 def testWhileCondWithControl(self): 2484 # Ensure that no control edges by an outer control dependency context are 2485 # added to nodes inside cond/while contexts. 2486 with self.cached_session() as sess: 2487 const_true = lambda: constant_op.constant(True) 2488 const_false = lambda: constant_op.constant(False) 2489 cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false) 2490 body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i) 2491 2492 with ops.control_dependencies([control_flow_ops.no_op()]): 2493 loop = control_flow_ops.while_loop(cond, body, 2494 (constant_op.constant(5),)) 2495 self.assertEqual(0, self.evaluate(loop)) 2496 2497 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2498 @test_util.run_v1_only("b/120545219") 2499 def testWhileCondWithControl_1(self): 2500 with self.cached_session(): 2501 v = variable_scope.get_variable( 2502 "v", [], initializer=init_ops.constant_initializer(2)) 2503 i0 = constant_op.constant(0) 2504 with ops.control_dependencies([i0]): 2505 2506 def loop_condition(i): 2507 return i < 4 2508 2509 def loop_body(i): 2510 some_cond = control_flow_ops.cond( 2511 constant_op.constant(True), 2512 lambda: state_ops.assign(v, math_ops.square(v)), lambda: v) 2513 with ops.control_dependencies([some_cond]): 2514 return i + 1 2515 2516 r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) 2517 self.evaluate(variables.global_variables_initializer()) 2518 self.assertEqual(4, self.evaluate(r)) 2519 self.assertAllClose(65536.0, self.evaluate(v)) 2520 2521 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2522 @test_util.run_v1_only("b/120545219") 2523 def testWhileCondExitControl(self): 2524 2525 with self.cached_session(): 2526 v = variables.Variable(1) 2527 2528 def false_branch(): 2529 cond = lambda i: i < 100 2530 2531 def body(i): 2532 x = state_ops.assign(v, i) 2533 return x + 1 2534 2535 loop = control_flow_ops.while_loop(cond, body, [0]) 2536 # Make sure to handle correctly control edge from Exit to a node. 2537 with ops.control_dependencies([loop]): 2538 return constant_op.constant(6.0) 2539 2540 r = control_flow_ops.cond( 2541 constant_op.constant(False), lambda: constant_op.constant(1.0), 2542 false_branch) 2543 self.evaluate(variables.global_variables_initializer()) 2544 self.assertEqual(6.0, self.evaluate(r)) 2545 self.assertEqual(99, self.evaluate(v)) 2546 2547 def testCondWhile_1(self): 2548 2549 with self.cached_session(): 2550 n = ops.convert_to_tensor(0, name="n") 2551 c = lambda x: math_ops.less(x, 10) 2552 b = lambda x: math_ops.add(x, 1) 2553 r = control_flow_ops.cond( 2554 math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]), 2555 lambda: n) 2556 self.assertAllEqual(10, self.evaluate(r)) 2557 2558 def testCondWhile_2(self): 2559 2560 with self.cached_session(): 2561 n = ops.convert_to_tensor(0) 2562 c = lambda x: math_ops.less(x, 10) 2563 b = lambda x: math_ops.add(x, 1) 2564 r = control_flow_ops.cond( 2565 math_ops.less(1, 0), lambda: math_ops.add(n, 1), 2566 lambda: control_flow_ops.while_loop(c, b, [n])) 2567 self.assertAllEqual(10, self.evaluate(r)) 2568 2569 def _testCondWhile_3(self, use_gpu): 2570 with self.cached_session(use_gpu=use_gpu) as sess: 2571 p = array_ops.placeholder(dtypes.bool) 2572 n = constant_op.constant(0.0) 2573 2574 def c(x): 2575 return math_ops.less(x, 10.0) 2576 2577 def b(x): 2578 with ops.device("/cpu:0"): 2579 x1 = math_ops.add(x, 1.0) 2580 return x1 2581 2582 r = control_flow_ops.cond(p, 2583 lambda: control_flow_ops.while_loop(c, b, [n]), 2584 lambda: math_ops.multiply(n, 2.0)) 2585 r1 = gradients_impl.gradients(r, [n]) 2586 self.assertEqual(10., sess.run(r, {p: True})) 2587 self.assertEqual([1.0], sess.run(r1, {p: True})) 2588 self.assertEqual(0.0, sess.run(r, {p: False})) 2589 self.assertEqual([2.0], sess.run(r1, {p: False})) 2590 2591 @test_util.run_deprecated_v1 2592 def testCondWhile_3(self): 2593 self._testCondWhile_3(use_gpu=False) 2594 self._testCondWhile_3(use_gpu=True) 2595 2596 def testWhileCond_1(self): 2597 2598 with self.cached_session(): 2599 i = ops.convert_to_tensor(0, name="i") 2600 n = ops.convert_to_tensor(10, name="n") 2601 one = ops.convert_to_tensor(1, name="one") 2602 c = lambda x: math_ops.less(x, n) 2603 # pylint: disable=undefined-variable 2604 # for OSS build 2605 b = lambda x: control_flow_ops.cond( 2606 constant_op.constant(True), 2607 lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one)) 2608 # pylint: enable=undefined-variable 2609 r = control_flow_ops.while_loop(c, b, [i]) 2610 self.assertAllEqual(10, self.evaluate(r)) 2611 2612 def testWhileCond_2(self): 2613 2614 with self.cached_session(): 2615 n = ops.convert_to_tensor(0, name="n") 2616 c = lambda x: math_ops.less(x, 10) 2617 b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n) 2618 r = control_flow_ops.while_loop(c, b, [n]) 2619 self.assertAllEqual(10, self.evaluate(r)) 2620 2621 def testWhileCond_3(self): 2622 2623 with self.cached_session(): 2624 n = ops.convert_to_tensor(0) 2625 c = lambda x: math_ops.less(x, 10) 2626 # pylint: disable=undefined-variable 2627 # for OSS build 2628 b = lambda x: control_flow_ops.cond(math_ops.less(0, 1), 2629 lambda: math_ops.add(x, 1), 2630 lambda: math_ops.subtract(x, 1)) 2631 # pylint: enable=undefined-variable 2632 r = control_flow_ops.while_loop(c, b, [n]) 2633 self.assertAllEqual(10, self.evaluate(r)) 2634 2635 @test_util.run_deprecated_v1 2636 def testWhileCondGradMultiDevice(self): 2637 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 2638 allow_soft_placement=True) 2639 with self.cached_session(config=config) as sess: 2640 pred = array_ops.placeholder(dtypes.bool, []) 2641 x_init = constant_op.constant(1.0) 2642 2643 with ops.device("/cpu:0"): 2644 z = control_flow_ops.while_loop( 2645 lambda i, _: i < 3, 2646 lambda i, x: (i + 1, control_flow_ops.cond( 2647 pred, lambda: x * 2.0, lambda: 10.0)), 2648 [0, x_init]) 2649 2650 with ops.device("/cpu:1"): 2651 grad = gradients_impl.gradients(z, x_init)[0] 2652 2653 with ops.device("/cpu:0"): 2654 grad_grad = gradients_impl.gradients(grad, x_init)[0] 2655 2656 self.assertEqual(sess.run(grad, {pred: True}), 8.0) 2657 self.assertEqual(sess.run(grad, {pred: False}), 0.0) 2658 2659 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 2660 return 2661 2662 self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0) 2663 self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0) 2664 2665 # NOTE: It is ok to have parallel_iterations > 1 2666 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2667 @test_util.run_deprecated_v1 2668 def testWhileUpdateVariable_1(self): 2669 with self.cached_session(): 2670 select = variables.Variable([3.0, 4.0, 5.0]) 2671 n = constant_op.constant(0) 2672 2673 def loop_iterator(j): 2674 return math_ops.less(j, 3) 2675 2676 def loop_body(j): 2677 ns = state_ops.scatter_update(select, j, 10.0) 2678 nj = math_ops.add(j, 1) 2679 op = control_flow_ops.group(ns) 2680 nj = control_flow_ops.with_dependencies([op], nj) 2681 return [nj] 2682 2683 r = control_flow_ops.while_loop( 2684 loop_iterator, loop_body, [n], parallel_iterations=1) 2685 self.evaluate(variables.global_variables_initializer()) 2686 self.assertEqual(3, self.evaluate(r)) 2687 result = self.evaluate(select) 2688 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2689 2690 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2691 @test_util.run_v1_only("b/120545219") 2692 def testWhileUpdateVariable_2(self): 2693 with self.cached_session(): 2694 select1 = variables.Variable([3.0, 4.0, 5.0]) 2695 select2 = variables.Variable([3.0, 4.0, 5.0]) 2696 n = constant_op.constant(0) 2697 2698 def loop_iterator(j): 2699 return math_ops.less(j, 3) 2700 2701 def loop_body(j): 2702 ns1 = state_ops.scatter_update(select1, j, 10.0) 2703 ns2 = state_ops.scatter_update(select2, j, 10.0) 2704 nj = math_ops.add(j, 1) 2705 op = control_flow_ops.group(ns1, ns2) 2706 nj = control_flow_ops.with_dependencies([op], nj) 2707 return [nj] 2708 2709 r = control_flow_ops.while_loop( 2710 loop_iterator, loop_body, [n], parallel_iterations=1) 2711 self.evaluate(variables.global_variables_initializer()) 2712 self.assertEqual(3, self.evaluate(r)) 2713 result1 = self.evaluate(select1) 2714 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1) 2715 result2 = self.evaluate(select2) 2716 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2) 2717 2718 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2719 @test_util.run_v1_only("b/120545219") 2720 def testWhileUpdateVariable_3(self): 2721 with self.cached_session(): 2722 select = variables.Variable([3.0, 4.0, 5.0]) 2723 n = constant_op.constant(0) 2724 2725 def loop_iterator(j, _): 2726 return math_ops.less(j, 3) 2727 2728 def loop_body(j, _): 2729 ns = state_ops.scatter_update(select, j, 10.0) 2730 nj = math_ops.add(j, 1) 2731 return [nj, ns] 2732 2733 r = control_flow_ops.while_loop( 2734 loop_iterator, 2735 loop_body, [n, array_ops.identity(select)], 2736 parallel_iterations=1) 2737 self.evaluate(variables.global_variables_initializer()) 2738 result = r[1] 2739 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2740 2741 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2742 @test_util.run_v1_only("b/120545219") 2743 def testWhileUpdateVariable_4(self): 2744 with self.cached_session(): 2745 var_a = variables.Variable(0, name="a") 2746 var_b = variables.Variable(0, name="b") 2747 self.evaluate(variables.global_variables_initializer()) 2748 2749 c = constant_op.constant(0, name="c") 2750 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2751 2752 # Loop condition 2753 def pred(i): 2754 return math_ops.less(i, 10) 2755 2756 # Loop body 2757 def loop_body(i): 2758 asn2 = state_ops.assign_add(var_b, asn1, name="b_add") 2759 with ops.control_dependencies([asn2]): 2760 ni = math_ops.add(i, 1, name="i_add") 2761 return ni 2762 2763 lpa = control_flow_ops.while_loop( 2764 pred, loop_body, [c], parallel_iterations=1) 2765 2766 self.assertEqual(0, self.evaluate(var_b)) 2767 self.evaluate(lpa) # Run the loop 2768 self.assertEqual(10, self.evaluate(var_b)) 2769 2770 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2771 @test_util.run_v1_only("b/120545219") 2772 def testWhileUpdateVariable_5(self): 2773 with self.cached_session(): 2774 # Create some variables. 2775 var_a = variables.Variable(0, name="a") 2776 var_b = variables.Variable(0, name="b") 2777 self.evaluate(variables.global_variables_initializer()) 2778 2779 # Change condition to check var_b 2780 def pred(_): 2781 return math_ops.less(var_b, 10) 2782 2783 # Change body to increment var_b 2784 def loop_body(i): 2785 asn1 = state_ops.assign_add( 2786 var_a, constant_op.constant(1), name="a_add") 2787 asn2 = state_ops.assign_add( 2788 var_b, constant_op.constant(1), name="b_add") 2789 with ops.control_dependencies([asn1, asn2]): 2790 inc_b = array_ops.identity(var_b) 2791 return inc_b 2792 2793 lpa = control_flow_ops.while_loop( 2794 pred, loop_body, [var_b], parallel_iterations=1, name="loop") 2795 2796 self.assertEqual(0, self.evaluate(var_b)) 2797 self.evaluate(lpa) # Run the loop 2798 self.assertEqual(10, self.evaluate(var_a)) 2799 self.assertEqual(10, self.evaluate(var_b)) 2800 2801 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2802 @test_util.run_v1_only("b/120545219") 2803 def testWhileUpdateVariable_6(self): 2804 with self.cached_session(): 2805 # Create some variables. 2806 var_a = variables.Variable(0, name="a") 2807 var_b = variables.Variable(0, name="b") 2808 c = constant_op.constant(0) 2809 self.evaluate(variables.global_variables_initializer()) 2810 2811 # Loop condition 2812 def pred(i): 2813 return math_ops.less(i, 10) 2814 2815 # Loop body 2816 def loop_body(i): 2817 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2818 with ops.control_dependencies([asn1]): 2819 asn2 = state_ops.assign_add(var_b, var_a, name="b_add") 2820 with ops.control_dependencies([asn2]): 2821 ni = math_ops.add(i, 1, name="i_add") 2822 return ni 2823 2824 lpa = control_flow_ops.while_loop( 2825 pred, loop_body, [c], parallel_iterations=1, name="loop") 2826 2827 self.assertEqual(0, self.evaluate(var_b)) 2828 self.evaluate(lpa) # Run the loop 2829 self.assertEqual(55, self.evaluate(var_b)) 2830 self.assertEqual(10, self.evaluate(var_a)) 2831 2832 @test_util.run_v1_only("b/120545219") 2833 def testWhileQueue_1(self): 2834 with self.cached_session(): 2835 q = data_flow_ops.FIFOQueue(-1, dtypes.int32) 2836 i = constant_op.constant(0) 2837 2838 def c(i): 2839 return math_ops.less(i, 10) 2840 2841 def b(i): 2842 ni = math_ops.add(i, 1) 2843 ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni) 2844 return ni 2845 2846 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2847 self.assertEqual([10], self.evaluate(r)) 2848 for i in range(10): 2849 self.assertEqual([i], self.evaluate(q.dequeue())) 2850 2851 @test_util.run_v1_only("b/120545219") 2852 def testWhileTimeOut(self): 2853 run_options = config_pb2.RunOptions(timeout_in_ms=1) 2854 with self.cached_session() as sess: 2855 n = constant_op.constant(0) 2856 c = lambda x: True 2857 b = lambda x: math_ops.add(x, 1) 2858 r = control_flow_ops.while_loop(c, b, [n]) 2859 with self.assertRaises(errors_impl.DeadlineExceededError): 2860 sess.run(r, options=run_options) 2861 2862 @test_util.disable_control_flow_v2("b/117119329 (stack)") 2863 @test_util.run_v1_only("b/120545219") 2864 def testWhileStack_1(self): 2865 with self.cached_session(): 2866 s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo") 2867 i = constant_op.constant(0) 2868 2869 def c(i): 2870 return math_ops.less(i, 10) 2871 2872 def b(i): 2873 ni = math_ops.add(i, 1) 2874 ni = control_flow_ops.with_dependencies( 2875 [gen_data_flow_ops.stack_push_v2(s, i)], ni) 2876 return ni 2877 2878 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2879 2880 x = constant_op.constant(0) 2881 2882 def c1(i, _): 2883 return math_ops.greater(i, 0) 2884 2885 def b1(i, x): 2886 ni = math_ops.subtract(i, 1) 2887 nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32) 2888 return [ni, nx] 2889 2890 _, rx = control_flow_ops.while_loop( 2891 c1, 2892 b1, [r, x], 2893 [r.get_shape(), tensor_shape.unknown_shape()], 2894 parallel_iterations=1) 2895 self.assertEqual(45, self.evaluate(rx)) 2896 2897 def _testWhileGrad_ColocateGradients(self, colocate): 2898 gpu_dev_name = test.gpu_device_name() if test.is_gpu_available( 2899 ) else "/device:CPU:0" 2900 2901 graph = ops.Graph() 2902 with graph.as_default(): 2903 v = constant_op.constant(2.0, name="v") 2904 c = lambda v: math_ops.less(v, 100.0) 2905 2906 def b(x): 2907 with ops.device(gpu_dev_name): 2908 return math_ops.square(x) 2909 2910 loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2911 r = gradients_impl.gradients( 2912 loop, v, colocate_gradients_with_ops=colocate)[0] 2913 2914 r_ops = graph.get_operations() 2915 r_devices = [(op.name, op.device) for op in r_ops] 2916 2917 self.assertTrue(any("Square" in op.name for op in r_ops)) 2918 2919 for (name, dev) in r_devices: 2920 if not colocate and name.endswith("Square"): 2921 # Only forward graph contain gpu in Square device 2922 self.assertTrue(gpu_dev_name in dev) 2923 elif colocate and "Square" in name: 2924 # Forward and backward graphs contain gpu in Square/Square_grad devices 2925 self.assertTrue(gpu_dev_name in dev) 2926 else: 2927 self.assertFalse(gpu_dev_name in dev) 2928 2929 with self.session(graph=graph) as sess: 2930 self.assertAllClose(1024.0, self.evaluate(r)) 2931 2932 @test_util.disable_control_flow_v2("b/116351701 (colocation)") 2933 @test_util.run_v1_only("b/120545219") 2934 def testWhileGrad_ColocateGradients(self): 2935 self._testWhileGrad_ColocateGradients(colocate=False) 2936 self._testWhileGrad_ColocateGradients(colocate=True) 2937 2938 @test_util.run_v1_only("b/120545219") 2939 def testWhileGrad_Square(self): 2940 with self.cached_session(): 2941 v = constant_op.constant(2.0, name="v") 2942 c = lambda v: math_ops.less(v, 100.0) 2943 b = math_ops.square 2944 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2945 r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v) 2946 2947 r = gradients_impl.gradients(r, v)[0] 2948 self.assertAllClose(1024.0, self.evaluate(r)) 2949 2950 @test_util.run_v1_only("b/120545219") 2951 def testWhileGrad_Shape(self): 2952 with self.cached_session(): 2953 x = array_ops.placeholder(dtypes.float32, shape=[None]) 2954 v = constant_op.constant([2.0], name="v") 2955 n = constant_op.constant(0, name="n") 2956 c = lambda i, v: math_ops.less(i, 5) 2957 b = lambda i, v: [i + 1, math_ops.multiply(x, v)] 2958 r = control_flow_ops.while_loop( 2959 c, 2960 b, [n, v], 2961 [n.get_shape(), tensor_shape.unknown_shape()], 2962 parallel_iterations=1) 2963 2964 r = gradients_impl.gradients(r[1], x)[0] 2965 self.assertEqual([None], r.get_shape().as_list()) 2966 self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) 2967 2968 @test_util.run_deprecated_v1 2969 def testWhileGrad_BaseShape(self): 2970 with self.cached_session() as sess: 2971 x = array_ops.placeholder(dtypes.float32, [None]) 2972 v0 = constant_op.constant([2.0, 2.0], name="v") 2973 c = lambda v: constant_op.constant(False) 2974 b = lambda v: math_ops.multiply(v, x) 2975 r = control_flow_ops.while_loop(c, b, [v0]) 2976 y = math_ops.square(x) 2977 2978 r = gradients_impl.gradients([r, y], x)[0] 2979 self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]})) 2980 2981 @test_util.run_deprecated_v1 2982 @test_util.enable_output_all_intermediates 2983 def testWhileGradAfterSessionRun(self): 2984 v0 = constant_op.constant(2.) 2985 r = control_flow_ops.while_loop( 2986 lambda _: True, lambda v: v * v, [v0], maximum_iterations=3) 2987 2988 self.assertAllEqual(r, 256.) 2989 grad = gradients_impl.gradients(r, v0)[0] 2990 self.assertAllClose(grad, 1024.) 2991 2992 @test_util.run_deprecated_v1 2993 @test_util.enable_output_all_intermediates 2994 def testNestedWhileGradAfterSessionRun(self): 2995 v0 = constant_op.constant(2.) 2996 2997 def body(v): 2998 inner_v0 = constant_op.constant(1.) 2999 return control_flow_ops.while_loop( 3000 lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2) 3001 3002 r = control_flow_ops.while_loop( 3003 lambda _: True, body, [v0], maximum_iterations=3) 3004 3005 self.assertAllEqual(r, 256.) 3006 grad = gradients_impl.gradients(r, v0)[0] 3007 self.assertAllClose(grad, 1024.) 3008 3009 @test_util.run_v1_only("b/120545219") 3010 def testWhileGrad_MultipleUses(self): 3011 with self.cached_session(): 3012 v = constant_op.constant(2.0, name="v") 3013 c = lambda v: math_ops.less(v, 100.0) 3014 b = math_ops.square 3015 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3016 r = math_ops.multiply(r, r) 3017 3018 r = gradients_impl.gradients(r, v)[0] 3019 self.assertEqual(524288.0, self.evaluate(r)) 3020 3021 @test_util.run_v1_only("b/120545219") 3022 def testWhileGrad_LoopAdd(self): 3023 with self.cached_session(): 3024 v = constant_op.constant(2.0, name="v") 3025 c = lambda v: math_ops.less(v, 100.0) 3026 b = math_ops.square 3027 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3028 r = math_ops.add(r, r) 3029 3030 r = gradients_impl.gradients(r, v)[0] 3031 self.assertAllClose(2048.0, self.evaluate(r)) 3032 3033 def _testWhileGrad_Mul(self, use_gpu, p_iters): 3034 with self.cached_session(use_gpu=use_gpu) as sess: 3035 a = constant_op.constant(3.0, name="a") 3036 v = constant_op.constant(2.0, name="v") 3037 c = lambda v: math_ops.less(v, 100.0) 3038 b = lambda v: math_ops.multiply(v, a) 3039 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters) 3040 3041 grad_a, grad_v = gradients_impl.gradients(r, [a, v]) 3042 grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v]) 3043 self.assertAllClose(216.0, grad_a_val) 3044 self.assertAllClose(81.0, grad_v_val) 3045 3046 @test_util.run_deprecated_v1 3047 def testWhileGrad_Mul(self): 3048 self._testWhileGrad_Mul(use_gpu=False, p_iters=1) 3049 self._testWhileGrad_Mul(use_gpu=False, p_iters=10) 3050 self._testWhileGrad_Mul(use_gpu=True, p_iters=1) 3051 self._testWhileGrad_Mul(use_gpu=True, p_iters=10) 3052 3053 def testWhileGradInControlDeps(self): 3054 3055 @def_function.function 3056 def f(): 3057 x_init = constant_op.constant(2.) 3058 loop_cond = lambda i, x: math_ops.less(i, 2) 3059 loop_body = lambda i, x: [i + 1, x**2] 3060 _, x = control_flow_ops.while_loop(loop_cond, loop_body, [0, x_init]) 3061 with ops.control_dependencies([x]): 3062 (grad,) = gradients_impl.gradients(x, x_init) 3063 return grad 3064 3065 self.assertAllEqual(f(), 4. * 2.**3) # 4 * x_init ^ 3 3066 3067 @test_util.run_deprecated_v1 3068 def testTfFunctionInV1WhileLoop(self): 3069 3070 # This test specifically tests that creating a Const node inside a 3071 # tf.function inside a v1 while_loop while inlining is turned on works. 3072 config = opt_cfg() 3073 assert config.graph_options.optimizer_options.do_function_inlining 3074 with session.Session(config=config): 3075 3076 @def_function.function 3077 def loop_body(i): 3078 # Here we create the const. 3079 return i + 1. 3080 3081 loop_cond = lambda i: True 3082 x = control_flow_ops.while_loop( 3083 loop_cond, loop_body, [0.], maximum_iterations=5) 3084 self.assertAllEqual(x, 5.) 3085 3086 def _testNestedWhileCondWhileGrad(self, use_gpu): 3087 3088 with self.cached_session(use_gpu=use_gpu): 3089 v = constant_op.constant(1.0) 3090 3091 def inner_loop(s): 3092 z = constant_op.constant(0) 3093 c = lambda i, x: math_ops.less(i, 4) 3094 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3095 return control_flow_ops.while_loop(c, b, [z, s]) 3096 3097 c = lambda x: math_ops.less(x, 128.0) 3098 3099 def b(x): 3100 return control_flow_ops.cond( 3101 constant_op.constant(True), 3102 lambda: math_ops.square(inner_loop(x)[1]), 3103 lambda: math_ops.multiply(x, 2.0)) 3104 3105 r = control_flow_ops.while_loop(c, b, [v]) 3106 r = gradients_impl.gradients(r, v)[0] 3107 self.assertAllClose(512.0, self.evaluate(r)) 3108 3109 @test_util.run_deprecated_v1 3110 def testNestedWhileCondWhileGrad(self): 3111 self._testNestedWhileCondWhileGrad(use_gpu=False) 3112 3113 @test_util.run_deprecated_v1 3114 def testNestedWhileCondWhileGradGpu(self): 3115 self._testNestedWhileCondWhileGrad(use_gpu=True) 3116 3117 @test_util.run_v1_only("b/120545219") 3118 def testWhileGrad_Variable(self): 3119 with self.cached_session(): 3120 a = variables.Variable(3.0) 3121 v = constant_op.constant(2.0, name="v") 3122 c = lambda v: math_ops.less(v, 100.0) 3123 b = lambda v: math_ops.multiply(v, a) 3124 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3125 3126 r = gradients_impl.gradients(r, a) 3127 self.evaluate(variables.global_variables_initializer()) 3128 self.assertAllClose(216.0, r[0]) 3129 3130 @test_util.run_deprecated_v1 3131 def testWhileGrad_ResourceVariable(self): 3132 with self.cached_session(): 3133 a = resource_variable_ops.ResourceVariable(3.0) 3134 v = constant_op.constant(2.0, name="v") 3135 c = lambda v: math_ops.less(v, 100.0) 3136 b = lambda v: math_ops.multiply(v, a) 3137 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3138 3139 g = gradients_impl.gradients(r, a) 3140 self.evaluate(variables.global_variables_initializer()) 3141 self.assertAllClose(216.0, g[0]) 3142 3143 def testWhileGrad_EagerResourceVariable(self): 3144 with context.eager_mode(): 3145 a = resource_variable_ops.ResourceVariable( 3146 np.ones([2, 2], dtype=np.float32)) 3147 v = constant_op.constant(1.0) 3148 3149 @eager_function.defun 3150 def fn(): 3151 r = control_flow_ops.while_loop( 3152 lambda i, _: i < 2, 3153 lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v), 3154 [0, 1.0])[1] 3155 return gradients_impl.gradients(r, [v])[0] 3156 3157 self.assertEqual(self.evaluate(fn()), 32.) 3158 3159 def testWhileGrad_ResourceVarInFunctionCall(self): 3160 3161 @def_function.function 3162 def foo(x, var): 3163 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 3164 3165 @def_function.function 3166 def bar(var): 3167 r = control_flow_ops.while_loop( 3168 lambda i, _: i < 2, 3169 lambda i, x: (i + 1, foo(x, var)), 3170 [0, 0.0])[1] 3171 return gradients_impl.gradients(r, var)[0] 3172 3173 var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.]) 3174 self.evaluate(variables.global_variables_initializer()) 3175 grad = self.evaluate(bar(var)) 3176 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 3177 3178 def testWhileGrad_ResourceVarInNestedFunctionCall(self): 3179 3180 @def_function.function 3181 def foo(x, var): 3182 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 3183 3184 @def_function.function 3185 def foo2(x, var): 3186 return foo(x, var) 3187 3188 @def_function.function 3189 def bar(var): 3190 r = control_flow_ops.while_loop( 3191 lambda i, _: i < 2, 3192 lambda i, x: (i + 1, foo2(x, var)), 3193 [0, 0.0])[1] 3194 return gradients_impl.gradients(r, var)[0] 3195 3196 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 3197 self.evaluate(variables.global_variables_initializer()) 3198 grad = self.evaluate(bar(var)) 3199 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 3200 3201 def testWhileGrad_ResourceVarInLoopInFunctionCall(self): 3202 if test.is_gpu_available(): 3203 self.skipTest("b/128635252") 3204 3205 @def_function.function 3206 def foo(x, var): 3207 return control_flow_ops.while_loop( 3208 lambda j, _: j < 3, 3209 lambda j, y: (j + 1, 3210 y + math_ops.reduce_sum(var.sparse_read([1, 2]))), 3211 [0, x])[1] 3212 3213 @def_function.function 3214 def bar(var): 3215 r = control_flow_ops.while_loop( 3216 lambda i, _: i < 2, 3217 lambda i, x: (i + 1, foo(x, var)), 3218 [0, 0.0])[1] 3219 return gradients_impl.gradients(r, var)[0] 3220 3221 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 3222 self.evaluate(variables.global_variables_initializer()) 3223 grad = self.evaluate(bar(var)) 3224 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.]) 3225 3226 def testWhileCondGrad_ResourceVarInFunctionCall(self): 3227 3228 @def_function.function 3229 def foo(x, var): 3230 return x + var.sparse_read([1])[0] 3231 3232 def body(i, x): 3233 return (i + 1, control_flow_ops.cond( 3234 math_ops.equal(i % 2, 0), 3235 lambda: foo(x, var1), 3236 lambda: foo(x, var2))) 3237 3238 @def_function.function 3239 def bar(var1, var2): 3240 r = control_flow_ops.while_loop( 3241 lambda i, _: i < 4, body, [0, 0.0]) 3242 return gradients_impl.gradients(r, [var1, var2]) 3243 3244 var1 = resource_variable_ops.ResourceVariable([1., 2., 3.]) 3245 var2 = resource_variable_ops.ResourceVariable([4., 5.]) 3246 self.evaluate(variables.global_variables_initializer()) 3247 grads = self.evaluate(bar(var1, var2)) 3248 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.]) 3249 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.]) 3250 3251 @test_util.run_deprecated_v1 3252 def testWhileGrad_ResourceVarSparseRead(self): 3253 # NOTE(skyewm): this test is interesting because the gradient is the 3254 # aggregation result of IndexedSlices and Tensors. 3255 var = resource_variable_ops.ResourceVariable(np.ones(5), 3256 dtype=dtypes.float32) 3257 r = control_flow_ops.while_loop( 3258 lambda i, _: i < 3, 3259 lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))), 3260 [0, constant_op.constant(1.0)])[1] 3261 grad = gradients_impl.gradients(r, var)[0] 3262 3263 self.evaluate(variables.global_variables_initializer()) 3264 grad_val = self.evaluate(grad) 3265 arr = gradient_checker_v2._to_numpy(grad_val) 3266 self.assertAllEqual(arr, [0., 12., 0., 12., 0.]) 3267 3268 @test_util.run_deprecated_v1 3269 def testWhileGrad_MultiResourceVarSparseRead(self): 3270 # NOTE(skyewm): this test is interesting because the gradient is the 3271 # aggregation result of IndexedSlices and Tensors. 3272 var1 = resource_variable_ops.ResourceVariable(np.ones(5), 3273 dtype=dtypes.float32) 3274 var2 = resource_variable_ops.ResourceVariable(np.ones(3), 3275 dtype=dtypes.float32) 3276 x1_init = constant_op.constant([0., 0.]) 3277 x2_init = constant_op.constant(1.) 3278 x3_init = constant_op.constant(1.) 3279 3280 def body(i, unused_x1, x2, x3): 3281 y1 = var1.sparse_read([1, 3]) 3282 y2 = x2 * 2 3283 y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0])) 3284 return i + 1, y1, y2, y3 3285 3286 r = control_flow_ops.while_loop( 3287 lambda i, x1, x2, x3: i < 3, body, 3288 [0, x1_init, x2_init, x3_init])[1:] 3289 var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2]) 3290 3291 self.evaluate(variables.global_variables_initializer()) 3292 var1_grad_val = self.evaluate(var1_grad) 3293 var2_grad_val = self.evaluate(var2_grad) 3294 self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val), 3295 [0., 1., 0., 1., 0.]) 3296 self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val), 3297 [3., 0., 0.]) 3298 3299 def testWhileGrad_Gather(self): 3300 # NOTE(skyewm): this test is interesting because the gather gradient 3301 # function returns an IndexedSlices. 3302 @tf_function_in_tf2 3303 def fn(): 3304 x = constant_op.constant([1., 1., 1., 1., 1.]) 3305 y = control_flow_ops.while_loop( 3306 lambda i, _: i < 3, 3307 lambda i, x: (i + 1, x + array_ops.gather(x, [0])), 3308 [0, x[:1]])[1] 3309 z = y * 3.0 3310 grad = gradients_impl.gradients(z, x)[0] 3311 return y, grad 3312 y, grad = fn() 3313 self.assertEqual(self.evaluate(y), 8.) 3314 self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.]) 3315 3316 def testWhileGrad_GatherNoFanOut(self): 3317 # NOTE(skyewm): this test is interesting because the gather gradient 3318 # function returns an IndexedSlices. 3319 @tf_function_in_tf2 3320 def fn(): 3321 x = constant_op.constant([1., 1., 1., 1., 1.]) 3322 y = control_flow_ops.while_loop( 3323 lambda i, _: i < 3, 3324 lambda i, x: (i + 1, array_ops.gather(x, [0])), 3325 [0, x[:1]])[1] 3326 z = y * 3.0 3327 grad = gradients_impl.gradients(z, x)[0] 3328 return y, grad 3329 y, grad = fn() 3330 self.assertEqual(self.evaluate(y), 1.) 3331 self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.]) 3332 3333 @test_util.run_v1_only("b/120545219") 3334 def testWhileGradInCond(self): 3335 3336 with self.cached_session(): 3337 n = ops.convert_to_tensor(1.0, name="n") 3338 x = array_ops.placeholder(dtypes.float32, shape=None) 3339 c = lambda n: math_ops.less(n, 10.0) 3340 b = lambda n: math_ops.add(n, x) 3341 3342 def fn1(): 3343 r = control_flow_ops.while_loop(c, b, [n], 3344 [tensor_shape.unknown_shape()]) 3345 return gradients_impl.gradients(r, x)[0] 3346 3347 r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) 3348 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 3349 3350 @test_util.disable_control_flow_v2("b/116340060") 3351 @test_util.run_v1_only("b/120545219") 3352 def testGradInWhileWrtInitialLoopVal(self): 3353 with self.cached_session(): 3354 x = array_ops.placeholder(dtypes.float32, shape=(), name="x") 3355 y = x + 1 3356 3357 def body(i, v): 3358 z = v * 2 3359 return i + 1, gradients_impl.gradients(z, x)[0] 3360 3361 with self.assertRaisesRegex( 3362 ValueError, 3363 "Cannot compute gradient inside while loop with respect to op 'x'. " 3364 "We do not support taking the gradient wrt or through the initial " 3365 "value of a loop variable. Gradients can be computed through " 3366 "loop invariants or wrt the input parameters to the loop body."): 3367 control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y]) 3368 3369 @test_util.run_v1_only("b/120545219") 3370 def testWhileGradInWhile(self): 3371 with self.cached_session(): 3372 n = ops.convert_to_tensor(1.0, name="n") 3373 x = array_ops.placeholder(dtypes.float32, shape=None) 3374 c = lambda n: math_ops.less(n, 10.0) 3375 b = lambda n: math_ops.add(n, x) 3376 3377 def b1(n): 3378 r = control_flow_ops.while_loop(c, b, [n], 3379 [tensor_shape.unknown_shape()]) 3380 return gradients_impl.gradients(r, x) 3381 3382 r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n], 3383 [tensor_shape.unknown_shape()]) 3384 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 3385 3386 @test_util.run_v1_only("b/120545219") 3387 def testCondGradInNestedWhiles(self): 3388 3389 def outer_body(i, x): 3390 _, x = control_flow_ops.while_loop( 3391 lambda j, x: j < 3, inner_body, [0, 0.0]) 3392 return i + 1, x 3393 3394 def inner_body(j, x): 3395 y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x) 3396 return j + 1, gradients_impl.gradients(y, x)[0] 3397 3398 i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0]) 3399 3400 with self.cached_session() as sess: 3401 i_val, x_val = self.evaluate([i, x]) 3402 self.assertEqual(i_val, 3) 3403 self.assertAllClose(x_val, 1.0) 3404 3405 @test_util.run_gpu_only 3406 def testGpuResourceAccess(self): 3407 with ops.device(test.gpu_device_name()): 3408 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 3409 3410 @def_function.function 3411 def foo(): 3412 return control_flow_ops.while_loop( 3413 lambda i, _: i < 3, 3414 lambda i, x: (i + 1, control_flow_ops.cond( 3415 constant_op.constant(True), 3416 lambda: x + var, 3417 lambda: x)), 3418 [0, 0.0])[1] 3419 3420 self.evaluate(variables.global_variables_initializer()) 3421 self.assertEqual(self.evaluate(foo()), 9.0) 3422 3423 def testNestedResourceAccess(self): 3424 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 3425 3426 @eager_function.defun 3427 def test_fn(): 3428 x = constant_op.constant(0.0) 3429 r = control_flow_ops.while_loop( 3430 # Outer loop condition 3431 lambda i, y: i < 2, 3432 # Outer loop body 3433 lambda i, y: (i + 1, y + control_flow_ops.cond( 3434 constant_op.constant(True), 3435 # True branch 3436 lambda: control_flow_ops.while_loop( 3437 # Inner loop condition 3438 lambda j, z: j < 3, 3439 # Inner loop body 3440 lambda j, z: (j + 1, z + math_ops.square(var)), 3441 # Inner initial loop value 3442 [0, y])[1], 3443 # False branch 3444 lambda: (0.0))), 3445 # Outer initial loop value 3446 [0, x])[1] 3447 3448 grad = gradients_impl.gradients(r, x)[0] 3449 return r, grad 3450 3451 self.evaluate(variables.global_variables_initializer()) 3452 r, grad = self.evaluate(test_fn()) 3453 # 2 * 3 * 3^2 3454 self.assertEqual(r, 81.0) 3455 # v1 control flow gets the wrong answer!!! 3456 # Gradient computation: 3457 # f(x) = x + 3^2 3458 # inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27 3459 # g(x) = x + inner_loop(x) = 2x + 27 3460 # outer_loop(x) = g(g(x)) = 4x + 81 3461 # outer_loop'(x) = 4 3462 # Note that v1 control flow gets 4.0 as well if the cond is removed. 3463 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 3464 self.assertEqual(grad, 4.0) 3465 3466 def testWhile_NestedInput(self): 3467 with self.cached_session() as sess: 3468 named = collections.namedtuple("named", ("a", "b")) 3469 loop_vars = [ 3470 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3471 (constant_op.constant(2.0), constant_op.constant(3.0)), 3472 constant_op.constant(4.0) 3473 ] 3474 c = lambda lv0, _1, _2: lv0.a < 100.0 3475 3476 def b(lv0, lv1, lv2): 3477 lv0 = named(a=lv0.a + 1, b=lv0.b) 3478 lv1 = (lv1[0] + 1, lv1[1]) 3479 lv2 += 2 3480 return [lv0, lv1, lv2] 3481 3482 r = control_flow_ops.while_loop(c, b, loop_vars) 3483 3484 self.assertTrue(isinstance(r, list)) 3485 self.assertTrue(isinstance(r[0], named)) 3486 self.assertTrue(isinstance(r[1], tuple)) 3487 self.assertTrue(isinstance(r[2], ops.Tensor)) 3488 3489 r_flattened = nest.flatten(r) 3490 self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], 3491 self.evaluate(r_flattened)) 3492 3493 @test_util.run_v1_only("b/120545219") 3494 def testWhile_NestedBadArityFails(self): 3495 with self.cached_session(): 3496 named = collections.namedtuple("named", ("a", "b")) 3497 loop_vars = [ 3498 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3499 (constant_op.constant(2.0), constant_op.constant(3.0)), 3500 constant_op.constant(4.0) 3501 ] 3502 c = lambda lv0, _1, _2: lv0.a < 100.0 3503 3504 def b(lv0, lv1, _): 3505 return [lv0, lv1] 3506 3507 with self.assertRaisesRegex(ValueError, "the same number of elements"): 3508 control_flow_ops.while_loop(c, b, loop_vars) 3509 3510 @test_util.run_v1_only("b/120545219") 3511 def testWhileGrad_ys_xs(self): 3512 with self.cached_session(): 3513 x = constant_op.constant(3.0, name="x") 3514 y = constant_op.constant(2.0, name="y") 3515 3516 c = lambda x, y: math_ops.less(x, 100.0) 3517 3518 def b(x, y): 3519 y1 = math_ops.add(x, y) 3520 x1 = math_ops.multiply(x, y1) 3521 return x1, y1 3522 3523 rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1) 3524 3525 r = gradients_impl.gradients([rx, ry], x) 3526 self.assertAllClose(304.0, r[0]) 3527 r = gradients_impl.gradients([rx, ry], y) 3528 self.assertAllClose(124.0, r[0]) 3529 r = gradients_impl.gradients([rx], x) 3530 self.assertAllClose(295.0, r[0]) 3531 r = gradients_impl.gradients([rx], y) 3532 self.assertAllClose(120.0, r[0]) 3533 3534 @test_util.run_deprecated_v1 3535 def testWhileGrad_Dependency(self): 3536 with self.cached_session(): 3537 i = constant_op.constant(0, name="i") 3538 x = constant_op.constant(2.0, name="x") 3539 3540 c = lambda i, x: math_ops.less(i, 10) 3541 3542 def b(i, x): 3543 x = math_ops.multiply(x, 2.0) 3544 i = math_ops.add(i, 1) 3545 return i, x 3546 3547 ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3548 3549 r = gradients_impl.gradients([ri, rx], x) 3550 self.assertAllClose(1024.0, r[0]) 3551 r = gradients_impl.gradients([rx], x) 3552 self.assertAllClose(1024.0, r[0]) 3553 3554 @test_util.run_v1_only("b/120545219") 3555 def testWhileGrad_NoGradient(self): 3556 with self.cached_session(): 3557 v = constant_op.constant(2.0, name="v") 3558 c = lambda v: math_ops.less(v, 100.0) 3559 b = math_ops.square 3560 r = control_flow_ops.while_loop(c, b, [v], back_prop=False) 3561 r = math_ops.add(r, v) 3562 r = gradients_impl.gradients(r, v) 3563 self.assertAllClose(1.0, r[0]) 3564 3565 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3566 @test_util.run_v1_only("b/120545219") 3567 def testWhileGrad_NoDependency(self): 3568 with self.cached_session() as sess: 3569 variable = variables.Variable(array_ops.ones([2, 3])) 3570 duration = array_ops.zeros([], dtype=dtypes.int32) 3571 3572 def cond(duration, tensor, _): 3573 del tensor 3574 return duration < 10 3575 3576 def body(duration, tensor, _): 3577 return (duration + 1, tensor, tensor) 3578 3579 loop_vars = [duration, variable, variable] 3580 tensors = control_flow_ops.while_loop( 3581 cond=cond, body=body, loop_vars=loop_vars) 3582 cost = math_ops.reduce_sum(tensors[2]) 3583 grad = gradients_impl.gradients(cost, [variable]) 3584 self.evaluate(variables.global_variables_initializer()) 3585 self.assertAllClose(np.ones([2, 3]), sess.run(grad[0])) 3586 3587 @test_util.run_deprecated_v1 3588 def testWhileGrad_Const(self): 3589 with self.cached_session() as sess: 3590 c0 = constant_op.constant(0.0, name="c0") 3591 c1 = constant_op.constant(1.0, name="c1") 3592 duration = constant_op.constant(0, name="t") 3593 3594 def cond(duration, _): 3595 return duration < 1 3596 3597 def body(duration, _): 3598 return duration + 1, c1 3599 3600 loop_vars = [duration, c0] 3601 tensors = control_flow_ops.while_loop( 3602 cond=cond, body=body, loop_vars=loop_vars) 3603 cost = math_ops.reduce_sum(tensors[1]) 3604 grad = gradients_impl.gradients(cost, [c0]) 3605 self.assertAllClose(0.0, sess.run(grad[0])) 3606 3607 @test_util.run_v1_only("b/120545219") 3608 def testWhileGrad_SerialTwoLoops(self): 3609 with self.cached_session(): 3610 i = constant_op.constant(0, name="i") 3611 x = constant_op.constant(2.0, name="x") 3612 3613 c = lambda i, x: math_ops.less(i, 5) 3614 3615 def b(i, x): 3616 x = math_ops.multiply(x, 2.0) 3617 i = math_ops.add(i, 1) 3618 return i, x 3619 3620 _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3621 _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1) 3622 3623 r = gradients_impl.gradients([rx], x) 3624 self.assertAllClose(1024.0, r[0]) 3625 3626 @test_util.run_v1_only("b/120545219") 3627 def testWhileGrad_ParallelTwoLoops(self): 3628 with self.cached_session(): 3629 i = constant_op.constant(0, name="i") 3630 x = constant_op.constant(2.0, name="x") 3631 3632 c = lambda i, x: math_ops.less(i, 5) 3633 3634 def b(i, x): 3635 x = math_ops.multiply(x, 2.0) 3636 i = math_ops.add(i, 1) 3637 return i, x 3638 3639 _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3640 _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3641 rx = math_ops.add(r1, r2) 3642 3643 r = gradients_impl.gradients([rx], x) 3644 self.assertAllClose(64.0, r[0]) 3645 3646 @test_util.run_v1_only("b/120545219") 3647 def testWhileGrad_OneOutputWithControlDependencyOnSecond(self): 3648 with self.cached_session(): 3649 i = constant_op.constant(0, name="i") 3650 x = constant_op.constant(1.0, name="x") 3651 y = constant_op.constant(1.0, name="y") 3652 c = lambda i, *_: math_ops.less(i, 1, name="cond_less") 3653 3654 def b(i, xi, yi): 3655 # return (i + 1, xi, xi + yi) 3656 return (math_ops.add(i, 1, name="inc"), array_ops.identity( 3657 xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi")) 3658 3659 _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y]) 3660 with ops.control_dependencies([x_f]): 3661 y_f_d = array_ops.identity(y_f, name="y_f_d") 3662 3663 self.assertAllClose(2.0, self.evaluate(y_f_d)) # y_f_d = 1.0 + 1.0 3664 g = gradients_impl.gradients([y_f_d], [x])[0] 3665 self.assertTrue(g is not None) 3666 self.assertAllClose(1.0, 3667 self.evaluate(g)) # y_f_d = x + 1.0, dy_f_d/dx = 1.0 3668 3669 def _testNestedWhileGrad_Simple(self, use_gpu): 3670 with self.cached_session(use_gpu=use_gpu): 3671 v = constant_op.constant(1.0) 3672 3673 def inner_loop(s): 3674 c = lambda x: math_ops.less(x, 4.0) 3675 b = lambda x: math_ops.multiply(x, 2.0) 3676 return control_flow_ops.while_loop(c, b, [s]) 3677 3678 c = lambda x: math_ops.less(x, 2.0) 3679 b = lambda x: math_ops.multiply(inner_loop(x), 2.0) 3680 r = control_flow_ops.while_loop(c, b, [v]) 3681 3682 r = gradients_impl.gradients(r, v)[0] 3683 self.assertAllClose(8.0, self.evaluate(r)) 3684 3685 @test_util.run_deprecated_v1 3686 def testNestedWhileGrad_Simple(self): 3687 self._testNestedWhileGrad_Simple(use_gpu=False) 3688 self._testNestedWhileGrad_Simple(use_gpu=True) 3689 3690 @test_util.run_v1_only("b/120545219") 3691 def testNestedWhileGrad_SerialInner(self): 3692 with self.cached_session(): 3693 v = constant_op.constant(1.0) 3694 3695 def inner_loop1(s): 3696 z = constant_op.constant(0) 3697 c = lambda i, x: math_ops.less(i, 4) 3698 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3699 return control_flow_ops.while_loop(c, b, [z, s]) 3700 3701 def inner_loop2(s): 3702 z = constant_op.constant(0) 3703 c = lambda i, x: math_ops.less(i, 4) 3704 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3705 return control_flow_ops.while_loop(c, b, [z, s]) 3706 3707 c = lambda x: math_ops.less(x, 128.0) 3708 b = lambda x: inner_loop2(inner_loop1(x)[1])[1] 3709 r = control_flow_ops.while_loop(c, b, [v]) 3710 3711 r = gradients_impl.gradients(r, v)[0] 3712 self.assertAllClose(256.0, self.evaluate(r)) 3713 3714 @test_util.run_deprecated_v1 3715 def testNestedWhileGrad_ParallelInner(self): 3716 with self.cached_session(): 3717 v = constant_op.constant(1.0) 3718 3719 def inner_loop1(s): 3720 z = constant_op.constant(0) 3721 c = lambda i, x: math_ops.less(i, 4) 3722 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3723 return control_flow_ops.while_loop(c, b, [z, s]) 3724 3725 def inner_loop2(s): 3726 z = constant_op.constant(0) 3727 c = lambda i, x: math_ops.less(i, 4) 3728 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3729 return control_flow_ops.while_loop(c, b, [z, s]) 3730 3731 c = lambda x: math_ops.less(x, 128.0) 3732 b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1]) 3733 r = control_flow_ops.while_loop(c, b, [v]) 3734 3735 r = gradients_impl.gradients(r, v)[0] 3736 self.assertAllClose(512.0, self.evaluate(r)) 3737 3738 @test_util.run_v1_only("b/120545219") 3739 def testNestedWhileGrad_ParallelIterations(self): 3740 # Make sure the stack pushes and pops of an inner loop are executed in 3741 # the sequential order of the iterations of its outer loop. 3742 with self.cached_session() as sess: 3743 3744 def inner_loop(t): 3745 fn = lambda n: n + math_ops.square(var) 3746 return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10) 3747 3748 def outer_loop(inp): 3749 return map_fn.map_fn( 3750 fn=inner_loop, elems=inp, parallel_iterations=10) 3751 3752 var = variables.Variable(constant_op.constant(3.0)) 3753 inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 3754 res = outer_loop(inp) 3755 optimizer = adam.AdamOptimizer(learning_rate=0.001) 3756 train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res))) 3757 self.evaluate(variables.global_variables_initializer()) 3758 self.evaluate(train_op) 3759 self.assertAllClose(2.999, var.read_value()) 3760 3761 def _testWhileCondGrad_Simple(self, use_gpu): 3762 with self.cached_session(use_gpu=use_gpu): 3763 v = ops.convert_to_tensor(2.0, name="v") 3764 n = ops.convert_to_tensor(100.0, name="n") 3765 one = ops.convert_to_tensor(1.0, name="one") 3766 c = lambda x: math_ops.less(x, n) 3767 # pylint: disable=undefined-variable 3768 # for OSS build 3769 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3770 lambda: math_ops.square(x), 3771 lambda: math_ops.subtract(x, one)) 3772 # pylint: enable=undefined-variable 3773 r = control_flow_ops.while_loop(c, b, [v]) 3774 r = gradients_impl.gradients(r, v)[0] 3775 self.assertAllClose(1024.0, self.evaluate(r)) 3776 3777 @test_util.run_deprecated_v1 3778 def testWhileCondGrad_Simple(self): 3779 self._testWhileCondGrad_Simple(use_gpu=False) 3780 self._testWhileCondGrad_Simple(use_gpu=True) 3781 3782 @test_util.run_deprecated_v1 3783 def testWhileCondGrad_UnknownShape(self): 3784 with self.cached_session() as sess: 3785 v = array_ops.placeholder(dtypes.float32) 3786 n = ops.convert_to_tensor(100.0, name="n") 3787 one = ops.convert_to_tensor(1.0, name="one") 3788 c = lambda x: math_ops.less(x, n) 3789 # pylint: disable=undefined-variable 3790 # for OSS build 3791 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3792 lambda: math_ops.square(x), 3793 lambda: math_ops.subtract(x, one)) 3794 # pylint: enable=undefined-variable 3795 r = control_flow_ops.while_loop(c, b, [v]) 3796 r = gradients_impl.gradients(r, v)[0] 3797 r = sess.run(r, feed_dict={v: 2.0}) 3798 self.assertAllClose(1024.0, r) 3799 3800 @test_util.run_deprecated_v1 3801 def testWhileGrad_Concat(self): 3802 with self.cached_session() as sess: 3803 x = variable_scope.get_variable("x", initializer=[[1., 2.]]) 3804 i0 = constant_op.constant(0) 3805 h0 = array_ops.zeros([0, 2]) 3806 3807 def condition(i, _): 3808 return i < 2 3809 3810 def body(i, h): 3811 return i + 1, array_ops.concat([h, x], 0) 3812 3813 _, h = control_flow_ops.while_loop( 3814 condition, body, [i0, h0], 3815 [i0.get_shape(), tensor_shape.TensorShape([None, 2])]) 3816 s = math_ops.reduce_sum(h) 3817 3818 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 3819 op = optimizer.minimize(s) 3820 3821 self.evaluate(variables.global_variables_initializer()) 3822 self.evaluate(op) 3823 self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x)) 3824 3825 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3826 @test_util.run_v1_only("b/120545219") 3827 def testWhileWithRefsWithGradients_1(self): 3828 with self.cached_session() as sess: 3829 x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access 3830 i = constant_op.constant(0) 3831 c = lambda i, x: math_ops.less(i, 10) 3832 3833 self.assertEqual(x.dtype, dtypes.float32_ref) 3834 3835 def body(i, x): 3836 self.assertEqual(x.dtype, dtypes.float32_ref) 3837 return [i + 1, gen_array_ops.ref_identity(x)] 3838 3839 r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) 3840 3841 grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access 3842 grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys) 3843 3844 self.evaluate(variables.global_variables_initializer()) 3845 3846 self.assertEqual(r[0].dtype, dtypes.int32) 3847 self.assertEqual(r[1].dtype, dtypes.float32_ref) 3848 3849 value_i, value_x, value_x_grad = sess.run(r + grad) 3850 3851 self.assertEqual(10, value_i) 3852 self.assertEqual(0, value_x) 3853 self.assertEqual(73, value_x_grad) 3854 3855 @test_util.deprecated_graph_mode_only 3856 def testWhileGrad_IndexedSlices(self): 3857 with self.cached_session(): 3858 values = constant_op.constant([2.0, 4.0], name="values") 3859 indices = constant_op.constant([0, 3], name="indices") 3860 shape = constant_op.constant([10], name="dense_shape") 3861 i = constant_op.constant(0) 3862 x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape) 3863 3864 def c(i, _): 3865 return i < 10 3866 3867 def b(i, x): 3868 return [ 3869 i + 1, 3870 indexed_slices.IndexedSlices(x.values * 2.0, x.indices, 3871 x.dense_shape) 3872 ] 3873 3874 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3875 r = gradients_impl.gradients(r.values, values)[0] 3876 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3877 3878 @test_util.deprecated_graph_mode_only 3879 def testWhileGrad_SparseTensor(self): 3880 with self.cached_session(): 3881 values = constant_op.constant([2.0, 4.0], name="values") 3882 indices = constant_op.constant( 3883 [[0], [3]], dtype=dtypes.int64, name="indices") 3884 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 3885 i = constant_op.constant(0) 3886 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 3887 3888 def c(i, _): 3889 return i < 10 3890 3891 def b(i, x): 3892 return [ 3893 i + 1, 3894 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 3895 ] 3896 3897 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3898 r = gradients_impl.gradients(r.values, values)[0] 3899 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3900 3901 @test_util.deprecated_graph_mode_only 3902 def testCallGradInLoop(self): 3903 with self.cached_session() as sess: 3904 i0 = constant_op.constant(0) 3905 params = constant_op.constant(5.0) 3906 params_1 = math_ops.square(params) 3907 3908 def c(i, _): 3909 return i < 10 3910 3911 def b(i, x): 3912 data = constant_op.constant([1.0, 2.0, 3.0]) 3913 data = math_ops.multiply(data, params_1) 3914 x1 = x + gradients_impl.gradients(data, params)[0] 3915 return i + 1, x1 3916 3917 output_grad = control_flow_ops.while_loop( 3918 c, b, [i0, constant_op.constant(0.0)]) 3919 self.assertAllClose(600.0, self.evaluate(output_grad)[1]) 3920 3921 @test_util.run_deprecated_v1 3922 def testWhileAndTensorArray(self): 3923 with self.cached_session() as sess: 3924 param = constant_op.constant(2.0) 3925 n0 = constant_op.constant(0) 3926 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 3927 3928 def c(i, _): 3929 return i < 10 3930 3931 def b(i, y): 3932 return [ 3933 i + 1, 3934 map_fn.map_fn(lambda x: math_ops.multiply(x, param), y) 3935 ] 3936 3937 r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1) 3938 r = gradients_impl.gradients(r, param)[0] 3939 self.assertAllClose(107520.0, self.evaluate(r)) 3940 3941 @test_util.run_deprecated_v1 3942 def testNestedWhileAndTensorArray(self): 3943 n = constant_op.constant(3.0) 3944 3945 def Body(row, ta): 3946 3947 def InnerBody(row, col, ta): 3948 # Note: row and col are 1-based. 3949 ta = ta.write( 3950 math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col) 3951 return row, col + 1., ta 3952 3953 ta = control_flow_ops.while_loop( 3954 lambda _, col, _1: col <= n, 3955 InnerBody, [row, constant_op.constant(1.), ta], 3956 return_same_structure=False)[2] 3957 return row + 1., ta 3958 3959 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9) 3960 ta = control_flow_ops.while_loop( 3961 lambda row, _: row <= n, 3962 Body, [constant_op.constant(1.), ta], 3963 return_same_structure=False)[1] 3964 3965 output = array_ops.reshape(ta.stack(), [3, 3]) 3966 self.assertAllEqual( 3967 self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]]) 3968 # TODO(b/117675481): This does not work with current TA. Enable with new TA. 3969 # grad = gradients_impl.gradients(output, [n]) 3970 # self.assertEqual(self.evaluate(grad), 3.5) 3971 3972 @test_util.run_deprecated_v1 3973 def testWhileGrad_StopGrad(self): 3974 with self.cached_session(): 3975 x = constant_op.constant(3.0, name="x") 3976 y = constant_op.constant(2.0, name="y") 3977 3978 c = lambda x, y: math_ops.less(x, 100.0) 3979 3980 def b(x, y): 3981 y1 = math_ops.square(y) 3982 x1 = math_ops.add(math_ops.square(x), y1) 3983 return x1, y1 3984 3985 rx, ry = control_flow_ops.while_loop(c, b, [x, y]) 3986 3987 r = gradients_impl.gradients(rx, y)[0] 3988 self.assertEqual(136.0, self.evaluate(r)) 3989 r = gradients_impl.gradients(ry, y)[0] 3990 self.assertEqual(32.0, self.evaluate(r)) 3991 3992 r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0] 3993 self.assertEqual(r, None) 3994 r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0] 3995 self.assertEqual(r, None) 3996 3997 r = gradients_impl.gradients( 3998 array_ops.stop_gradient(math_ops.square(rx)), y)[0] 3999 self.assertEqual(r, None) 4000 r = gradients_impl.gradients( 4001 array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0] 4002 self.assertEqual(r, None) 4003 r = gradients_impl.gradients( 4004 array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0] 4005 self.assertEqual(r, None) 4006 4007 r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0] 4008 self.assertEqual(168.0, self.evaluate(r)) 4009 r = gradients_impl.gradients( 4010 math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0] 4011 self.assertEqual(136.0, self.evaluate(r)) 4012 r = gradients_impl.gradients( 4013 math_ops.add(array_ops.stop_gradient(rx), ry), y)[0] 4014 self.assertEqual(32.0, self.evaluate(r)) 4015 4016 @test_util.run_deprecated_v1 4017 def testWhileGrad_StopGradInside(self): 4018 with self.cached_session(): 4019 x = constant_op.constant(3.0, name="x") 4020 y = constant_op.constant(2.0, name="y") 4021 4022 c = lambda x, y: math_ops.less(x, 100.0) 4023 4024 def b(x, y): 4025 y1 = array_ops.stop_gradient(math_ops.square(y)) 4026 x1 = math_ops.add(math_ops.square(x), y1) 4027 return x1, y1 4028 4029 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 4030 4031 r = gradients_impl.gradients(rx, y)[0] 4032 self.assertAllClose(0.0, self.evaluate(r)) 4033 r = gradients_impl.gradients(rx, x)[0] 4034 self.assertAllClose(156.0, self.evaluate(r)) 4035 4036 @test_util.run_deprecated_v1 4037 def testWhileGrad_StopGradInsideNoShape(self): 4038 with self.cached_session() as sess: 4039 x = array_ops.placeholder(dtypes.float32) 4040 y = array_ops.placeholder(dtypes.float32) 4041 4042 c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0) 4043 4044 def b(x, y): 4045 y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped")) 4046 x1 = math_ops.add(math_ops.square(x), y1) 4047 return x1, y1 4048 4049 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 4050 4051 grad_y = gradients_impl.gradients(rx, y)[0] 4052 grad_x = gradients_impl.gradients(rx, x)[0] 4053 feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]} 4054 self.assertAllClose([0.0, 0.0], sess.run(grad_y, feed_dict=feed_dict)) 4055 self.assertAllClose([156.0, 400.0], sess.run(grad_x, feed_dict=feed_dict)) 4056 name = "gradients/while/stopped_grad" 4057 all_ops = x.graph.get_operations() 4058 self.assertFalse(any(name in op.name for op in all_ops)) 4059 4060 @test_util.run_deprecated_v1 4061 def testWhileGradGradFail(self): 4062 theta = variables.Variable(initial_value=1.) 4063 4064 def fn(prev, x): 4065 return prev + x * theta 4066 4067 result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) 4068 grad_theta = gradients_impl.gradients(result, theta) 4069 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 4070 with self.assertRaisesRegex(TypeError, "Second-order gradient"): 4071 gradients_impl.gradients(grad_theta, theta) 4072 grad_theta_stopped = array_ops.stop_gradient(grad_theta) 4073 gradients_impl.gradients(grad_theta_stopped, theta) 4074 4075 @test_util.run_deprecated_v1 4076 def testStopGradOnWhileGrad(self): 4077 with self.cached_session(): 4078 x = constant_op.constant(2.0, name="x") 4079 y = constant_op.constant(2.0, name="y") 4080 4081 c = lambda x: math_ops.less(x, 100.0) 4082 b = lambda x: math_ops.multiply(x, y) 4083 rx = control_flow_ops.while_loop(c, b, [x]) 4084 4085 rg = gradients_impl.gradients(rx, y)[0] 4086 rg = array_ops.stop_gradient(rg) 4087 r = math_ops.add(math_ops.square(y), rx) 4088 r = math_ops.add(r, rg) 4089 r = gradients_impl.gradients(r, y)[0] 4090 self.assertEqual(388.0, self.evaluate(r)) 4091 4092 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 4093 @test_util.run_deprecated_v1 4094 def testWhileGradientWithNontrainablePath1(self): 4095 q = variables.Variable([7., 8.]) 4096 4097 def cond(_, y): 4098 del y 4099 return False 4100 4101 def body(x, _): 4102 return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 4103 4104 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 4105 dy_dq, = gradients_impl.gradients(y, q) 4106 self.assertIsNotNone(dy_dq) 4107 with self.cached_session() as sess: 4108 self.evaluate(q.initializer) 4109 self.assertAllClose([0., 0.], self.evaluate(dy_dq)) 4110 4111 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 4112 @test_util.run_v1_only("b/120545219") 4113 def testWhileGradientWithNontrainablePath2(self): 4114 q = variables.Variable([7., 8.]) 4115 4116 def cond(_, y): 4117 return math_ops.equal(y, 0.) 4118 4119 def body(x, _): 4120 zero = constant_op.constant(0, dtype=dtypes.int64) 4121 return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 4122 4123 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 4124 dy_dq, = gradients_impl.gradients(y, q) 4125 self.assertIsNotNone(dy_dq) 4126 with self.cached_session() as sess: 4127 self.evaluate(q.initializer) 4128 self.assertAllClose([1., 1.], self.evaluate(dy_dq)) 4129 4130 @test_util.run_v1_only("b/120545219") 4131 def testIssue16504(self): 4132 c = constant_op.constant(np.arange(100), dtype=dtypes.float32) 4133 w = variables.Variable( 4134 initial_value=np.ones(100), dtype=dtypes.float32) / 100 4135 k = variables.Variable(0, dtype=dtypes.int32) 4136 chg_w = constant_op.constant(np.inf, dtype=dtypes.float32) 4137 4138 def cond(k, _, chg_w): 4139 return math_ops.logical_and(k < 10, chg_w > 1e-3) 4140 4141 def body(k, w, chg_w): 4142 grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w) 4143 w_n = w * math_ops.exp(-0.1 * grad) 4144 w_n /= math_ops.reduce_sum(w_n) 4145 chg_w = ( 4146 math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum( 4147 math_ops.abs(w))) 4148 return k + 1, w_n, chg_w 4149 4150 _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w]) 4151 grad, = gradients_impl.gradients(w, c) 4152 self.assertIsNotNone(grad) 4153 4154 @test_util.run_v1_only("b/120545219") 4155 def testStopGradMultiFlows(self): 4156 with self.cached_session(): 4157 4158 def body(i, y, r): 4159 x = variable_scope.get_variable( 4160 "x", 4161 shape=(), 4162 dtype=dtypes.float32, 4163 initializer=init_ops.ones_initializer()) 4164 y *= x 4165 return [i + 1, y, r + math_ops.reduce_sum(y)] 4166 4167 i0 = constant_op.constant(0) 4168 y0 = array_ops.ones(5) 4169 r0 = constant_op.constant(0.0) 4170 cond = lambda i, y, r: i < 1 4171 _, _, r = control_flow_ops.while_loop( 4172 cond, body, [i0, y0, r0], back_prop=True) 4173 4174 vars_ = variables.global_variables() 4175 grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0]) 4176 z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads))) 4177 result = gradients_impl.gradients(z, vars_)[0] 4178 self.evaluate(variables.global_variables_initializer()) 4179 self.assertEqual(5.0, self.evaluate(result)) 4180 4181 @test_util.run_v1_only("b/120545219") 4182 def testOneValueCond(self): 4183 4184 with self.cached_session(): 4185 c = array_ops.placeholder(dtypes.int32, shape=[]) 4186 one = ops.convert_to_tensor(1, name="one") 4187 two = ops.convert_to_tensor(2, name="two") 4188 p = math_ops.greater_equal(c, 1) 4189 i = control_flow_ops.cond(p, lambda: one, lambda: two) 4190 self.assertTrue(isinstance(i, ops.Tensor)) 4191 4192 # True case: c = 2 is >= 1 4193 self.assertEqual([1], i.eval(feed_dict={c: 2})) 4194 4195 # False case: c = 0 is not >= 1 4196 self.assertEqual([2], i.eval(feed_dict={c: 0})) 4197 4198 @test_util.run_deprecated_v1 4199 def testExampleCond(self): 4200 4201 with self.cached_session(): 4202 x = ops.convert_to_tensor([-2.0, 2.0], name="x") 4203 d = array_ops.placeholder(dtypes.int32, shape=[]) 4204 4205 def l2(): 4206 return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x))) 4207 4208 def l1(): 4209 return math_ops.reduce_sum(math_ops.abs(x)) 4210 4211 i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1) 4212 self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) 4213 self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) 4214 4215 @test_util.run_v1_only("b/120545219") 4216 def testCase(self): 4217 with self.cached_session(): 4218 x = constant_op.constant(1) 4219 y = constant_op.constant(2) 4220 z = constant_op.constant(3) 4221 f1 = lambda: constant_op.constant(17) 4222 f2 = lambda: constant_op.constant(23) 4223 f3 = lambda: constant_op.constant(-1) 4224 4225 r1 = control_flow_ops.case( 4226 { 4227 x < y: f1, 4228 x > z: f2 4229 }, default=f3, exclusive=True) 4230 self.assertAllEqual(r1, 17) 4231 4232 r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3) 4233 self.assertAllEqual(r2, 23) 4234 4235 # Duplicate events can happen, first one is selected 4236 r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3) 4237 self.assertAllEqual(r3, 17) 4238 4239 # Duplicate events cause an error if exclusive = True 4240 r4 = control_flow_ops.case( 4241 [(x < y, f1), (x < y, f2)], default=f3, exclusive=True) 4242 with self.assertRaisesOpError("Input error:"): 4243 self.evaluate(r4) 4244 4245 # Check that the default is called if none of the others are 4246 r5 = control_flow_ops.case({x > y: f1}, default=f3) 4247 self.assertAllEqual(r5, -1) 4248 4249 ran_once = [False, False, False] 4250 4251 def break_run_twice(ix): 4252 4253 def _break(): 4254 ran_once[ix] = True 4255 return constant_op.constant(ix) 4256 4257 return _break 4258 4259 # Should not fail - each conditional gets called exactly once 4260 # except default. Default gets called twice: once to create an 4261 # empty output and once for the actual cond switch. 4262 r6 = control_flow_ops.case( 4263 [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))], 4264 default=lambda: constant_op.constant(2)) 4265 4266 self.assertAllEqual(r6, 0) 4267 4268 @test_util.run_v1_only("b/120545219") 4269 def testCaseSideEffects(self): 4270 with self.cached_session() as sess: 4271 v0 = variables.Variable(-1) 4272 v1 = variables.Variable(-1) 4273 v2 = variables.Variable(-1) 4274 4275 a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0) 4276 b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1) 4277 c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2) 4278 4279 x = constant_op.constant(1) 4280 y = constant_op.constant(2) 4281 4282 r0 = control_flow_ops.case( 4283 ((x < y, a), (x > y, b)), default=c, exclusive=True) 4284 r1 = control_flow_ops.case( 4285 ((x > y, a), (x < y, b)), default=c, exclusive=True) 4286 r2 = control_flow_ops.case( 4287 ((x > y, a), (x > y, b)), default=c, exclusive=True) 4288 4289 self.evaluate(variables.global_variables_initializer()) 4290 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4291 self.assertEqual(2, self.evaluate(r2)) 4292 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2]) 4293 4294 self.evaluate(variables.global_variables_initializer()) 4295 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4296 self.assertEqual(1, self.evaluate(r1)) 4297 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1]) 4298 4299 self.evaluate(variables.global_variables_initializer()) 4300 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4301 self.assertEqual(0, self.evaluate(r0)) 4302 self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1]) 4303 4304 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 4305 @test_util.run_v1_only("b/120545219") 4306 def testOneOpCond(self): 4307 with self.cached_session(): 4308 v = variables.Variable(0) 4309 c = ops.convert_to_tensor(0) 4310 one = ops.convert_to_tensor(1) 4311 two = ops.convert_to_tensor(2) 4312 p = math_ops.greater_equal(c, 1) 4313 4314 def a(): 4315 return state_ops.assign(v, one) 4316 4317 def b(): 4318 return state_ops.assign(v, two) 4319 4320 i = control_flow_ops.cond(p, a, b) 4321 self.assertTrue(isinstance(i, ops.Tensor)) 4322 self.evaluate(variables.global_variables_initializer()) 4323 4324 self.assertEqual(0, self.evaluate(v)) 4325 4326 # True case: c = 2 is >= 1, v is set to 1. 4327 self.assertEqual(1, i.eval(feed_dict={c.name: 2})) 4328 self.assertEqual(1, self.evaluate(v)) 4329 4330 # False case: c = 0 is not >= 1, v is set to 2. 4331 self.assertEqual(2, i.eval(feed_dict={c.name: 0})) 4332 self.assertEqual(2, self.evaluate(v)) 4333 4334 @test_util.run_v1_only("b/120545219") 4335 def testWithOpsDependencies(self): 4336 with self.cached_session() as sess: 4337 v = variables.VariableV1(0.0) 4338 c = constant_op.constant(10) 4339 4340 # Fetching v directly will result in an uninitialized error 4341 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4342 self.evaluate([c, v]) 4343 4344 # Use a control dependency to ensure init_variable is run 4345 # while asking for c 4346 real_v = control_flow_ops.with_dependencies( 4347 name="real_tensor", 4348 output_tensor=v._ref(), # pylint: disable=protected-access 4349 dependencies=[v.initializer]) 4350 c_val, real_v_val = self.evaluate([c, real_v]) 4351 4352 # Ensure the result of 'real_c' is the same as 'c' 4353 self.assertAllEqual(10, c_val) 4354 4355 # Ensure that 'v' is initialized 4356 self.assertAllClose(0.0, real_v_val) 4357 4358 @test_util.run_v1_only("b/120545219") 4359 def testWithTensorDependencies(self): 4360 with self.cached_session(): 4361 v = variables.VariableV1(0.0) 4362 c1 = constant_op.constant(10) 4363 c2 = constant_op.constant(20) 4364 4365 # c1_with_init_v depends on the init op for v 4366 c1_with_init_v = control_flow_ops.with_dependencies( 4367 name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer]) 4368 # c2_with_c1 depends on the value of c1_with_init_v 4369 c2_with_c1_dep = control_flow_ops.with_dependencies( 4370 name="c2_with_c1_dep", 4371 output_tensor=c2, 4372 dependencies=[c1_with_init_v]) 4373 4374 # Fetching v directly will result in an uninitialized error 4375 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4376 self.evaluate(v) 4377 4378 # Get the value of 'c2_with_c1_dep', which should cause 'v' 4379 # to be initialized. 4380 self.assertAllEqual(20, self.evaluate(c2_with_c1_dep)) 4381 4382 # Ensure that 'v' is initialized 4383 self.assertAllClose(0.0, self.evaluate(v)) 4384 4385 @test_util.run_v1_only("b/120545219") 4386 def testWithIndexedSlicesDependencies(self): 4387 with self.cached_session(): 4388 v = variables.VariableV1( 4389 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32)) 4390 v_at_1 = indexed_slices.IndexedSlices(v, constant_op.constant([1])) 4391 gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices) 4392 v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer], 4393 v_at_1) 4394 gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values, 4395 v_at_1_after_init.indices) 4396 4397 # Fetching gather_v_at_1 will result in an uninitialized error 4398 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4399 self.evaluate(gather_v_at_1) 4400 4401 # Getting gather_v_at_1_after_init will work, and initialize v. 4402 self.assertAllEqual([[10.0, 11.0]], 4403 self.evaluate(gather_v_at_1_after_init)) 4404 4405 # Double check that 'v' is initialized 4406 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 4407 self.evaluate(v)) 4408 4409 def testDependenciesDevice(self): 4410 with ops.Graph().as_default(): 4411 # device set on tensor => same device on dep. 4412 with ops.device("/job:ps"): 4413 vd = variables.VariableV1([0.0]) 4414 with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd) 4415 self.assertTrue("/job:ps" in with_vd_dep.device) 4416 4417 # No device set on tensor => no device on dep. 4418 vnod = variables.VariableV1([0.0]) 4419 with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer], 4420 vnod) 4421 self.assertDeviceEqual(None, with_vnod_dep.device) 4422 4423 # device set on tensor, default device on graph => default device on dep. 4424 vdef = variables.VariableV1([0.0], name="vdef") 4425 with ops.device("/job:worker/device:GPU:1"): 4426 with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer], 4427 vdef) 4428 # The device is empty, but the colocation constraint is set. 4429 self.assertDeviceEqual("", with_vdef_dep.device) 4430 self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups()) 4431 4432 @test_util.run_v1_only("b/120545219") 4433 def testGroup(self): 4434 with self.cached_session() as sess: 4435 v1 = variables.VariableV1([0.0]) 4436 v2 = variables.VariableV1([1.0]) 4437 4438 # Group init1 and init2 and run. 4439 init = control_flow_ops.group(v1.initializer, v2.initializer) 4440 # Fetching v1 directly will result in an uninitialized error 4441 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4442 self.evaluate(v1) 4443 4444 # Runs "init" before fetching v1 and v2. 4445 init.run() 4446 v1_val, v2_val = self.evaluate([v1, v2]) 4447 4448 # Ensure that v1 and v2 are initialized 4449 self.assertAllClose([0.0], v1_val) 4450 self.assertAllClose([1.0], v2_val) 4451 4452 @test_util.run_v1_only("b/120545219") 4453 def testGroupEmpty(self): 4454 op = control_flow_ops.group() 4455 self.assertEqual(op.type, "NoOp") 4456 self.assertEqual(op.control_inputs, []) 4457 4458 @test_util.run_deprecated_v1 4459 def testMergeShapes(self): 4460 # All inputs unknown. 4461 p1 = array_ops.placeholder(dtypes.float32) 4462 p2 = array_ops.placeholder(dtypes.float32) 4463 p3 = array_ops.placeholder(dtypes.float32) 4464 m, index = control_flow_ops.merge([p1, p2, p3]) 4465 self.assertIs(None, m.get_shape().ndims) 4466 self.assertEqual([], index.get_shape()) 4467 4468 # All inputs known with different ranks. 4469 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4470 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3]) 4471 m, index = control_flow_ops.merge([p1, p2]) 4472 self.assertIs(None, m.get_shape().ndims) 4473 self.assertEqual([], index.get_shape()) 4474 4475 # All inputs known with some dimensions different. 4476 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4477 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1]) 4478 m, index = control_flow_ops.merge([p1, p2]) 4479 self.assertEqual([None, None], m.get_shape().as_list()) 4480 self.assertEqual([], index.get_shape()) 4481 4482 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4483 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4484 m, index = control_flow_ops.merge([p1, p2]) 4485 self.assertEqual([None, 2], m.get_shape().as_list()) 4486 self.assertEqual([], index.get_shape()) 4487 4488 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4489 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2]) 4490 m, index = control_flow_ops.merge([p1, p2]) 4491 self.assertEqual([None, 2], m.get_shape().as_list()) 4492 self.assertEqual([], index.get_shape()) 4493 4494 # All inputs known with same dimensions. 4495 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4496 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4497 m, index = control_flow_ops.merge([p1, p2]) 4498 self.assertEqual([1, 2], m.get_shape().as_list()) 4499 self.assertEqual([], index.get_shape()) 4500 4501 p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4502 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4503 m, index = control_flow_ops.merge([p1, p2]) 4504 self.assertEqual([None, 2], m.get_shape().as_list()) 4505 self.assertEqual([], index.get_shape()) 4506 4507 p1 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4508 p2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4509 m, index = control_flow_ops.merge([p1, p2]) 4510 self.assertEqual([None, None], m.get_shape().as_list()) 4511 self.assertEqual([], index.get_shape()) 4512 4513 @test_util.run_v1_only("b/120545219") 4514 def testRefSelect(self): 4515 index = array_ops.placeholder(dtypes.int32) 4516 4517 # All inputs unknown. 4518 p1 = array_ops.placeholder(dtypes.float32) 4519 p2 = array_ops.placeholder(dtypes.float32) 4520 p3 = array_ops.placeholder(dtypes.float32) 4521 v1 = variables.VariableV1(p1, validate_shape=False) 4522 v2 = variables.VariableV1(p2, validate_shape=False) 4523 v3 = variables.VariableV1(p3, validate_shape=False) 4524 self.assertIs(None, v1.get_shape().ndims) 4525 s = control_flow_ops.ref_select(index, [v1, v2, v3]) 4526 self.assertIs(None, s.get_shape().ndims) 4527 4528 # All inputs known but different. 4529 v1 = variables.VariableV1([[1, 2]]) 4530 v2 = variables.VariableV1([[2], [1]]) 4531 s = control_flow_ops.ref_select(index, [v1, v2]) 4532 self.assertIs(None, s.get_shape().ndims) 4533 4534 # All inputs known and same. 4535 v1 = variables.VariableV1([[1, 2]]) 4536 v2 = variables.VariableV1([[1, 2]]) 4537 s = control_flow_ops.ref_select(index, [v1, v2]) 4538 self.assertEqual([1, 2], s.get_shape()) 4539 4540 # Possibly the same but not guaranteed. 4541 v1 = variables.VariableV1([[1., 2.]]) 4542 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4543 v2 = variables.VariableV1(p2, validate_shape=False) 4544 s = control_flow_ops.ref_select(index, [v1, v2]) 4545 self.assertEqual(None, s.get_shape()) 4546 4547 @test_util.run_deprecated_v1 4548 def testRunLoopTensor(self): 4549 with self.cached_session() as sess: 4550 tensor_list = [] 4551 4552 def condition(t): 4553 return t < constant_op.constant(5) 4554 4555 def body(_): 4556 tensor_list.append(constant_op.constant(5)) 4557 return constant_op.constant(10) 4558 4559 result = control_flow_ops.while_loop(condition, body, 4560 [constant_op.constant(4)]) 4561 self.assertEqual(10, self.evaluate(result)) 4562 4563 # Ensure that we cannot run a tensor that escapes the loop body 4564 # accidentally. 4565 with self.assertRaises(ValueError): 4566 sess.run(tensor_list[0]) 4567 4568 @test_util.run_v1_only("b/120545219") 4569 def testWhilePyFuncBasic(self): 4570 4571 def func(x): 4572 return np.square(x) 4573 4574 with self.cached_session(): 4575 r = control_flow_ops.while_loop( 4576 lambda i, v: i < 4, 4577 lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]], 4578 [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)], 4579 [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()]) 4580 self.assertEqual(self.evaluate(r[1]), 65536.0) 4581 4582 @test_util.run_v1_only("b/120545219") 4583 def testWhileFuncBasic(self): 4584 4585 @function.Defun(dtypes.float32) 4586 def func(x): 4587 return math_ops.square(math_ops.square(x)) 4588 4589 with self.cached_session(): 4590 x = constant_op.constant(2.0, dtypes.float32) 4591 r = control_flow_ops.while_loop( 4592 lambda i, v: i < 2, lambda i, v: [i + 1, func(v)], 4593 [constant_op.constant(0), x], 4594 [tensor_shape.unknown_shape(), 4595 tensor_shape.unknown_shape()]) 4596 grad = gradients_impl.gradients(r, x)[0] 4597 self.assertEqual(self.evaluate(r[1]), 65536.0) 4598 self.assertEqual(self.evaluate(grad), 524288.0) 4599 # while_v2 does not have stacks. 4600 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 4601 self.assertEqual( 4602 len([op for op in x.graph.get_operations() if op.type == "StackV2" 4603 ]), 1) 4604 4605 4606 @test_util.run_v1_only("b/120545219") 4607 def testQIntSwitchMerge(self): 4608 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4609 constant_qint = constant_op.constant(np.array([42]), dtypes.qint8) 4610 cond = constant_op.constant(True, dtypes.bool) 4611 v_f, v_t = control_flow_ops.switch(constant_qint, cond) 4612 result = control_flow_ops.merge([v_f, v_t]) 4613 self.evaluate(result) 4614 4615 @test_util.run_v1_only("b/120545219") 4616 def testQIntRefSwitchMerge(self): 4617 with self.cached_session(use_gpu=test.is_gpu_available()) as sess: 4618 var_qint = gen_state_ops.variable( 4619 shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="") 4620 assign_op = state_ops.assign( 4621 var_qint, constant_op.constant(np.array([42]), dtypes.qint8)) 4622 self.evaluate(assign_op) 4623 4624 cond = constant_op.constant(True, dtypes.bool) 4625 v_f, v_t = control_flow_ops.ref_switch(var_qint, cond) 4626 result = control_flow_ops.ref_merge([v_f, v_t]) 4627 self.evaluate(result) 4628 4629 @test_util.run_v1_only("b/120545219") 4630 def testUInt64SwitchMerge(self): 4631 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4632 constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64) 4633 cond = constant_op.constant(True, dtypes.bool) 4634 v_f, v_t = control_flow_ops.switch(constant_uint64, cond) 4635 result = control_flow_ops.merge([v_f, v_t]) 4636 self.evaluate(result) 4637 4638 def testSwitchEagerMode(self): 4639 if not context.executing_eagerly(): 4640 return 4641 input_data = [1, 2, 3, 4] 4642 vf, vt = control_flow_ops.switch(input_data, False) 4643 self.assertAllEqual(vf, input_data) 4644 self.assertAllEqual(vt, []) 4645 4646 @test_util.run_deprecated_v1 4647 def testQIntArgAndRet(self): 4648 4649 @function.Defun(dtypes.qint8) 4650 def func(x): 4651 return x 4652 4653 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4654 qint = constant_op.constant(np.array([42]), dtypes.qint8) 4655 result = func(qint) 4656 self.evaluate(result) 4657 4658 def testSparseIdentity(self): 4659 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4660 st2 = control_flow_ops._Identity(st1) 4661 self.assertAllEqual(st1.indices, st2.indices) 4662 self.assertAllEqual(st1.values, st2.values) 4663 self.assertAllEqual(st1.dense_shape, st2.dense_shape) 4664 4665 def testSparseEnterExit(self): 4666 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4667 st2 = control_flow_ops._Enter(st1, "foo_1") 4668 st3 = control_flow_ops.exit(st2) 4669 self.assertAllEqual(st1.indices, st3.indices) 4670 self.assertAllEqual(st1.values, st3.values) 4671 self.assertAllEqual(st1.dense_shape, st3.dense_shape) 4672 4673 def _buildWhileWithShapeInvariants(self, shape_invariants): 4674 r = constant_op.constant([1, 2]) 4675 4676 def cond(_): 4677 return False 4678 4679 def body(_): 4680 return constant_op.constant([1]) 4681 4682 return control_flow_ops.while_loop( 4683 cond, body, [r], shape_invariants=shape_invariants) 4684 4685 def testWhileOutputShapeWithShapeInvariantsUnknownRank(self): 4686 @def_function.function 4687 def runTest(): 4688 while_output = self._buildWhileWithShapeInvariants( 4689 [tensor_shape.TensorShape(None)]) 4690 self.assertIsNone(while_output.shape.rank) 4691 runTest() 4692 4693 def testWhileOutputShapeWithShapeInvariantsPartialShape(self): 4694 @def_function.function 4695 def runTest(): 4696 while_output = self._buildWhileWithShapeInvariants( 4697 [tensor_shape.TensorShape([None])]) 4698 self.assertAllEqual(while_output.shape.as_list(), [None]) 4699 runTest() 4700 4701 def testFunctionInWhile(self): 4702 4703 @def_function.function 4704 def body(x): 4705 return x + 1 4706 4707 r = control_flow_ops.while_loop(lambda x: x < 5, body, [0]) 4708 self.assertAllEqual(r, 5.) 4709 4710 4711class ControlFlowContextCheckTest(test.TestCase): 4712 4713 def _getWhileTensor(self): 4714 """Creates and returns a tensor from a while context.""" 4715 tensor = [] 4716 4717 def body(i): 4718 if not tensor: 4719 tensor.append(constant_op.constant(1)) 4720 return i + tensor[0] 4721 4722 control_flow_ops.while_loop(lambda i: i < 10, body, [0]) 4723 return tensor[0] 4724 4725 def _getCondTensor(self): 4726 cond_tensor = [] 4727 4728 def true_fn(): 4729 if not cond_tensor: 4730 cond_tensor.append(constant_op.constant(1)) 4731 return cond_tensor[0] 4732 4733 control_flow_ops.cond( 4734 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4735 return cond_tensor[0] 4736 4737 @test_util.run_v1_only("b/120545219") 4738 def testInvalidContext(self): 4739 # Accessing a while loop tensor outside of control flow is illegal. 4740 while_tensor = self._getWhileTensor() 4741 with self.assertRaisesRegex( 4742 ValueError, 4743 "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' " 4744 "is in a while loop. See info log for more details."): 4745 math_ops.add(1, while_tensor) 4746 4747 @test_util.run_v1_only("b/120545219") 4748 def testInvalidContextInCond(self): 4749 # Accessing a while loop tensor in cond is illegal. 4750 while_tensor = self._getWhileTensor() 4751 with self.assertRaisesRegex( 4752 ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because " 4753 "'while/Const_1' is in a while loop. See info log for more details."): 4754 # TODO(skyewm): this passes if we return while_tensor directly instead 4755 # of using it as input to another op. 4756 control_flow_ops.cond( 4757 math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor), 4758 lambda: constant_op.constant(0)) 4759 4760 @test_util.run_v1_only("b/120545219") 4761 def testInvalidContextInWhile(self): 4762 # Accessing a while loop tensor in a different while loop is illegal. 4763 while_tensor = self._getWhileTensor() 4764 with self.assertRaisesRegex( 4765 ValueError, 4766 "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are " 4767 "in different while loops. See info log for more details."): 4768 control_flow_ops.while_loop(lambda i: i < 10, 4769 lambda x: math_ops.add(1, while_tensor), [0]) 4770 4771 with self.assertRaisesRegex( 4772 ValueError, 4773 "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' " 4774 "because they are in different while loops. See info log for more " 4775 "details."): 4776 control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0]) 4777 4778 def testValidCondContext(self): 4779 # Accessing a tensor from a cond context is OK (although dangerous). 4780 cond_tensor = self._getCondTensor() 4781 math_ops.add(1, cond_tensor) 4782 4783 def testValidCondContextBranches(self): 4784 # Accessing a tensor from a cond context from the other branch's cond 4785 # context is OK (although dangerous). 4786 cond_tensor = [] 4787 4788 def branch_fn(): 4789 if not cond_tensor: 4790 cond_tensor.append(constant_op.constant(1)) 4791 return cond_tensor[0] 4792 4793 control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn) 4794 4795 @test_util.run_v1_only("b/120545219") 4796 def testValidWhileContext(self): 4797 # Accessing a tensor in a nested while is OK. 4798 def body(_): 4799 c = constant_op.constant(1) 4800 return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0]) 4801 4802 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4803 4804 @test_util.run_v1_only("b/120545219") 4805 def testValidNestedContexts(self): 4806 # Accessing a tensor from a cond context in a while context, all inside an 4807 # outer while context, is OK. 4808 def body(_): 4809 cond_tensor = self._getCondTensor() 4810 # Create another cond containing the while loop for good measure 4811 return control_flow_ops.cond( 4812 math_ops.less(1, 2), 4813 lambda: control_flow_ops.while_loop(lambda i: i < 3, 4814 lambda i: i + cond_tensor, [0]), 4815 lambda: constant_op.constant(0)) 4816 4817 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4818 4819 @test_util.run_v1_only("b/120545219") 4820 def testInvalidNestedContexts(self): 4821 # Accessing a tensor from a while context in a different while context, all 4822 # inside a cond context, is illegal. 4823 def true_fn(): 4824 while_tensor = self._getWhileTensor() 4825 return control_flow_ops.while_loop(lambda i: i < 3, 4826 lambda i: i + while_tensor, [0]) 4827 4828 with self.assertRaisesRegex( 4829 ValueError, 4830 "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because" 4831 " they are in different while loops. See info log for more details."): 4832 control_flow_ops.cond( 4833 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4834 4835 4836class TupleTest(test.TestCase): 4837 4838 @test_util.run_v1_only("b/120545219") 4839 def testTensors(self): 4840 for v1_first in [True, False]: 4841 with self.cached_session(): 4842 v1 = variables.VariableV1([1.0]) 4843 add1 = math_ops.add( 4844 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4845 2.0) 4846 v2 = variables.VariableV1([10.0]) 4847 add2 = math_ops.add( 4848 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4849 20.0) 4850 t1, _, t2 = control_flow_ops.tuple([add1, None, add2]) 4851 4852 # v1 is not initialized. 4853 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4854 self.evaluate(v1) 4855 4856 # v2 is not initialized. 4857 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4858 self.evaluate(v2) 4859 4860 if v1_first: 4861 # Getting t1 initializes v2. 4862 self.assertAllClose([3.0], self.evaluate(t1)) 4863 self.assertAllClose([10.0], self.evaluate(v2)) 4864 else: 4865 # Getting t2 initializes v1. 4866 self.assertAllClose([30.0], self.evaluate(t2)) 4867 self.assertAllClose([1.0], self.evaluate(v1)) 4868 4869 @test_util.run_v1_only("b/120545219") 4870 def testIndexedSlices(self): 4871 for v1_first in [True, False]: 4872 with self.cached_session(): 4873 v1 = variables.VariableV1( 4874 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype( 4875 np.float32)) 4876 v1_at_1 = indexed_slices.IndexedSlices( 4877 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4878 constant_op.constant([1])) 4879 4880 v2 = variables.VariableV1( 4881 np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype( 4882 np.float32)) 4883 v2_at_1 = indexed_slices.IndexedSlices( 4884 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4885 constant_op.constant([1])) 4886 4887 st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1]) 4888 g1 = array_ops.gather(st1.values, st1.indices) 4889 g2 = array_ops.gather(st2.values, st2.indices) 4890 4891 # v1 is not initialized. 4892 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4893 self.evaluate(v1) 4894 4895 # v2 is not initialized. 4896 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4897 self.evaluate(v2) 4898 4899 if v1_first: 4900 # Getting g1 initializes v2. 4901 self.assertAllClose([[10.0, 11.0]], self.evaluate(g1)) 4902 self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]], 4903 self.evaluate(v2)) 4904 else: 4905 # Getting g2 initializes v1. 4906 self.assertAllClose([[10.1, 11.1]], self.evaluate(g2)) 4907 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 4908 self.evaluate(v1)) 4909 4910 def testAcceptTensorsAsControlInputs(self): 4911 with self.cached_session(): 4912 var = variables.VariableV1(0) 4913 assign = state_ops.assign(var, 1) 4914 t, = control_flow_ops.tuple( 4915 [constant_op.constant(0)], control_inputs=[assign]) 4916 4917 # Should trigger the assign. 4918 self.evaluate(t) 4919 4920 self.assertEqual(1, self.evaluate(var)) 4921 4922 4923class AssertTest(test.TestCase): 4924 4925 @test_util.run_deprecated_v1 4926 def testGuardedAssertDoesNotCopyWhenTrue(self): 4927 if test_util.is_gpu_available(): 4928 self.skipTest("b/128646478 fails in opensource") 4929 4930 with self.session() as sess: 4931 with ops.device(test.gpu_device_name()): 4932 value = constant_op.constant(1.0) 4933 with ops.device("/cpu:0"): 4934 true = constant_op.constant(True) 4935 guarded_assert = control_flow_ops.Assert(true, [value], name="guarded") 4936 unguarded_assert = gen_logging_ops._assert( 4937 true, [value], name="unguarded") 4938 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 4939 guarded_metadata = config_pb2.RunMetadata() 4940 sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata) 4941 unguarded_metadata = config_pb2.RunMetadata() 4942 sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata) 4943 guarded_nodestat_names = [ 4944 n.node_name 4945 for d in guarded_metadata.step_stats.dev_stats 4946 for n in d.node_stats 4947 ] 4948 unguarded_nodestat_names = [ 4949 n.node_name 4950 for d in unguarded_metadata.step_stats.dev_stats 4951 for n in d.node_stats 4952 ] 4953 guarded_memcpy_nodestat_names = [ 4954 n for n in guarded_nodestat_names if "MEMCPYDtoH" in n 4955 ] 4956 unguarded_memcpy_nodestat_names = [ 4957 n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n 4958 ] 4959 if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: 4960 # A copy was performed for the unguarded assert 4961 self.assertLess(0, len(unguarded_memcpy_nodestat_names), 4962 str(unguarded_nodestat_names)) 4963 # No copy was performed for the guarded assert 4964 self.assertEqual([], guarded_memcpy_nodestat_names) 4965 4966 4967class WhileOpBenchmark(test.Benchmark): 4968 """Evaluate the performance of while_loop op.""" 4969 4970 def _getInitVariables(self): 4971 batch_size = 10 4972 image_size = 256 4973 kernel_size = 3 4974 depth = 16 4975 4976 init_step = constant_op.constant(-1) 4977 image = variable_scope.get_variable( 4978 "image", 4979 initializer=random_ops.random_normal( 4980 [batch_size, image_size, image_size, depth], 4981 dtype=dtypes.float32, 4982 stddev=1e-1)) 4983 kernel = variable_scope.get_variable( 4984 "weights", 4985 initializer=random_ops.truncated_normal( 4986 [kernel_size, kernel_size, depth, depth], 4987 dtype=dtypes.float32, 4988 stddev=1e-1)) 4989 return init_step, image, kernel 4990 4991 def _runOneBenchmark(self, 4992 default_device, 4993 num_iters=10, 4994 static_unroll=False, 4995 steps=10): 4996 """Evaluate the while loop performance. 4997 4998 Args: 4999 default_device: The default device to run all ops except the loop_body. 5000 loop_body is always run on GPU. 5001 num_iters: Number of iterations to run. 5002 static_unroll: If true, run unrolled version; otherwise, run while_loop. 5003 steps: Total number of repeated steps to run the loop. 5004 5005 Returns: 5006 The duration of the run in seconds. 5007 """ 5008 5009 def loop_body(i, x): 5010 with ops.device("/gpu:0"): 5011 # Always put loop body on GPU. 5012 nx = nn_ops.conv2d( 5013 input=x, 5014 filter=kernel, 5015 strides=[1, 1, 1, 1], 5016 padding="SAME", 5017 data_format="NHWC", 5018 name="conv2d") 5019 ni = math_ops.add(i, 1) 5020 return ni, nx 5021 5022 ops.reset_default_graph() 5023 with session.Session() as sess, ops.device(default_device): 5024 # Get the initial id i, input x, and kernel. 5025 i, x, kernel = self._getInitVariables() 5026 self.evaluate(variables.global_variables_initializer()) 5027 5028 if static_unroll: 5029 for _ in range(steps): 5030 i, x = loop_body(i, x) 5031 else: 5032 i, x = control_flow_ops.while_loop( 5033 lambda i, _: i < steps, 5034 loop_body, [i, x], 5035 parallel_iterations=steps, 5036 swap_memory=True) 5037 5038 r = math_ops.reduce_sum(x) 5039 dx, dk = gradients_impl.gradients(r, [x, kernel]) 5040 # Use group to avoid fetching back results. 5041 r = control_flow_ops.group(dx, dk) 5042 5043 for _ in range(3): 5044 # exclude warm up time 5045 self.evaluate(r) 5046 5047 start_time = time.time() 5048 for _ in range(num_iters): 5049 self.evaluate(r) 5050 return (time.time() - start_time) / num_iters 5051 5052 def benchmarkWhileOpCrossDevicePlacement(self): 5053 iters = 10 5054 # Run loop body on GPU, but other ops on CPU. 5055 duration = self._runOneBenchmark("cpu", iters, static_unroll=False) 5056 self.report_benchmark( 5057 name="while_op_cross_device", iters=iters, wall_time=duration) 5058 5059 def benchmarkWhileOpSameDevicePlacement(self): 5060 iters = 10 5061 # Run all ops on the same GPU device. 5062 duration = self._runOneBenchmark("gpu", iters, static_unroll=False) 5063 self.report_benchmark( 5064 name="while_op_same_device", iters=iters, wall_time=duration) 5065 5066 def benchmarkWhileOpUnrollCrossDevicePlacement(self): 5067 iters = 10 5068 # Run loop body on GPU, but other ops on CPU. 5069 duration = self._runOneBenchmark("cpu", iters, static_unroll=True) 5070 self.report_benchmark( 5071 name="unroll_cross_device_cpu", iters=iters, wall_time=duration) 5072 5073 def benchmarkWhileOpUnrollSameDevicePlacement(self): 5074 iters = 10 5075 # Run all ops on GPU. 5076 duration = self._runOneBenchmark("gpu", iters, static_unroll=True) 5077 self.report_benchmark( 5078 name="unroll_same_device", iters=iters, wall_time=duration) 5079 5080 5081@test_util.with_control_flow_v2 5082class EagerTest(test.TestCase): 5083 5084 def testCond(self): 5085 with context.eager_mode(): 5086 pred = math_ops.less(1, 2) 5087 fn1 = lambda: [constant_op.constant(10)] 5088 fn2 = lambda: [constant_op.constant(20)] 5089 r = control_flow_ops.cond(pred, fn1, fn2) 5090 5091 self.assertAllEqual(r.numpy(), 10) 5092 self.assertFalse(isinstance(r, list)) 5093 5094 # TODO(b/117279927): Re-enable once msan failure is fixed. 5095 def DISABLED_testCondInDefun(self): 5096 with context.eager_mode(): 5097 5098 @eager_function.defun 5099 def foo(pred): 5100 # TODO(b/111124878): this only needs to output one element. 5101 fn1 = lambda: (constant_op.constant(10), constant_op.constant(100)) 5102 fn2 = lambda: (constant_op.constant(20), constant_op.constant(200)) 5103 return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2) 5104 5105 r = foo(True) 5106 self.assertAllEqual(r[0].numpy(), 10) 5107 self.assertNotIsInstance(r, list) 5108 5109 r = foo(False) 5110 self.assertAllEqual(r[0].numpy(), 20) 5111 self.assertFalse(isinstance(r, list)) 5112 5113 def testWhileLoop(self): 5114 with context.eager_mode(): 5115 tensor = constant_op.constant([1, 2, 3, 4, 5]) 5116 self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50]) 5117 5118 def testWhileLoopWithMaxIterations(self): 5119 with context.eager_mode(): 5120 tensor = constant_op.constant([1, 2, 3, 4, 5]) 5121 self.assertAllEqual( 5122 isum(tensor, maximum_iterations=3).numpy(), 5123 [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3]) 5124 5125 @test_util.run_v1_only("b/120545219") 5126 def testWhileWithMaximumIterationsAndSingleArgument(self): 5127 with context.eager_mode(): 5128 tensor = constant_op.constant(0) 5129 r = control_flow_ops.while_loop( 5130 lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1) 5131 self.assertEqual(1, r.numpy()) 5132 5133 def testWithDependencies(self): 5134 with context.eager_mode(): 5135 t1 = constant_op.constant(1) 5136 t2 = constant_op.constant(2) 5137 t3 = control_flow_ops.with_dependencies(t1, t2) 5138 self.assertAllEqual(t2.numpy(), t3.numpy()) 5139 5140 def testTuple(self): 5141 with context.eager_mode(): 5142 t1 = constant_op.constant(1) 5143 t2 = constant_op.constant(2) 5144 tup1, tup2 = control_flow_ops.tuple([t1, t2]) 5145 self.assertAllEqual(t1.numpy(), tup1.numpy()) 5146 self.assertAllEqual(t2.numpy(), tup2.numpy()) 5147 5148 @test_util.run_v1_only("b/120545219") 5149 def testCase(self): 5150 with context.eager_mode(): 5151 x = constant_op.constant(1) 5152 y = constant_op.constant(2) 5153 z = constant_op.constant(3) 5154 f1 = lambda: constant_op.constant(17) 5155 f2 = lambda: constant_op.constant(23) 5156 f3 = lambda: constant_op.constant(-1) 5157 5158 r1 = control_flow_ops.case( 5159 [(x < y, f1), (x > z, f2)], default=f3, exclusive=True) 5160 self.assertAllEqual(r1.numpy(), 17) 5161 5162 5163if __name__ == "__main__": 5164 test.main() 5165