1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for while_v2.""" 16 17from absl.testing import parameterized 18 19from google.protobuf import text_format 20from tensorflow.core.framework import graph_pb2 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.core.protobuf import rewriter_config_pb2 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import function 29from tensorflow.python.framework import importer 30from tensorflow.python.framework import meta_graph 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import test_util 35from tensorflow.python.grappler import tf_optimizer 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import control_flow_util 39from tensorflow.python.ops import control_flow_util_v2 40from tensorflow.python.ops import control_flow_v2_toggles 41from tensorflow.python.ops import custom_gradient 42from tensorflow.python.ops import gen_array_ops 43from tensorflow.python.ops import gen_list_ops 44from tensorflow.python.ops import gradient_checker_v2 45from tensorflow.python.ops import gradients_impl 46from tensorflow.python.ops import list_ops 47from tensorflow.python.ops import map_fn 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import random_ops 50from tensorflow.python.ops import variables 51from tensorflow.python.ops import while_v2 52from tensorflow.python.ops.ragged import ragged_factory_ops 53from tensorflow.python.ops.ragged import ragged_tensor 54from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 55from tensorflow.python.platform import test 56 57 58def random_gamma(shape): # pylint: disable=invalid-name 59 return random_ops.random_gamma(shape, 1.0) 60 61 62def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name 63 return random_ops.random_gamma( 64 shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]]) 65 66 67def random_poisson_v2(shape): # pylint: disable=invalid-name 68 return random_ops.random_poisson_v2(shape, 1.0) 69 70 71def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name 72 return random_ops.random_poisson_v2(shape, [12.2, 3.3]) 73 74 75def fill(shape): # pylint: disable=invalid-name 76 return array_ops.fill(shape, 1.0) 77 78 79class WhileV2Test(test.TestCase, parameterized.TestCase): 80 81 @test_util.run_deprecated_v1 82 def testSingleLoopVar(self): 83 x = constant_op.constant(2.) 84 ret = while_loop_v2( 85 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) 86 grad = gradients_impl.gradients(ret, [x]) 87 with self.cached_session(): 88 self.assertEqual(self.evaluate(ret), 16.) 89 self.assertSequenceEqual(self.evaluate(grad), [32.]) 90 91 @test_util.run_deprecated_v1 92 def testSingleLoopVarBackPropFalse(self): 93 x = constant_op.constant(2.) 94 ret = while_loop_v2( 95 lambda v: v < 8., 96 lambda v: v * v, [x], 97 return_same_structure=False, 98 back_prop=False) 99 grad = gradients_impl.gradients(ret, [x]) 100 self.assertEqual(grad, [None]) 101 with self.cached_session(): 102 self.assertEqual(self.evaluate(ret), 16.) 103 104 @test_util.run_deprecated_v1 105 def testCustomGradient(self): 106 x = constant_op.constant(2.) 107 n = constant_op.constant(1., name="const-n") 108 m = variables.Variable(1.0) 109 self.evaluate(variables.global_variables_initializer()) 110 111 def body_fn(v): # pylint: disable=invalid-name 112 113 @custom_gradient.custom_gradient 114 def inner_fn(v): # pylint: disable=invalid-name 115 116 def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer-name 117 return dy * 2 * v * n * m, [v * v] 118 119 return v * v * m, grad_fn 120 121 return inner_fn(v) 122 123 ret = while_loop_v2( 124 lambda v: v < 8., body_fn, [x], return_same_structure=False) 125 grad = gradients_impl.gradients(ret, [x]) 126 with self.cached_session(): 127 self.assertEqual(self.evaluate(ret), 16.) 128 self.assertSequenceEqual(self.evaluate(grad), [32.]) 129 130 @test_util.run_v1_only("b/120545219") 131 def testReturnSameStructureTrue(self): 132 x = constant_op.constant(2.) 133 ret = while_loop_v2( 134 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True) 135 grad = gradients_impl.gradients(ret, [x]) 136 with self.cached_session() as sess: 137 eval_result = sess.run(ret) 138 self.assertIsInstance(eval_result, list) 139 self.assertLen(eval_result, 1) 140 self.assertEqual(16., eval_result[0]) 141 self.assertSequenceEqual(sess.run(grad), [32.]) 142 143 def testVerifyInputOutputTypesMatch(self): 144 145 @def_function.function 146 def BuildWhile(): 147 x = constant_op.constant(1., dtypes.float32) 148 149 def Body(x): 150 return math_ops.cast(x, dtypes.float16) + 1 151 152 while_loop_v2(lambda x: x < 10, Body, [x]) 153 154 with self.assertRaisesRegex( 155 TypeError, 156 r"Loop var Const:0 enters the loop with type <dtype: 'float32'> " 157 r"but has type <dtype: 'float16'> after 1 iteration."): 158 BuildWhile() 159 160 @parameterized.parameters(dtypes.float32, dtypes.float64) 161 def testGradientTapeResourceVariable(self, dtype): 162 with context.eager_mode(): 163 v = variables.Variable(1., dtype=dtype) 164 165 @def_function.function 166 def fnWithLoop(): # pylint: disable=invalid-name 167 with backprop.GradientTape() as tape: 168 _, x = while_loop_v2( 169 lambda i, _: i < 2, 170 lambda i, x: (i + 1, x * v), 171 [0, constant_op.constant(2., dtype=dtype)]) 172 return tape.gradient(x, v) 173 174 self.assertAllEqual(fnWithLoop(), 4.0) 175 176 def testDeferredCaptures(self): 177 with context.eager_mode(): 178 c = constant_op.constant(10) 179 180 @def_function.function 181 def F(): 182 183 def Body(_): 184 return ops.get_default_graph().capture_call_time_value( 185 lambda: c, tensor_spec.TensorSpec([], dtypes.int32)) 186 187 x, = while_loop_v2(lambda i: True, Body, [0], maximum_iterations=1) 188 return x 189 190 self.assertAllEqual(F(), 10) 191 192 def checkIteratedGradients(self, func): 193 with context.eager_mode(): 194 195 def _Grad(f): 196 def _GradFunction(primal): 197 with backprop.GradientTape() as tape: 198 tape.watch(primal) 199 primal_out = f(primal) 200 return tape.gradient(primal_out, primal) 201 return _GradFunction 202 203 f = func 204 one = constant_op.constant(1.) 205 206 for _ in range(3): 207 theoretical, numerical = gradient_checker_v2.compute_gradient( 208 def_function.function(f), [one]) 209 self.assertAllClose(theoretical, numerical, rtol=1e-3) 210 f = _Grad(f) 211 self.assertAllClose(array_ops.reshape(numerical, []), 212 def_function.function(f)(one), 213 rtol=1e-3) 214 215 def testIteratedGradients(self): 216 217 def _Func(x): 218 _, z = while_loop_v2( 219 lambda i, _: i < 2, 220 lambda i, y: (i + 1, math_ops.cos(y)), 221 [0, x]) 222 return z 223 224 self.checkIteratedGradients(_Func) 225 226 def testIteratedGradientsWithList(self): 227 228 def _Func(x): 229 results = list_ops.empty_tensor_list( 230 element_shape=[], element_dtype=dtypes.float32) 231 232 def _LoopBody(i, y, handle): 233 return (i + 1, math_ops.cos(y), 234 list_ops.tensor_list_push_back(handle, y)) 235 236 _, z, results = while_loop_v2( 237 lambda i, _, h: i < 2, _LoopBody, [0, x, results]) 238 return z + math_ops.reduce_sum(list_ops.tensor_list_stack( 239 results, dtypes.float32)) 240 241 self.checkIteratedGradients(_Func) 242 243 def testGradWhileGradWhileWithVariable(self): 244 with context.eager_mode(): 245 v = variables.Variable(1.) 246 247 @def_function.function 248 def _Func(x): 249 250 def _Inner(a): 251 with backprop.GradientTape() as tape: 252 tape.watch(a) 253 _, b = while_loop_v2( 254 lambda i, _: i < 2, 255 lambda i, y: (i + 1, math_ops.cos(v + y)), 256 [0, a]) 257 return tape.gradient(b, a) 258 259 _, z = while_loop_v2( 260 lambda i, _: i < 2, 261 lambda i, y: (i + 1, _Inner(y)), 262 [0, x]) 263 return z 264 265 with backprop.GradientTape(persistent=True) as tape: 266 x = constant_op.constant(1.) 267 tape.watch(x) 268 y = _Func(x) 269 dx, _ = tape.gradient(y, [x, v]) 270 theoretical, numerical = gradient_checker_v2.compute_gradient( 271 _Func, [x]) 272 self.assertAllClose(numerical, theoretical, rtol=1e-3) 273 self.assertAllClose(array_ops.reshape(numerical, []), 274 dx, rtol=1e-3) 275 276 def testThreeNestWithLists(self): 277 with context.eager_mode(): 278 def _WrapInWhile(f): 279 def _Wrapped(x): 280 results = list_ops.empty_tensor_list( 281 element_shape=[], element_dtype=dtypes.float32) 282 283 def _LoopBody(i, y, handle): 284 return (i + 1, f(math_ops.cos(y)), 285 list_ops.tensor_list_push_back(handle, y)) 286 287 _, z, results = control_flow_ops.while_loop( 288 lambda i, _, h: i < 2, _LoopBody, [0, x, results]) 289 return z + math_ops.reduce_sum(list_ops.tensor_list_stack( 290 results, dtypes.float32)) 291 return _Wrapped 292 293 f = math_ops.sin 294 295 target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f))) 296 297 @def_function.function 298 def _TapeFromGraphMode(x): 299 with backprop.GradientTape(persistent=True) as tape: 300 tape.watch(x) 301 y = target_function(x) 302 return tape.gradient(y, x) 303 304 x = constant_op.constant(1.) 305 dx = _TapeFromGraphMode(x) 306 theoretical, numerical = gradient_checker_v2.compute_gradient( 307 target_function, [x]) 308 self.assertAllClose(numerical, theoretical, rtol=3e-3) 309 self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3) 310 311 def testDeviceLabelsInherited(self): 312 def _LoopBody(i, y): 313 result = math_ops.cos(y) 314 self.assertIn("CPU:10", result.device) 315 with ops.device("CPU:11"): 316 result = array_ops.identity(result) 317 self.assertIn("CPU:11", result.device) 318 return i + 1, result 319 320 @def_function.function 321 def _FunctionWithWhileLoop(): 322 x = constant_op.constant(1.) 323 with ops.device("CPU:10"): 324 _, z = while_loop_v2( 325 lambda i, _: i < 2, 326 _LoopBody, 327 [0, x]) 328 return z 329 # The test assertion runs at trace time. 330 _FunctionWithWhileLoop.get_concrete_function() 331 332 def testExternalControlDependencies(self): 333 with ops.Graph().as_default(), self.test_session(): 334 v = variables.Variable(1.) 335 self.evaluate(v.initializer) 336 op = v.assign_add(1.) 337 338 def body_fn(i): # pylint: disable=invalid-name 339 with ops.control_dependencies([op]): 340 return i + 1 341 342 loop = while_loop_v2(lambda i: i < 1, body_fn, [0]) 343 loop[0].op.run() 344 self.assertAllEqual(self.evaluate(v), 2.0) 345 346 @test_util.run_deprecated_v1 347 def testMultipleLoopVarsBasic(self): 348 x = constant_op.constant(5.) 349 y = constant_op.constant(3.) 350 351 # x = 5. 352 # y = 3. 353 # while x < 45.: 354 # x = x * y 355 ret = while_loop_v2( 356 lambda v, _: v < 45., 357 lambda v, w: (v * w, w), [x, y], 358 return_same_structure=False) 359 # ret = [x*y^2, y] 360 361 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 362 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 363 with self.cached_session(): 364 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 365 self.assertSequenceEqual(self.evaluate(grad), [9.]) 366 367 @test_util.run_deprecated_v1 368 def testMultipleLoopNonscalarCond(self): 369 x = constant_op.constant([[5.]]) 370 y = constant_op.constant(3.) 371 372 # x = 5. 373 # y = 3. 374 # while x < 45.: 375 # x = x * y 376 ret = while_loop_v2( 377 lambda v, _: v < 45., 378 lambda v, w: (v * w, w), [x, y], 379 return_same_structure=False) 380 # ret == [x*y^2, y] 381 382 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 383 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 384 with self.cached_session(): 385 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 386 self.assertSequenceEqual(self.evaluate(grad), [9.]) 387 388 @test_util.run_deprecated_v1 389 def testMultipleLoopVars(self): 390 x = constant_op.constant(5.) 391 y = constant_op.constant(3.) 392 393 # x = 5. 394 # y = 3. 395 # while x < 45.: 396 # x = x * y 397 # y = x + y 398 ret = while_loop_v2( 399 lambda v, _: v < 45., 400 lambda v, w: (v * w, v + w), [x, y], 401 return_same_structure=False) 402 # ret = [y*x**2 + x*y**2, x*y + x + y] 403 404 gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] 405 gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] 406 gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] 407 grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] 408 grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] 409 grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] 410 with self.cached_session(): 411 self.assertSequenceEqual(self.evaluate(ret), [120., 23.]) 412 self.assertSequenceEqual(self.evaluate(gradx_0), [39.]) 413 self.assertSequenceEqual(self.evaluate(gradx_1), [4.]) 414 self.assertSequenceEqual(self.evaluate(gradx_2), [43.]) 415 self.assertSequenceEqual(self.evaluate(grady_0), [55.]) 416 self.assertSequenceEqual(self.evaluate(grady_1), [6.]) 417 self.assertSequenceEqual(self.evaluate(grady_2), [61.]) 418 419 @test_util.run_deprecated_v1 420 def testGradientTape(self): 421 with backprop.GradientTape() as t: 422 x = constant_op.constant(2.) 423 t.watch(x) 424 ret = while_loop_v2( 425 lambda v: v < 4., lambda v: v * v, [x], 426 return_same_structure=False) # x**2 427 grad = t.gradient(ret, x) 428 with self.cached_session() as sess: 429 self.assertAllEqual(sess.run(grad), 4.0) 430 431 @test_util.run_deprecated_v1 432 def testMultipleWhileLoops(self): 433 x = constant_op.constant(2.) 434 ret1 = while_loop_v2( 435 lambda v: v < 4., lambda v: v * v, [x], 436 return_same_structure=False) # x**2 437 ret2 = while_loop_v2( 438 lambda v: v < 16., lambda v: v * v, [ret1], 439 return_same_structure=False) # x**4 440 grad = gradients_impl.gradients(ret2, [x]) # 4x**3 441 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 442 with self.cached_session(): 443 self.assertSequenceEqual(self.evaluate(grad), [32.]) 444 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 445 446 def testMultipleWhileLoopsWithFunc(self): 447 x = constant_op.constant(2.) 448 449 @def_function.function 450 def Fn(): 451 ret1 = while_loop_v2( 452 lambda v: v < 4., 453 lambda v: v * v, [x], 454 return_same_structure=False, 455 name="while_1") # x**2 456 ret2 = while_loop_v2( 457 lambda v: v < 16., 458 lambda v: v * v, [x], 459 return_same_structure=False, 460 name="while_2") # x**4 461 return ret1, ret2 462 463 concrete_fn = Fn.get_concrete_function() 464 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 465 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 466 self.assertEqual(while_1.type, "StatelessWhile") 467 self.assertEqual(while_2.type, "StatelessWhile") 468 self.assertEmpty(while_1.control_inputs) 469 self.assertEmpty(while_2.control_inputs) 470 471 def testMultipleWhileLoopsGradStateless(self): 472 473 @def_function.function 474 def Fn(): 475 x = constant_op.constant(2.) 476 with backprop.GradientTape() as tape: 477 tape.watch(x) 478 ret1 = while_loop_v2( 479 lambda v: v < 4., 480 lambda v: v * v, [x], 481 return_same_structure=False, 482 name="while_1") # x**2 483 ret2 = while_loop_v2( 484 lambda v: v < 16., 485 lambda v: v * v, [x], 486 return_same_structure=False, 487 name="while_2") # x**4 488 loss = ret1 + ret2 489 return tape.gradient(loss, x) 490 491 graph = Fn.get_concrete_function().graph 492 while_ops = [op for op in graph.get_operations() if "While" in op.type] 493 self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4, 494 "Must have exactly 4 StatelessWhile ops.") 495 for op in while_ops: 496 self.assertEmpty(op.control_inputs, 497 "{} should not have any control inputs".format(op.name)) 498 499 def testMultipleWhileLoopsWithDeps(self): 500 x = variables.Variable(2.) 501 c = constant_op.constant(2.) 502 503 @def_function.function 504 def Fn(): 505 506 def Body1(v): 507 x.assign(x) 508 return v * x 509 510 ret1 = while_loop_v2( 511 lambda v: v < 4., 512 Body1, [c], 513 return_same_structure=False, 514 name="while_1") # 2x 515 516 def Body2(v): 517 x.assign(x) 518 return v * x * x 519 520 ret2 = while_loop_v2( 521 lambda v: v < 16., 522 Body2, [c], 523 return_same_structure=False, 524 name="while_2") # 4x 525 return ret1, ret2 526 527 concrete_fn = Fn.get_concrete_function() 528 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 529 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 530 self.assertEqual(while_1.type, "While") 531 self.assertEqual(while_2.type, "While") 532 self.assertEmpty(while_1.control_inputs) 533 self.assertLen(while_2.control_inputs, 1) 534 self.assertIs(while_2.control_inputs[0], while_1) 535 536 def testMultipleWhileLoopsWithVarsDeps(self): 537 x1 = variables.Variable(2.) 538 x2 = variables.Variable(3.) 539 c = constant_op.constant(2.) 540 541 @def_function.function 542 def Fn(): 543 544 def Body1(v): 545 x1.assign(x1) 546 return v * x1 547 548 ret1 = while_loop_v2( 549 lambda v: v < 4., 550 Body1, [c], 551 return_same_structure=False, 552 name="while_1") # 2x 553 554 def Body2(v): 555 x1.assign(x1) 556 return v * x1 * x1 557 558 ret2 = while_loop_v2( 559 lambda v: v < 16., 560 Body2, [c], 561 return_same_structure=False, 562 name="while_2") # 4x 563 564 def Body3(v): 565 x2.assign(x2) 566 return v * x2 567 568 ret3 = while_loop_v2( 569 lambda v: v < 4., 570 Body3, [c], 571 return_same_structure=False, 572 name="while_3") # 3x 573 574 def Body4(v): 575 x2.assign(x2) 576 return v * x2 * x2 577 578 ret4 = while_loop_v2( 579 lambda v: v < 16., 580 Body4, [c], 581 return_same_structure=False, 582 name="while_4") # 9x 583 ret5 = while_loop_v2( 584 lambda v: v < 16., 585 lambda v: v * v, [c], 586 return_same_structure=False, 587 name="while_stateless") # x**2 588 return ret1, ret2, ret3, ret4, ret5 589 590 concrete_fn = Fn.get_concrete_function() 591 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 592 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 593 while_3 = concrete_fn.graph.get_operation_by_name("while_3") 594 while_4 = concrete_fn.graph.get_operation_by_name("while_4") 595 while_stateless = concrete_fn.graph.get_operation_by_name( 596 "while_stateless") 597 self.assertEqual(while_1.type, "While") 598 self.assertEqual(while_2.type, "While") 599 self.assertEqual(while_3.type, "While") 600 self.assertEqual(while_4.type, "While") 601 self.assertEqual(while_stateless.type, "StatelessWhile") 602 self.assertEmpty(while_1.control_inputs) 603 self.assertLen(while_2.control_inputs, 1) 604 self.assertIs(while_2.control_inputs[0], while_1) 605 self.assertEmpty(while_3.control_inputs) 606 self.assertLen(while_4.control_inputs, 1) 607 self.assertIs(while_4.control_inputs[0], while_3) 608 self.assertEmpty(while_stateless.control_inputs) 609 610 @test_util.run_deprecated_v1 611 def testDoubleDerivative(self): 612 x = constant_op.constant(2.) 613 ret = while_loop_v2( 614 lambda v: v < 8., lambda v: v**2, [x], 615 return_same_structure=False) # x**4 616 grad = gradients_impl.gradients(ret, [x]) # 4x**3 617 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 618 with self.cached_session(): 619 self.assertEqual(self.evaluate(ret), 16.) 620 self.assertSequenceEqual(self.evaluate(grad), [32.]) 621 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 622 623 @test_util.run_v2_only 624 def testMultipleWhileLoopsEager(self): 625 626 @def_function.function 627 def Func(): 628 x = constant_op.constant(2.) 629 ret1 = while_loop_v2( 630 lambda v: v < 4., lambda v: v * v, [x], 631 return_same_structure=False) # x**2 632 ret2 = while_loop_v2( 633 lambda v: v < 16., 634 lambda v: v * v, [ret1], 635 return_same_structure=False) # x**4 636 grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3 637 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 638 return grad, grad_grad 639 640 grad, grad_grad = Func() 641 self.assertEqual(grad.numpy(), 32.) 642 self.assertEqual(grad_grad.numpy(), 48.) 643 644 @test_util.run_v2_only 645 def testDoubleDerivativeEager(self): 646 647 @def_function.function 648 def Func(): 649 x = constant_op.constant(2.) 650 ret = while_loop_v2( 651 lambda v: v < 8., lambda v: v**2, [x], 652 return_same_structure=False) # x**4 653 grad = gradients_impl.gradients(ret, [x])[0] # 4x**3 654 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 655 return ret, grad, grad_grad 656 657 ret, grad, grad_grad = Func() 658 self.assertEqual(ret.numpy(), 16.) 659 self.assertEqual(grad.numpy(), 32.) 660 self.assertEqual(grad_grad.numpy(), 48.) 661 662 def _testPruning(self): 663 x = constant_op.constant(1) 664 665 tensor_list = list_ops.empty_tensor_list( 666 element_dtype=x.dtype, element_shape=x.shape) 667 668 def Cond(x, tl): 669 del tl # Unused for Cond. 670 return x < 5 671 672 def Body(x, tl): 673 return x + 1, list_ops.tensor_list_push_back(tl, x) 674 675 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 676 677 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 678 train_op.append(outputs[0]) 679 680 g = GetOptimizedGraph() 681 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 682 # away, causing an extra Enter node. 683 enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 684 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 685 # Test that the TensorList is pruned out. 686 self.assertEmpty([ 687 n for n in g.node if n.op == "Enter" and 688 n.attr["T"].type == dtypes.variant.as_datatype_enum 689 ]) 690 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 691 692 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 693 train_op.append(stack) 694 g = GetOptimizedGraph() 695 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 696 # away, causing an extra Enter node. 697 enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 698 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 699 # Test that the TensorList is not pruned out. 700 self.assertNotEmpty([ 701 n for n in g.node if n.op == "Enter" and 702 n.attr["T"].type == dtypes.variant.as_datatype_enum 703 ]) 704 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 705 706 @test_util.run_deprecated_v1 707 def testPruningV1(self): 708 self._testPruning() 709 710 @test_util.enable_control_flow_v2 711 @test_util.run_deprecated_v1 712 def testPruningV2(self): 713 self._testPruning() 714 715 def _testDoNotAccumulateInvariants(self): 716 push_op = ("TensorListPushBack" 717 if control_flow_v2_toggles.control_flow_v2_enabled() else 718 "StackPushV2") 719 720 # Tests that loop invariants, i.e., tensors that are "captured" by the 721 # while loop and not passed as loop variables are not accumulated in 722 # gradient computation. 723 v = constant_op.constant(5.0, name="v") 724 725 r = control_flow_ops.while_loop( 726 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 727 728 output = gradients_impl.gradients(r, v)[0] 729 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 730 train_op.append(output) 731 732 g = GetOptimizedGraph() 733 # The gradient for v * x requires the value of both v and x. Since v is a 734 # loop invariant it is not accumulated so we have just one accumulator for 735 # x. 736 self.assertLen([n for n in g.node if n.op == push_op], 1) 737 738 @test_util.run_deprecated_v1 739 def testDoNotAccumulateInvariantsV1(self): 740 self._testDoNotAccumulateInvariants() 741 742 @test_util.run_deprecated_v1 743 @test_util.enable_control_flow_v2 744 def testDoNotAccumulateInvariantsV2(self): 745 self._testDoNotAccumulateInvariants() 746 747 @test_util.enable_control_flow_v2 748 @test_util.run_deprecated_v1 749 @test_util.enable_output_all_intermediates 750 def testPruningNested(self): 751 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 752 x = constant_op.constant(0) 753 754 tensor_list = list_ops.empty_tensor_list( 755 element_dtype=x.dtype, element_shape=x.shape) 756 757 def Cond(x, tl): 758 del tl # Unused for Cond. 759 return x < 25 760 761 def Body(x, tl): 762 763 def InnerCond(inner_x, unused_outer_x, unused_tl): 764 return inner_x < 5 765 766 def InnerBody(inner_x, outer_x, tl): 767 return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x) 768 769 inner_x = constant_op.constant(0) 770 return control_flow_ops.while_loop(InnerCond, InnerBody, 771 [inner_x, x, tl])[1:] 772 773 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 774 775 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 776 train_op.append(outputs[0]) 777 778 g = GetOptimizedGraph() 779 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 780 # away, causing an extra Enter node. 781 # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 782 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 783 # Test that the TensorList is pruned out. 784 self.assertEmpty([ 785 n for n in g.node if n.op == "Enter" and 786 n.attr["T"].type == dtypes.variant.as_datatype_enum 787 ]) 788 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 789 self.assertEmpty([n for n in g.node if n.op == "_While"]) 790 791 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 792 train_op.append(stack) 793 g = GetOptimizedGraph() 794 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 795 # away, causing an extra Enter node. 796 # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 797 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 798 # Test that the TensorList is not pruned out. 799 self.assertNotEmpty([ 800 n for n in g.node if n.op == "Enter" and 801 n.attr["T"].type == dtypes.variant.as_datatype_enum 802 ]) 803 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 804 805 @test_util.enable_control_flow_v2 806 @test_util.run_deprecated_v1 807 @test_util.enable_output_all_intermediates 808 def testPruningNested2(self): 809 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 810 v = constant_op.constant(5.0, name="v") 811 812 p = array_ops.placeholder(dtype=dtypes.int32) 813 814 def MidBodyBuilder(iterations): 815 816 def MidBody(i, x): 817 r = control_flow_ops.while_loop( 818 lambda *_: True, 819 lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")), 820 (0, x), 821 maximum_iterations=iterations, 822 name="inner") 823 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 824 825 return MidBody 826 827 def OuterBody(i, x): 828 iterations = array_ops.size(p, name="iterations") 829 return (i + 1, x + control_flow_ops.while_loop( 830 lambda *_: True, 831 MidBodyBuilder(iterations), (0, x), 832 maximum_iterations=iterations, 833 name="mid")[1]) 834 835 def CreateWhileLoop(): 836 with ops.device("/cpu:0"): 837 r = control_flow_ops.while_loop( 838 lambda *_: True, 839 OuterBody, (0, 1.0), 840 maximum_iterations=5, 841 name="outer") 842 return array_ops.identity(r[1]) 843 844 output = CreateWhileLoop() 845 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 846 train_op.append(output) 847 848 g = GetOptimizedGraph() 849 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 850 851 @test_util.enable_control_flow_v2 852 @test_util.run_deprecated_v1 853 @test_util.enable_output_all_intermediates 854 def testPruningNested3(self): 855 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 856 v = constant_op.constant(5.0, name="v") 857 858 def CreateWhileLoop(): 859 r = control_flow_ops.while_loop( 860 lambda _: True, 861 lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0], 862 maximum_iterations=5, 863 name="outer") 864 return array_ops.identity(r) 865 866 r = CreateWhileLoop() 867 output = gradients_impl.gradients(r, v)[0] 868 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 869 train_op.append(output) 870 871 g = GetOptimizedGraph() 872 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 873 874 def _assertNotAccumulated(self, while_op, index): 875 """Asserts that `while_op` input at `index` is not accumulated.""" 876 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 877 placeholder = body_graph.inputs[index] 878 self.assertNotIn("TensorListPushBack", 879 [op.type for op in placeholder.consumers()]) 880 881 @test_util.enable_control_flow_v2 882 @test_util.run_deprecated_v1 883 @test_util.enable_output_all_intermediates 884 def testDoNotOutputLoopCounterAsIntermediate(self): 885 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 886 v = constant_op.constant(5.0, name="v") 887 r = control_flow_ops.while_loop( 888 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 889 # Skip over Identity. 890 while_op = r.op.inputs[0].op 891 self._assertNotAccumulated(while_op, 0) 892 893 @test_util.enable_control_flow_v2 894 @test_util.run_deprecated_v1 895 @test_util.enable_output_all_intermediates 896 def testDoNotOutputLoopInvariantAsIntermediate(self): 897 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 898 899 def GetInputIndex(op, tensor): 900 for index, inp in enumerate(op.inputs): 901 if inp is tensor: 902 return index 903 904 v = constant_op.constant(5.0, name="v") 905 r = control_flow_ops.while_loop( 906 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 907 # Skip over Identity. 908 while_op = r.op.inputs[0].op 909 # We can't directly use while_op.inputs.index() because Tensors are not 910 # hashable. 911 index = GetInputIndex(while_op, v) 912 self._assertNotAccumulated(while_op, index) 913 914 @test_util.run_deprecated_v1 915 def testCaptureExternalTensorInCond(self): 916 x = constant_op.constant(2.) 917 y = constant_op.constant(1.) 918 ret = while_loop_v2( 919 lambda v: v + y < 9., 920 lambda v: v * 3., [x], 921 return_same_structure=False) 922 grad = gradients_impl.gradients(ret, [x]) 923 with self.cached_session(): 924 self.assertEqual(self.evaluate(ret), 18.) 925 self.assertSequenceEqual(self.evaluate(grad), [9.]) 926 927 @test_util.run_deprecated_v1 928 def testCaptureExternalTensorInBody(self): 929 x = constant_op.constant(2.) 930 y = constant_op.constant(3.) 931 ret = while_loop_v2( 932 lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False) 933 grad = gradients_impl.gradients(ret, [x]) 934 with self.cached_session(): 935 self.assertEqual(self.evaluate(ret), 18.) 936 self.assertSequenceEqual(self.evaluate(grad), [9.]) 937 938 @test_util.run_deprecated_v1 939 def testLoopWithTensorListPushBack(self): 940 x = constant_op.constant(2.) 941 942 tensor_list = list_ops.empty_tensor_list( 943 element_dtype=dtypes.float32, element_shape=ScalarShape()) 944 945 def Cond(x, tl): 946 del tl # Unused for Cond. 947 return x < 5. 948 949 def Body(x, tl): 950 tl = list_ops.tensor_list_push_back(tl, x) 951 tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) 952 return x**2., tl 953 954 ret = while_loop_v2( 955 Cond, Body, [x, tensor_list], return_same_structure=False) 956 grad = gradients_impl.gradients(ret[0], x) 957 with self.cached_session() as sess: 958 self.assertEqual(sess.run(ret[0]), 16.) 959 self.assertSequenceEqual(self.evaluate(grad), [32.]) 960 961 @test_util.run_deprecated_v1 962 def testDuplicateAccumulator(self): 963 x = constant_op.constant(2.) 964 965 tensor_list = list_ops.empty_tensor_list( 966 element_dtype=dtypes.float32, element_shape=ScalarShape()) 967 968 def Cond(x, tl): 969 del tl # Unused for Cond. 970 return x < 5. 971 972 def Body(x, tl): 973 # There is an accumulator in the loop already so we should not add 974 # another. 975 tl = list_ops.tensor_list_push_back(tl, x) 976 return x**2., tl 977 978 ret = while_loop_v2( 979 Cond, Body, [x, tensor_list], return_same_structure=False) 980 981 for op in ops.get_default_graph().get_operations(): 982 if op.type == "While" or op.type == "StatelessWhile": 983 while_op = op 984 985 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 986 x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] 987 x_input_t = body_graph.inputs[x_input_index] 988 accumulator_count = len( 989 [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) 990 self.assertEqual(accumulator_count, 1) 991 992 grad = gradients_impl.gradients(ret[0], x) 993 with self.cached_session() as sess: 994 self.assertEqual(sess.run(ret[0]), 16.) 995 self.assertSequenceEqual(self.evaluate(grad), [32.]) 996 997 @parameterized.named_parameters( 998 ("UnknownShape", None), 999 ("PartiallyDefinedShape", [None, 2]), 1000 ("FullyDefinedShape", [1, 2]), 1001 ) 1002 @test_util.run_deprecated_v1 1003 def testAccumulatorElementShape(self, shape): 1004 1005 def MatchShape(actual_tensor_shape): 1006 # Compare the shapes, treating None dimensions as equal. We do not 1007 # directly check actual_tensor_shape and tf.TensorShape(shape) for 1008 # equality because tf.Dimension.__eq__ returns None if either dimension is 1009 # None. 1010 if shape is None: 1011 self.assertIsNone(actual_tensor_shape.dims) 1012 else: 1013 self.assertListEqual(actual_tensor_shape.as_list(), shape) 1014 1015 def GetAccumulatorForInputAtIndex(while_op, idx): 1016 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 1017 y_input_t = body_graph.inputs[idx] 1018 push_back_node = [c for c in y_input_t.consumers() 1019 if c.type == "TensorListPushBack"][0] 1020 output_idx = body_graph.outputs.index(push_back_node.outputs[0]) 1021 return while_op.outputs[output_idx] 1022 1023 x = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 1024 y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 1025 1026 # Forward pass. 1027 ret = while_loop_v2(lambda v, u: v < 8., 1028 lambda v, u: (math_ops.pow(v, u), u), 1029 [x, y], 1030 return_same_structure=True) 1031 while_op = ret[0].op.inputs[0].op 1032 # Gradient pass. 1033 grad = gradients_impl.gradients(ret[0], x) 1034 # Note: There is an Identity b/w grad[0] and the While op. 1035 grad_while_op = grad[0].op.inputs[0].op 1036 1037 # Get the TensorList output of While op containing the accumulated values 1038 # of y. 1039 x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0] 1040 output = GetAccumulatorForInputAtIndex(while_op, x_input_index) 1041 _, val = list_ops.tensor_list_pop_back(output, 1042 element_dtype=dtypes.float32) 1043 MatchShape(val.shape) 1044 1045 # Take second derivative to generate intermediate grad_while_op outputs 1046 gradients_impl.gradients(grad, x) 1047 1048 # Get the TensorList output of gradient While op containing the accumulated 1049 # values of grad_x (note that grad_x is needed by the second derivative). 1050 # grad_while_op.inputs: 1051 grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0]) 1052 grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 1053 grad_output_index) 1054 _, val = list_ops.tensor_list_pop_back(grad_output, 1055 element_dtype=dtypes.float32) 1056 MatchShape(val.shape) 1057 1058 def _createWhile(self, name): 1059 """Helper function testDefaultName.""" 1060 output = while_v2.while_loop( 1061 lambda i: i < 3, 1062 lambda i: i + 1, [constant_op.constant(0)], 1063 return_same_structure=False) 1064 while_op = output.op.inputs[0].op 1065 self.assertEqual(while_op.type, "StatelessWhile") 1066 return while_op 1067 1068 def testDefaultName(self): 1069 with ops.Graph().as_default(): 1070 while_op = self._createWhile(None) 1071 self.assertEqual(while_op.name, "while") 1072 self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*") 1073 self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*") 1074 1075 with ops.Graph().as_default(): 1076 with ops.name_scope("foo"): 1077 while1_op = self._createWhile("") 1078 self.assertEqual(while1_op.name, "foo/while") 1079 self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*") 1080 self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*") 1081 1082 while2_op = self._createWhile(None) 1083 self.assertEqual(while2_op.name, "foo/while_1") 1084 self.assertRegex( 1085 while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*") 1086 self.assertRegex( 1087 while2_op.get_attr("body").name, r"foo_while_1_body_\d*") 1088 1089 @test_util.enable_control_flow_v2 1090 @test_util.run_deprecated_v1 1091 def testWhileAndTensorArray(self): 1092 param = constant_op.constant(2.0) 1093 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 1094 # map_fn uses TensorArray internally. 1095 r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0) 1096 grad = gradients_impl.gradients(r, param)[0] 1097 self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r)) 1098 self.assertAllClose(21.0, self.evaluate(grad)) 1099 1100 @test_util.run_deprecated_v1 1101 def testNestedWhile(self): 1102 # Compute sum of geometric progression: n^0 + n^1 + ... + n^m 1103 # We compute the pow using a while loop. 1104 n = constant_op.constant(3.) 1105 m = constant_op.constant(5.) 1106 sum_of_powers = constant_op.constant(0.) 1107 1108 def Body(i, previous_sum): 1109 prod = constant_op.constant(1.) 1110 return i - 1., previous_sum + while_loop_v2( 1111 lambda c, _: c > 0, 1112 lambda c, v: (c - 1., v * n), [i, prod], 1113 return_same_structure=False)[1] 1114 1115 result = while_loop_v2( 1116 lambda i, _: i >= 0, 1117 Body, [m, sum_of_powers], 1118 return_same_structure=False)[1] 1119 grad = gradients_impl.gradients(result, [n]) 1120 self.assertEqual(self.evaluate(result), 364.) 1121 self.assertSequenceEqual(self.evaluate(grad), [547.]) 1122 1123 @test_util.run_deprecated_v1 1124 def testNestedWhileWithLegacyDefun(self): 1125 n = constant_op.constant(3.) 1126 m = constant_op.constant(5.) 1127 sum_of_powers = constant_op.constant(0.) 1128 1129 def Body(i, previous_sum): 1130 prod = constant_op.constant(1.) 1131 1132 def InnerBodyWrapper(c, v): 1133 1134 @function.Defun(dtypes.float32, dtypes.float32) 1135 def InnerBody(c, v): 1136 return c - 1., v * n 1137 1138 results = InnerBody(c, v) 1139 results[0].set_shape([]) 1140 results[1].set_shape([]) 1141 return results 1142 1143 return i - 1., previous_sum + while_loop_v2( 1144 lambda c, _: c > 0, 1145 InnerBodyWrapper, [i, prod], 1146 return_same_structure=False)[1] 1147 1148 result = while_loop_v2( 1149 lambda i, _: i >= 0, 1150 Body, [m, sum_of_powers], 1151 return_same_structure=False)[1] 1152 grad = gradients_impl.gradients(result, [n]) 1153 self.assertEqual(self.evaluate(result), 364.) 1154 self.assertSequenceEqual(self.evaluate(grad), [547.]) 1155 1156 @test_util.run_deprecated_v1 1157 def testIdentityNodeInBody(self): 1158 1159 def Body(v): 1160 v = array_ops.identity(v) 1161 v = array_ops.identity(v) 1162 return v * v 1163 1164 x = constant_op.constant(2.) 1165 ret = while_loop_v2( 1166 lambda v: v < 8., Body, [x], return_same_structure=False) 1167 grad = gradients_impl.gradients(ret, [x]) 1168 self.assertEqual(self.evaluate(ret), 16.) 1169 self.assertSequenceEqual(self.evaluate(grad), [32.]) 1170 1171 @test_util.run_deprecated_v1 1172 def testForwardPassRewrite(self): 1173 x = constant_op.constant(1.0, name="x") 1174 output = while_v2.while_loop(lambda x: x < 10.0, 1175 lambda x: x * 2.0, 1176 [x])[0] 1177 while_op = output.op.inputs[0].op 1178 self.assertEqual(while_op.type, "StatelessWhile") 1179 # outputs = [loop_counter, max_iters, x] 1180 self.assertLen(while_op.outputs, 3) 1181 1182 gradients_impl.gradients(output, x) 1183 # while_op should have been rewritten to output intermediates. 1184 # outputs = [loop_counter, max_iters, x, x_accumulator] 1185 self.assertLen(while_op.outputs, 4) 1186 1187 gradients_impl.gradients(output, x) 1188 # Computing the gradient again shouldn't rewrite while_op again. 1189 self.assertLen(while_op.outputs, 4) 1190 1191 @parameterized.named_parameters( 1192 ("RandomUniform", random_ops.random_uniform, [5, 3]), 1193 ("RandomNormal", random_ops.random_normal, [5, 3]), 1194 ("ParameterizedTruncatedNormal", 1195 random_ops.parameterized_truncated_normal, [5, 3]), 1196 ("TruncatedNormal", random_ops.truncated_normal, [5, 3]), 1197 ("RandomGamma", random_gamma, [5, 3]), 1198 ("RandomPoissonV2", random_poisson_v2, [5, 3]), 1199 ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]), 1200 ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]), 1201 ) 1202 @test_util.run_deprecated_v1 1203 def testRandomOpsShape(self, random_fn, expected_shape): 1204 shape = constant_op.constant([3]) 1205 1206 def Body(i, u): 1207 shape_extended = array_ops.concat([[5], shape], axis=0) 1208 u = random_fn(shape_extended) 1209 assert u.shape.as_list() == expected_shape, str(u.shape.as_list()) 1210 return i + 1, u 1211 1212 _, _ = while_loop_v2( 1213 cond=lambda i, _: i < 3, 1214 body=Body, 1215 loop_vars=[ 1216 0, 1217 array_ops.zeros(expected_shape, dtype=dtypes.float32), 1218 ]) 1219 1220 @test_util.run_deprecated_v1 1221 def testReshapeShape(self): 1222 shape = constant_op.constant([3, 4]) 1223 1224 def Body(i, u): 1225 shape_extended = array_ops.concat([[5], shape], axis=0) 1226 u = array_ops.reshape(u, [-1]) 1227 assert u.shape.as_list() == [60], str(u.shape.as_list()) 1228 u = array_ops.reshape(u, shape_extended) 1229 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1230 return i + 1, u 1231 1232 _, _ = while_loop_v2( 1233 cond=lambda i, _: i < 3, 1234 body=Body, 1235 loop_vars=[ 1236 0, 1237 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1238 ]) 1239 1240 @parameterized.named_parameters( 1241 ("Zeros", array_ops.zeros), 1242 ("Ones", array_ops.ones), 1243 ("Fill", fill), 1244 ) 1245 @test_util.run_deprecated_v1 1246 def testFillOpsShape(self, fill_fn): 1247 shape = constant_op.constant([3, 4]) 1248 1249 def Body(i, u): 1250 shape_extended = array_ops.concat([[5], shape], axis=0) 1251 u = fill_fn(shape_extended) 1252 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1253 return i + 1, u 1254 1255 _, _ = while_loop_v2( 1256 cond=lambda i, _: i < 3, 1257 body=Body, 1258 loop_vars=[ 1259 0, 1260 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1261 ]) 1262 1263 @test_util.run_deprecated_v1 1264 def testExternalColocationGrad(self): 1265 external_t = constant_op.constant(2.) 1266 v0 = constant_op.constant(2.) 1267 1268 def Body(v): 1269 with ops.colocate_with(external_t): 1270 return v * v 1271 1272 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1273 grad = gradients_impl.gradients(ret, [v0])[0] 1274 self.assertAllEqual(ret, 16.) 1275 self.assertAllEqual(grad, 32.) 1276 1277 @test_util.run_deprecated_v1 1278 def testDoNotAccumulateConstNodes(self): 1279 1280 def Body(v): 1281 return v * 2.0 1282 1283 v0 = constant_op.constant(2.) 1284 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1285 # Gradients computation has the side-effect of updating the forward op 1286 # which is what we want to test. 1287 unused_grad = gradients_impl.gradients(ret, [v0])[0] 1288 # ret is separated from the `While` op by an `Identity` so we skip over 1289 # that. 1290 forward_while_op = ret.op.inputs[0].op 1291 body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph") 1292 push_back_nodes = [ 1293 o for o in body_graph.get_operations() if o.type == "TensorListPushBack" 1294 ] 1295 # Gradient of `Mul` requires accumulating both its inputs. But since one 1296 # of those is a Const (2.0), we should have just one accumulator. 1297 self.assertLen(push_back_nodes, 1) 1298 1299 def testDoNotAccumulateForwardTensorsForReductionOps(self): 1300 1301 @def_function.function 1302 def Fn(): 1303 with backprop.GradientTape() as tape: 1304 x = constant_op.constant(2.) 1305 tape.watch(x) 1306 1307 def Body(i, x): 1308 forward_graph = ops.get_default_graph() 1309 1310 @custom_gradient.custom_gradient 1311 def SquaredWithZeroGrad(x): 1312 1313 def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name 1314 del variables 1315 gradient_graph = ops.get_default_graph() 1316 shape = gen_array_ops.shape(x) 1317 assert shape.graph is forward_graph 1318 rank = gen_array_ops.rank(x) 1319 assert rank.graph is forward_graph 1320 size = gen_array_ops.size(x) 1321 assert size.graph is forward_graph 1322 zeros = array_ops.zeros(shape) 1323 assert zeros.graph is gradient_graph 1324 return zeros 1325 1326 return x * 2, Grad 1327 1328 return i + 1, SquaredWithZeroGrad(x) 1329 1330 _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) 1331 grad = tape.gradient(result, x) 1332 return grad 1333 1334 Fn() 1335 1336 def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self): 1337 1338 @def_function.function 1339 def Fn(): 1340 with backprop.GradientTape() as tape: 1341 e = constant_op.constant(2.) 1342 x = list_ops.empty_tensor_list( 1343 element_dtype=dtypes.float32, element_shape=e.shape) 1344 x = list_ops.tensor_list_push_back(x, e) 1345 tape.watch(x) 1346 1347 def Body(i, x): 1348 forward_graph = ops.get_default_graph() 1349 1350 @custom_gradient.custom_gradient 1351 def IdentityWithZeroGrad(x): 1352 1353 def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name 1354 del variables 1355 gradient_graph = ops.get_default_graph() 1356 shape = gen_list_ops.tensor_list_element_shape( 1357 x, shape_type=dtypes.int32) 1358 assert shape.graph is forward_graph 1359 size = gen_list_ops.tensor_list_length(x) 1360 assert size.graph is forward_graph 1361 zeros = gen_list_ops.tensor_list_reserve(shape, size, 1362 dtypes.float32) 1363 assert zeros.graph is gradient_graph 1364 return zeros 1365 1366 return x, Grad 1367 1368 return i + 1, IdentityWithZeroGrad(x) 1369 1370 _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) 1371 ones_like = list_ops.tensor_list_from_tensor( 1372 array_ops.ones_like( 1373 list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)), 1374 element_shape=tensor_shape.TensorShape([])) 1375 grad = tape.gradient(result, x, output_gradients=[ones_like]) 1376 return grad 1377 1378 Fn() 1379 1380 @test_util.run_v2_only 1381 def testInheritParentNameScope(self): 1382 1383 @def_function.function 1384 def F(): 1385 with ops.name_scope("foo"): 1386 1387 def Cond(unused_i): 1388 with ops.name_scope("cond"): 1389 actual_name_scope = ops.get_name_scope() 1390 expected_name_scope = "foo/while/cond" 1391 assert actual_name_scope == expected_name_scope, ( 1392 "%s does not match %s" % 1393 (actual_name_scope, expected_name_scope)) 1394 return False 1395 1396 def Body(i): 1397 with ops.name_scope("body"): 1398 actual_name_scope = ops.get_name_scope() 1399 expected_name_scope = "foo/while/body" 1400 assert actual_name_scope == expected_name_scope, ( 1401 "%s does not match %s" % 1402 (actual_name_scope, expected_name_scope)) 1403 return i 1404 1405 return while_v2.while_loop(Cond, Body, [0.]) 1406 1407 F() 1408 1409 @test_util.run_deprecated_v1 # Need to pass RunMetadata. 1410 def testDisableLowering(self): 1411 old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE 1412 control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True 1413 with self.session() as sess: 1414 x = constant_op.constant(2.) 1415 ret = while_loop_v2( 1416 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) 1417 1418 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1419 run_metadata = config_pb2.RunMetadata() 1420 self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata), 1421 16) 1422 for dev_stat in run_metadata.step_stats.dev_stats: 1423 for ns in dev_stat.node_stats: 1424 self.assertNotIn("switch", ns.node_name) 1425 control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old 1426 1427 def _runBasicWithConfig(self, config): 1428 with ops.device("/cpu:0"): 1429 x = constant_op.constant(0) 1430 ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x]) 1431 with self.cached_session(config=config): 1432 self.assertEqual(1000, self.evaluate(ret)) 1433 1434 @test_util.run_deprecated_v1 1435 def testRunKernelsInline(self): 1436 config = config_pb2.ConfigProto() 1437 config.inter_op_parallelism_threads = -1 1438 self._runBasicWithConfig(config) 1439 1440 @test_util.run_deprecated_v1 1441 def testSingleThreadedExecution(self): 1442 config = config_pb2.ConfigProto() 1443 config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" 1444 self._runBasicWithConfig(config) 1445 1446 def testIsControlFlowGraph(self): 1447 x = constant_op.constant(0) 1448 1449 @def_function.function 1450 def F(c): 1451 1452 def Cond(i): 1453 self.assertTrue(i.graph.is_control_flow_graph) 1454 return i < 2 1455 1456 def Body(i): 1457 i = i + 1 1458 self.assertTrue(i.graph.is_control_flow_graph) 1459 return i 1460 1461 return while_loop_v2(Cond, Body, [c]) 1462 1463 ret, = F(x) 1464 self.assertEqual(2, self.evaluate(ret)) 1465 1466 def testImportFromSerializedWithFunctionInBody(self): 1467 serialized = """node { 1468 name: "Const" 1469 op: "Const" 1470 attr { 1471 key: "dtype" 1472 value { 1473 type: DT_FLOAT 1474 } 1475 } 1476 attr { 1477 key: "value" 1478 value { 1479 tensor { 1480 dtype: DT_FLOAT 1481 tensor_shape { 1482 } 1483 float_val: 1.0 1484 } 1485 } 1486 } 1487 } 1488 node { 1489 name: "while/maximum_iterations" 1490 op: "Const" 1491 attr { 1492 key: "dtype" 1493 value { 1494 type: DT_INT32 1495 } 1496 } 1497 attr { 1498 key: "value" 1499 value { 1500 tensor { 1501 dtype: DT_INT32 1502 tensor_shape { 1503 } 1504 int_val: -1 1505 } 1506 } 1507 } 1508 } 1509 node { 1510 name: "while/loop_counter" 1511 op: "Const" 1512 attr { 1513 key: "dtype" 1514 value { 1515 type: DT_INT32 1516 } 1517 } 1518 attr { 1519 key: "value" 1520 value { 1521 tensor { 1522 dtype: DT_INT32 1523 tensor_shape { 1524 } 1525 int_val: 0 1526 } 1527 } 1528 } 1529 } 1530 node { 1531 name: "while" 1532 op: "StatelessWhile" 1533 input: "while/loop_counter" 1534 input: "while/maximum_iterations" 1535 input: "Const" 1536 attr { 1537 key: "T" 1538 value { 1539 list { 1540 type: DT_INT32 1541 type: DT_INT32 1542 type: DT_FLOAT 1543 } 1544 } 1545 } 1546 attr { 1547 key: "_lower_using_switch_merge" 1548 value { 1549 b: true 1550 } 1551 } 1552 attr { 1553 key: "_num_original_outputs" 1554 value { 1555 i: 3 1556 } 1557 } 1558 attr { 1559 key: "_read_only_resource_inputs" 1560 value { 1561 list { 1562 } 1563 } 1564 } 1565 attr { 1566 key: "body" 1567 value { 1568 func { 1569 name: "while_body_822" 1570 } 1571 } 1572 } 1573 attr { 1574 key: "cond" 1575 value { 1576 func { 1577 name: "while_cond_821" 1578 } 1579 } 1580 } 1581 attr { 1582 key: "output_shapes" 1583 value { 1584 list { 1585 shape { 1586 } 1587 shape { 1588 } 1589 shape { 1590 } 1591 } 1592 } 1593 } 1594 attr { 1595 key: "parallel_iterations" 1596 value { 1597 i: 10 1598 } 1599 } 1600 } 1601 node { 1602 name: "while/Identity" 1603 op: "Identity" 1604 input: "while" 1605 attr { 1606 key: "T" 1607 value { 1608 type: DT_INT32 1609 } 1610 } 1611 } 1612 node { 1613 name: "while/Identity_1" 1614 op: "Identity" 1615 input: "while:1" 1616 attr { 1617 key: "T" 1618 value { 1619 type: DT_INT32 1620 } 1621 } 1622 } 1623 node { 1624 name: "while/Identity_2" 1625 op: "Identity" 1626 input: "while:2" 1627 attr { 1628 key: "T" 1629 value { 1630 type: DT_FLOAT 1631 } 1632 } 1633 } 1634 library { 1635 function { 1636 signature { 1637 name: "while_body_822" 1638 input_arg { 1639 name: "while_loop_counter" 1640 type: DT_INT32 1641 } 1642 input_arg { 1643 name: "while_maximum_iterations_0" 1644 type: DT_INT32 1645 } 1646 input_arg { 1647 name: "placeholder" 1648 type: DT_FLOAT 1649 } 1650 output_arg { 1651 name: "add" 1652 type: DT_INT32 1653 } 1654 output_arg { 1655 name: "while_maximum_iterations" 1656 type: DT_INT32 1657 } 1658 output_arg { 1659 name: "partitionedcall" 1660 type: DT_FLOAT 1661 } 1662 } 1663 node_def { 1664 name: "PartitionedCall" 1665 op: "PartitionedCall" 1666 input: "placeholder" 1667 attr { 1668 key: "Tin" 1669 value { 1670 list { 1671 type: DT_FLOAT 1672 } 1673 } 1674 } 1675 attr { 1676 key: "Tout" 1677 value { 1678 list { 1679 type: DT_FLOAT 1680 } 1681 } 1682 } 1683 attr { 1684 key: "_collective_manager_ids" 1685 value { 1686 list { 1687 } 1688 } 1689 } 1690 attr { 1691 key: "_read_only_resource_inputs" 1692 value { 1693 list { 1694 } 1695 } 1696 } 1697 attr { 1698 key: "config" 1699 value { 1700 s: "" 1701 } 1702 } 1703 attr { 1704 key: "config_proto" 1705 value { 1706 s: "" 1707 } 1708 } 1709 attr { 1710 key: "executor_type" 1711 value { 1712 s: "" 1713 } 1714 } 1715 attr { 1716 key: "f" 1717 value { 1718 func { 1719 name: "__inference_f_841" 1720 } 1721 } 1722 } 1723 experimental_debug_info { 1724 original_node_names: "PartitionedCall" 1725 } 1726 } 1727 node_def { 1728 name: "add/y" 1729 op: "Const" 1730 attr { 1731 key: "dtype" 1732 value { 1733 type: DT_INT32 1734 } 1735 } 1736 attr { 1737 key: "value" 1738 value { 1739 tensor { 1740 dtype: DT_INT32 1741 tensor_shape { 1742 } 1743 int_val: 1 1744 } 1745 } 1746 } 1747 experimental_debug_info { 1748 original_node_names: "add/y" 1749 } 1750 } 1751 node_def { 1752 name: "add_0" 1753 op: "AddV2" 1754 input: "while_loop_counter" 1755 input: "add/y:output:0" 1756 attr { 1757 key: "T" 1758 value { 1759 type: DT_INT32 1760 } 1761 } 1762 experimental_debug_info { 1763 original_node_names: "add" 1764 } 1765 } 1766 ret { 1767 key: "add" 1768 value: "add_0:z:0" 1769 } 1770 ret { 1771 key: "partitionedcall" 1772 value: "PartitionedCall:output:0" 1773 } 1774 ret { 1775 key: "while_maximum_iterations" 1776 value: "while_maximum_iterations_0" 1777 } 1778 arg_attr { 1779 key: 0 1780 value { 1781 attr { 1782 key: "_output_shapes" 1783 value { 1784 list { 1785 shape { 1786 } 1787 } 1788 } 1789 } 1790 } 1791 } 1792 arg_attr { 1793 key: 1 1794 value { 1795 attr { 1796 key: "_output_shapes" 1797 value { 1798 list { 1799 shape { 1800 } 1801 } 1802 } 1803 } 1804 } 1805 } 1806 arg_attr { 1807 key: 2 1808 value { 1809 attr { 1810 key: "_output_shapes" 1811 value { 1812 list { 1813 shape { 1814 } 1815 } 1816 } 1817 } 1818 } 1819 } 1820 } 1821 function { 1822 signature { 1823 name: "while_cond_821" 1824 input_arg { 1825 name: "while_loop_counter" 1826 type: DT_INT32 1827 } 1828 input_arg { 1829 name: "while_maximum_iterations" 1830 type: DT_INT32 1831 } 1832 input_arg { 1833 name: "placeholder" 1834 type: DT_FLOAT 1835 } 1836 output_arg { 1837 name: "less" 1838 type: DT_BOOL 1839 } 1840 } 1841 node_def { 1842 name: "Less/y" 1843 op: "Const" 1844 attr { 1845 key: "dtype" 1846 value { 1847 type: DT_FLOAT 1848 } 1849 } 1850 attr { 1851 key: "value" 1852 value { 1853 tensor { 1854 dtype: DT_FLOAT 1855 tensor_shape { 1856 } 1857 float_val: 5.0 1858 } 1859 } 1860 } 1861 experimental_debug_info { 1862 original_node_names: "Less/y" 1863 } 1864 } 1865 node_def { 1866 name: "Less" 1867 op: "Less" 1868 input: "placeholder" 1869 input: "Less/y:output:0" 1870 attr { 1871 key: "T" 1872 value { 1873 type: DT_FLOAT 1874 } 1875 } 1876 experimental_debug_info { 1877 original_node_names: "Less" 1878 } 1879 } 1880 ret { 1881 key: "less" 1882 value: "Less:z:0" 1883 } 1884 arg_attr { 1885 key: 0 1886 value { 1887 attr { 1888 key: "_output_shapes" 1889 value { 1890 list { 1891 shape { 1892 } 1893 } 1894 } 1895 } 1896 } 1897 } 1898 arg_attr { 1899 key: 1 1900 value { 1901 attr { 1902 key: "_output_shapes" 1903 value { 1904 list { 1905 shape { 1906 } 1907 } 1908 } 1909 } 1910 } 1911 } 1912 arg_attr { 1913 key: 2 1914 value { 1915 attr { 1916 key: "_output_shapes" 1917 value { 1918 list { 1919 shape { 1920 } 1921 } 1922 } 1923 } 1924 } 1925 } 1926 } 1927 function { 1928 signature { 1929 name: "__inference_f_841" 1930 input_arg { 1931 name: "mul_placeholder" 1932 type: DT_FLOAT 1933 } 1934 output_arg { 1935 name: "identity" 1936 type: DT_FLOAT 1937 } 1938 } 1939 node_def { 1940 name: "mul/y" 1941 op: "Const" 1942 attr { 1943 key: "dtype" 1944 value { 1945 type: DT_FLOAT 1946 } 1947 } 1948 attr { 1949 key: "value" 1950 value { 1951 tensor { 1952 dtype: DT_FLOAT 1953 tensor_shape { 1954 } 1955 float_val: 2.0 1956 } 1957 } 1958 } 1959 experimental_debug_info { 1960 original_node_names: "mul/y" 1961 } 1962 } 1963 node_def { 1964 name: "mul" 1965 op: "Mul" 1966 input: "mul_placeholder" 1967 input: "mul/y:output:0" 1968 attr { 1969 key: "T" 1970 value { 1971 type: DT_FLOAT 1972 } 1973 } 1974 experimental_debug_info { 1975 original_node_names: "mul" 1976 } 1977 } 1978 node_def { 1979 name: "Identity" 1980 op: "Identity" 1981 input: "mul:z:0" 1982 attr { 1983 key: "T" 1984 value { 1985 type: DT_FLOAT 1986 } 1987 } 1988 experimental_debug_info { 1989 original_node_names: "Identity" 1990 } 1991 } 1992 ret { 1993 key: "identity" 1994 value: "Identity:output:0" 1995 } 1996 arg_attr { 1997 key: 0 1998 value { 1999 attr { 2000 key: "_output_shapes" 2001 value { 2002 list { 2003 shape { 2004 } 2005 } 2006 } 2007 } 2008 } 2009 } 2010 } 2011 } 2012 versions { 2013 producer: 399 2014 min_consumer: 12 2015 } 2016 """ 2017 # Code for generating above graph: 2018 # 2019 # def Body(i): 2020 # @tf.function 2021 # def f(): 2022 # return i * 2 2023 # return f() 2024 # tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)]) 2025 graph_def = graph_pb2.GraphDef() 2026 text_format.Parse(serialized, graph_def) 2027 @def_function.function 2028 def F(): 2029 x, y = importer.import_graph_def( 2030 graph_def, return_elements=["Const:0", "while:2"]) 2031 grad_out, = gradients_impl.gradients(y, x) 2032 return grad_out 2033 self.assertAllEqual(F(), 8.0) 2034 2035 def testIndexedSlicesInIncomingGrads(self): 2036 @def_function.function 2037 def F(): 2038 x = constant_op.constant([2.]) 2039 # Computes x^4 2040 ret = while_loop_v2( 2041 lambda _: True, lambda v: v * v, [x], return_same_structure=False, 2042 maximum_iterations=2) 2043 v = array_ops.gather(ret, [0]) 2044 return gradients_impl.gradients(v, [x])[0] # 4*x^3 2045 self.assertAllEqual(self.evaluate(F()), [32.]) 2046 2047 def testShapeInvariantsRaggedTensor(self): 2048 2049 @def_function.function 2050 def TestFn(x): 2051 _, ret = while_loop_v2( 2052 lambda i, _: i < 1, 2053 lambda i, y: (i + 1, array_ops.concat([y, y], axis=0)), 2054 [0, x], 2055 shape_invariants=[ 2056 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 2057 ragged_tensor.RaggedTensorSpec(shape=[None, None])], 2058 ) 2059 return ret 2060 2061 x = ragged_factory_ops.constant([[1., 2.], [3.]]) 2062 result = TestFn(x) 2063 expected_result = [[1., 2.], [3.], [1., 2.], [3.]] 2064 self.assertAllEqual(result, expected_result) 2065 2066 2067def ScalarShape(): 2068 return ops.convert_to_tensor([], dtype=dtypes.int32) 2069 2070 2071def GetOptimizedGraph(): 2072 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 2073 config = config_pb2.ConfigProto() 2074 config.graph_options.rewrite_options.CopyFrom( 2075 rewriter_config_pb2.RewriterConfig( 2076 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 2077 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) 2078 return tf_optimizer.OptimizeGraph(config, mg) 2079 2080 2081if __name__ == "__main__": 2082 test.main() 2083