1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for control_flow module.""" 16 17import collections 18 19import numpy as np 20 21from tensorflow.python.autograph.converters import break_statements 22from tensorflow.python.autograph.converters import continue_statements 23from tensorflow.python.autograph.converters import control_flow 24from tensorflow.python.autograph.core import converter_testing 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.platform import test 32from tensorflow.python.util import nest 33 34 35for_unaffected_global = None 36for_mixed_globals_nonglobals = None 37for_test_global_local = None 38 39 40class ControlFlowTestBase(converter_testing.TestCase): 41 42 def assertValuesEqual(self, actual, expected): 43 values = nest.map_structure( 44 lambda x: self.evaluate(x) if tensor_util.is_tf_type(x) else x, 45 actual) 46 self.assertAllEqual(values, expected) 47 48 def assertTransformedResult(self, f, inputs, expected): 49 if not isinstance(inputs, tuple): 50 inputs = (inputs,) 51 tr = self.transform(f, control_flow) 52 returns = tr(*inputs) 53 self.assertValuesEqual(returns, expected) 54 55 56class NestedControlFlowTest(ControlFlowTestBase): 57 58 def test_basic(self): 59 60 def f(n): 61 i = 0 62 j = 0 63 s = 0 64 while i < n: 65 while j < i: 66 j += 3 67 u = i + j # 'u' is not defined within the inner loop 68 s += u 69 i += 1 70 j = 0 71 return s, i, j, n 72 73 self.assertTransformedResult(f, constant_op.constant(5), 74 (25, 5, 0, 5)) 75 76 def test_mixed_globals_nonglobals(self): 77 78 def f(n): 79 global for_mixed_globals_nonglobals 80 i = 0 81 j = 0 82 for_mixed_globals_nonglobals = 0 83 while i < n: 84 while j < i: 85 j += 3 86 u = i + j # 'u' is not defined within the inner loop 87 for_mixed_globals_nonglobals += u 88 i += 1 89 j = 0 90 return for_mixed_globals_nonglobals, i, j, n 91 92 self.assertTransformedResult(f, constant_op.constant(5), 93 (25, 5, 0, 5)) 94 95 def test_composite_state_complex(self): 96 97 class TestClassX(object): 98 99 def __init__(self, x): 100 self.x = x 101 102 class TestClassY(object): 103 104 def __init__(self, y): 105 self.y = y 106 107 def f(n): 108 tc = TestClassX(TestClassY({'z': TestClassX(n)})) 109 if n > 0: 110 while n > 0: 111 if n < 2: 112 tc.x.y['z'].x += 1 113 n -= 1 114 return n, tc 115 116 tr = self.transform(f, control_flow) 117 118 n, tc = tr(constant_op.constant(5)) 119 self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6)) 120 121 122class WhileStatementTest(ControlFlowTestBase): 123 124 def test_basic(self): 125 126 def f(n): 127 i = 0 128 s = 0 129 while i < n: 130 s += i 131 i += 1 132 return s, i, n 133 134 self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5)) 135 136 def test_single_output(self): 137 138 def f(n): 139 while n > 0: 140 n -= 1 141 return n 142 143 self.assertTransformedResult(f, constant_op.constant(5), 0) 144 145 def test_composite_state_attr(self): 146 147 class TestClass(object): 148 149 def __init__(self): 150 self.x = constant_op.constant(3) 151 152 def f(n): 153 tc = TestClass() 154 while n > 0: 155 tc.x += 1 156 n -= 1 157 return n 158 159 self.assertTransformedResult(f, constant_op.constant(5), 0) 160 161 def test_composite_state_slice(self): 162 163 def f(n): 164 d = {'a': n} 165 k = 'a' 166 while n > 0: 167 d[k] += 1 168 n -= 1 169 return d[k], n 170 171 self.assertTransformedResult(f, constant_op.constant(5), (10, 0)) 172 173 def test_composite_state_literal_slice(self): 174 175 def f(n): 176 d = {'a': n} 177 while n > 0: 178 d['a'] += 1 179 n -= 1 180 return d['a'], n 181 182 self.assertTransformedResult(f, constant_op.constant(5), (10, 0)) 183 184 def test_composite_state_attr_initialized_in_loop(self): 185 186 class TestClass(object): 187 pass 188 189 def f(n, x): 190 tc = TestClass() 191 while n < 5: 192 if n == 0: 193 tc.subattr = x 194 else: 195 tc.subattr = tc.subattr + 1 196 n += 1 197 return tc.subattr 198 199 self.assertTransformedResult(f, (0, constant_op.constant(10)), 14) 200 tr = self.transform(f, control_flow) 201 with self.assertRaisesRegex( 202 ValueError, "'tc.subattr' must be defined before the loop"): 203 tr(constant_op.constant(0), 0) 204 205 def test_composite_state_slice_initialized_in_loop(self): 206 207 def f(n, x): 208 d = {} 209 k = 'subkey' 210 while n < 5: 211 if n == 0: 212 d[k] = x 213 else: 214 d[k] = d[k] + 1 215 n += 1 216 return d 217 218 self.assertTransformedResult(f, (0, constant_op.constant(10)), 219 {'subkey': 14}) 220 tr = self.transform(f, control_flow) 221 with self.assertRaisesRegex( 222 ValueError, r"'d\[k\]' must be defined before the loop"): 223 tr(constant_op.constant(0), 0) 224 225 def test_composite_state_literal_slice_initialized_in_loop(self): 226 227 def f(n, x): 228 d = {} 229 while n < 5: 230 if n == 0: 231 d['subkey'] = x 232 else: 233 d['subkey'] = d['subkey'] + 1 234 n += 1 235 return d 236 237 self.assertTransformedResult(f, (0, constant_op.constant(10)), 238 {'subkey': 14}) 239 tr = self.transform(f, control_flow) 240 with self.assertRaisesRegex( 241 ValueError, r"'d\['subkey'\]' must be defined before the loop"): 242 tr(constant_op.constant(0), 0) 243 244 def test_composite_state_slice_aliased_to_local(self): 245 246 def f(n, x): 247 d = {} 248 while n < 5: 249 k = 'subkey' 250 d[k] = x + 1 251 n += 1 252 return d 253 254 self.assertTransformedResult(f, (0, constant_op.constant(10)), 255 {'subkey': 11}) 256 tr = self.transform(f, control_flow) 257 # TODO(b/136999953): Better error message. 258 # Note that this error happens at execution time. 259 with self.assertRaises(errors.InaccessibleTensorError): 260 graph_fn = def_function.function(tr, autograph=False) 261 self.evaluate( 262 graph_fn(constant_op.constant(0), constant_op.constant(5))) 263 264 def test_local_composite_attr(self): 265 266 class TestClass(object): 267 268 def __init__(self): 269 self.x = constant_op.constant(3) 270 271 def f(n): 272 while n > 0: 273 tc = TestClass() 274 tc.x = tc.x 275 n -= 1 276 return n 277 278 self.assertTransformedResult(f, constant_op.constant(5), 0) 279 280 def test_local_composite_slice(self): 281 282 def f(n): 283 while n > 0: 284 d = {'x': n} 285 k = 'x' 286 d[k] = d[k] 287 n -= 1 288 return n 289 290 self.assertTransformedResult(f, constant_op.constant(5), 0) 291 292 def test_local_composite_literal_slice(self): 293 294 def f(n): 295 while n > 0: 296 d = {'x': n} 297 d['x'] = d['x'] 298 n -= 1 299 return n 300 301 self.assertTransformedResult(f, constant_op.constant(5), 0) 302 303 def test_non_tensor_state(self): 304 305 # This class is ok to be in a tf.while's state. 306 class TestClass(collections.namedtuple('TestClass', ('x'))): 307 pass 308 309 def f(n): 310 tc = TestClass([constant_op.constant(0)]) 311 while n > 0: 312 tc = TestClass([constant_op.constant(3)]) 313 tc.x[0] = tc.x[0] + 1 314 n -= 1 315 return tc.x[0] 316 317 self.assertTransformedResult(f, constant_op.constant(5), 4) 318 319 def test_non_tensor_state_illegal_type(self): 320 321 class TestClass(object): 322 323 def __init__(self): 324 self.x = [constant_op.constant(3)] 325 326 def f(n): 327 while n > 0: 328 tc = TestClass() 329 tc.x[0] = tc.x[0] + 1 330 n -= 1 331 return tc.x[0] 332 333 tr = self.transform(f, control_flow) 334 335 # The tested function would require `tc` to become part of the while loop 336 # state, but TensorFlow doesn't support classes at the moment. 337 with self.assertRaisesRegex( 338 ValueError, 'tc.*must be defined before the loop'): 339 tr(constant_op.constant(5)) 340 341 def test_dispatches_by_cond_only(self): 342 343 class TensorIncompatibleNumeric(object): 344 """Works in arithmetic expression, but errors out with TF ops.""" 345 346 def __init__(self, val): 347 self.val = val 348 349 def __add__(self, other): 350 return TensorIncompatibleNumeric(self.val + other) 351 352 def f(n, s): 353 while n > 0: 354 n -= 1 355 s += n 356 return s 357 358 self.assertTransformedResult(f, (constant_op.constant(5), 0), 10) 359 tr = self.transform(f, control_flow) 360 # n alone controls the staging. When the loop is not staged, Python 361 # knows how to add the two objects. But when staged, tf.while will 362 # not know how to deal with the TensorIncompatibleNumeric object. 363 self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10) 364 with self.assertRaises(TypeError): 365 tr(constant_op.constant(5), TensorIncompatibleNumeric(0)) 366 367 368class IfStatementTest(ControlFlowTestBase): 369 370 def test_basic(self): 371 372 def f(n): 373 a = 0 374 b = 0 375 if n > 0: 376 a = -n 377 else: 378 b = 2 * n 379 return a, b 380 381 self.assertTransformedResult(f, constant_op.constant(1), (-1, 0)) 382 self.assertTransformedResult(f, constant_op.constant(-1), (0, -2)) 383 384 def test_sparse_tensor(self): 385 386 def f(cond, a): 387 if cond: 388 a = -a 389 return a 390 391 st = sparse_tensor.SparseTensor( 392 indices=((0,),), values=(0,), dense_shape=(1,)) 393 self.assertTransformedResult(f, (st, constant_op.constant(1)), -1) 394 self.assertTransformedResult(f, (None, constant_op.constant(1)), 1) 395 396 def test_complex_outputs(self): 397 398 class TestClass(object): 399 400 def __init__(self, a, b): 401 self.a = a 402 self.b = b 403 404 def f(n, obj): 405 obj.a = 0 406 obj.b = 0 407 if n > 0: 408 obj.a = -n 409 else: 410 obj.b = 2 * n 411 return obj 412 413 tr = self.transform(f, control_flow) 414 415 res_obj = tr(constant_op.constant(1), TestClass(0, 0)) 416 self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0)) 417 res_obj = tr(constant_op.constant(-1), TestClass(0, 0)) 418 self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2)) 419 420 def test_single_output(self): 421 422 def f(n): 423 if n > 0: 424 n = -n 425 return n 426 427 self.assertTransformedResult(f, constant_op.constant(1), -1) 428 429 def test_unbalanced(self): 430 431 def f(n): 432 if n > 0: 433 n = 3 434 return n 435 436 self.assertTransformedResult(f, constant_op.constant(2), 3) 437 self.assertTransformedResult(f, constant_op.constant(-3), -3) 438 439 def test_unbalanced_raising(self): 440 441 def f(n): 442 if n > 0: 443 n = n + 1 444 raise ValueError() 445 return n 446 447 self.assertTransformedResult(f, -3, -3) 448 449 tr = self.transform(f, control_flow) 450 451 with self.assertRaises(ValueError): 452 tr(1) 453 454 def test_local_var(self): 455 456 def f(n): 457 if n > 0: 458 b = 4 459 n = b + 1 460 return n 461 462 self.assertTransformedResult(f, constant_op.constant(1), 5) 463 self.assertTransformedResult(f, constant_op.constant(-1), -1) 464 465 def test_local_remains_local(self): 466 467 def f(n): 468 if n > 0: 469 b = 4 470 n = b + 1 471 return n 472 473 self.assertTransformedResult(f, constant_op.constant(1), 5) 474 self.assertTransformedResult(f, constant_op.constant(-1), -1) 475 476 def test_global_local(self): 477 478 def f(n): 479 if n > 0: 480 global for_test_global_local 481 if for_test_global_local is None: 482 for_test_global_local = 1 483 else: 484 for_test_global_local += 1 485 n += for_test_global_local 486 return n 487 488 tr = self.transform(f, control_flow) 489 assert for_test_global_local is None 490 self.assertEqual(tr(1), 2) 491 self.assertEqual(for_test_global_local, 1) 492 493 def test_no_outputs(self): 494 495 def f(n): 496 if n > 0: 497 b = 4 # pylint:disable=unused-variable 498 return n 499 500 self.assertTransformedResult(f, constant_op.constant(1), 1) 501 self.assertTransformedResult(f, constant_op.constant(-1), -1) 502 503 def test_created_outputs(self): 504 505 def f(i): 506 if i == 0: 507 result = i - 1 508 else: 509 result = i + 1 510 return result 511 512 self.assertTransformedResult(f, 0, -1) 513 self.assertTransformedResult(f, 1, 2) 514 515 def test_created_loop_local_outputs(self): 516 517 def f(n, x): 518 for i in n: 519 if i == 0: 520 result = i - 1 521 else: 522 result = i + 1 523 if result > 0: 524 x += 1 525 return x 526 527 self.assertTransformedResult(f, (range(5), 10), 14) 528 529 def test_created_loop_variable(self): 530 531 def f(n, x): 532 for i in n: 533 if i == 0: 534 result = i - 1 535 if i > 0: # Using the result from previous iteration. 536 if result < 0: 537 x += 1 538 return x 539 540 self.assertTransformedResult(f, (range(5), 10), 14) 541 542 def test_unaffected_global(self): 543 544 global for_unaffected_global 545 for_unaffected_global = 3 546 547 def f(i): 548 global for_unaffected_global 549 if i == 0: 550 for_unaffected_global = i - 1 551 return for_unaffected_global 552 553 self.assertTransformedResult(f, 1, 3) 554 self.assertTransformedResult(f, 0, -1) 555 self.assertEqual(for_unaffected_global, -1) 556 557 def test_unaffected_nonlocal(self): 558 559 def f(i): 560 def inner_fn(): 561 nonlocal n 562 if i == 0: 563 n = i - 1 564 565 n = 3 566 inner_fn() 567 return n 568 569 self.assertTransformedResult(f, 1, 3) 570 self.assertTransformedResult(f, 0, -1) 571 572 def test_output_defined_in_prior_except(self): 573 574 def f(i): 575 try: 576 raise ValueError() 577 except ValueError: 578 x = 1 579 if i == 0: 580 x = i - 1 581 return x 582 583 self.assertTransformedResult(f, 1, 1) 584 self.assertTransformedResult(f, 0, -1) 585 586 def test_unbalanced_multiple_composites(self): 587 588 class Foo(object): 589 590 def __init__(self): 591 self.b = 2 592 self.c = 3 593 594 def f(x, condition): 595 596 z = 5 597 if condition: 598 x.b = 7 599 x.c = 11 600 z = 13 601 602 return x.b, x.c, z 603 604 self.assertTransformedResult(f, (Foo(), constant_op.constant(True)), 605 (7, 11, 13)) 606 self.assertTransformedResult(f, (Foo(), constant_op.constant(False)), 607 (2, 3, 5)) 608 609 def test_unbalanced_composite(self): 610 611 class Foo(object): 612 613 def __init__(self): 614 self.b = 2 615 616 def f(x, condition): 617 618 z = 5 619 if condition: 620 x.b = 7 621 z = 13 622 623 return x.b, z 624 625 self.assertTransformedResult(f, (Foo(), constant_op.constant(True)), 626 (7, 13)) 627 self.assertTransformedResult(f, (Foo(), constant_op.constant(False)), 628 (2, 5)) 629 630 631class ForStatementTest(ControlFlowTestBase): 632 633 def test_basic(self): 634 635 def f(l): 636 s1 = 0 637 s2 = 0 638 for e in l: 639 s1 += e 640 s2 += e * e 641 return s1, s2 642 643 self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10)) 644 empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) 645 self.assertTransformedResult(f, empty_vector, (0, 0)) 646 647 def test_single_output(self): 648 649 def f(l): 650 s = 0 651 for e in l: 652 s += e 653 return s 654 655 self.assertTransformedResult(f, constant_op.constant([1, 3]), 4) 656 empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) 657 self.assertTransformedResult(f, empty_vector, 0) 658 659 def test_iterated_expression(self): 660 661 eval_count = [0] 662 663 def count_evals(x): 664 eval_count[0] += 1 665 return x 666 667 def f(n): 668 s = 0 669 for e in count_evals(range(n)): 670 s += e 671 return s 672 673 tr = self.transform(f, control_flow) 674 675 self.assertEqual(tr(5), 10) 676 self.assertEqual(eval_count[0], 1) 677 678 def test_composite_state_initialized_in_loop(self): 679 680 class TestClass(object): 681 pass 682 683 def f(n, x): 684 tc = TestClass() 685 for i in n: 686 if i == 0: 687 tc.x = x 688 else: 689 tc.x = tc.x + i 690 return tc.x 691 692 self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20) 693 tr = self.transform(f, control_flow) 694 695 with self.assertRaisesRegex( 696 ValueError, "'tc.x' must be defined before the loop"): 697 tr(constant_op.constant(list(range(5))), 0) 698 699 def test_tuple_unpacking(self): 700 701 def f(x_list): 702 z = constant_op.constant(0) # pylint:disable=undefined-variable 703 for i, x in enumerate(x_list): 704 z = z + x + i 705 return z 706 707 self.assertTransformedResult(f, [3, 3], 7) 708 709 def test_with_comprehension_in_body(self): 710 711 def f(l, n): 712 s = constant_op.constant(list(range(n))) 713 for _ in l: 714 s += constant_op.constant([a for a in range(n)]) 715 return s 716 717 self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5), 718 np.array(range(5)) * 4) 719 720 721class AdvancedControlFlowTest(ControlFlowTestBase): 722 723 def assertTransformedEquivalent(self, f, *inputs): 724 tr = self.transform( 725 f, (break_statements, continue_statements, control_flow)) 726 self.assertEqual(f(*inputs), tr(*inputs)) 727 728 def test_while_with_else(self): 729 730 def f(x): 731 while x > 2: 732 x /= 2 733 else: 734 x += 1 735 return x 736 737 self.assertTransformedEquivalent(f, 4) 738 self.assertTransformedEquivalent(f, 2) 739 740 def test_while_with_else_and_break(self): 741 742 def f(cond1): 743 x = 8 744 while x > 2: 745 x /= 2 746 if cond1: 747 break 748 else: 749 x += 1 750 return x 751 752 self.assertTransformedEquivalent(f, True) 753 self.assertTransformedEquivalent(f, False) 754 755 def test_for_with_else(self): 756 757 def f(l): 758 res = 0 759 for x in l: 760 res += x 761 else: 762 res += 1 763 return res 764 765 self.assertTransformedEquivalent(f, []) 766 self.assertTransformedEquivalent(f, [1, 2]) 767 768 def test_for_with_else_and_break(self): 769 770 def f(flag): 771 l = [1, 2, 3] 772 res = 0 773 for x in l: 774 res += x 775 if flag: 776 break 777 else: 778 res += 1 779 return res 780 781 self.assertTransformedEquivalent(f, True) 782 self.assertTransformedEquivalent(f, False) 783 784 785if __name__ == '__main__': 786 test.main() 787