xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/converters/control_flow_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for control_flow module."""
16
17import collections
18
19import numpy as np
20
21from tensorflow.python.autograph.converters import break_statements
22from tensorflow.python.autograph.converters import continue_statements
23from tensorflow.python.autograph.converters import control_flow
24from tensorflow.python.autograph.core import converter_testing
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.platform import test
32from tensorflow.python.util import nest
33
34
35for_unaffected_global = None
36for_mixed_globals_nonglobals = None
37for_test_global_local = None
38
39
40class ControlFlowTestBase(converter_testing.TestCase):
41
42  def assertValuesEqual(self, actual, expected):
43    values = nest.map_structure(
44        lambda x: self.evaluate(x) if tensor_util.is_tf_type(x) else x,
45        actual)
46    self.assertAllEqual(values, expected)
47
48  def assertTransformedResult(self, f, inputs, expected):
49    if not isinstance(inputs, tuple):
50      inputs = (inputs,)
51    tr = self.transform(f, control_flow)
52    returns = tr(*inputs)
53    self.assertValuesEqual(returns, expected)
54
55
56class NestedControlFlowTest(ControlFlowTestBase):
57
58  def test_basic(self):
59
60    def f(n):
61      i = 0
62      j = 0
63      s = 0
64      while i < n:
65        while j < i:
66          j += 3
67        u = i + j  # 'u' is not defined within the inner loop
68        s += u
69        i += 1
70        j = 0
71      return s, i, j, n
72
73    self.assertTransformedResult(f, constant_op.constant(5),
74                                 (25, 5, 0, 5))
75
76  def test_mixed_globals_nonglobals(self):
77
78    def f(n):
79      global for_mixed_globals_nonglobals
80      i = 0
81      j = 0
82      for_mixed_globals_nonglobals = 0
83      while i < n:
84        while j < i:
85          j += 3
86        u = i + j  # 'u' is not defined within the inner loop
87        for_mixed_globals_nonglobals += u
88        i += 1
89        j = 0
90      return for_mixed_globals_nonglobals, i, j, n
91
92    self.assertTransformedResult(f, constant_op.constant(5),
93                                 (25, 5, 0, 5))
94
95  def test_composite_state_complex(self):
96
97    class TestClassX(object):
98
99      def __init__(self, x):
100        self.x = x
101
102    class TestClassY(object):
103
104      def __init__(self, y):
105        self.y = y
106
107    def f(n):
108      tc = TestClassX(TestClassY({'z': TestClassX(n)}))
109      if n > 0:
110        while n > 0:
111          if n < 2:
112            tc.x.y['z'].x += 1
113          n -= 1
114      return n, tc
115
116    tr = self.transform(f, control_flow)
117
118    n, tc = tr(constant_op.constant(5))
119    self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
120
121
122class WhileStatementTest(ControlFlowTestBase):
123
124  def test_basic(self):
125
126    def f(n):
127      i = 0
128      s = 0
129      while i < n:
130        s += i
131        i += 1
132      return s, i, n
133
134    self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5))
135
136  def test_single_output(self):
137
138    def f(n):
139      while n > 0:
140        n -= 1
141      return n
142
143    self.assertTransformedResult(f, constant_op.constant(5), 0)
144
145  def test_composite_state_attr(self):
146
147    class TestClass(object):
148
149      def __init__(self):
150        self.x = constant_op.constant(3)
151
152    def f(n):
153      tc = TestClass()
154      while n > 0:
155        tc.x += 1
156        n -= 1
157      return n
158
159    self.assertTransformedResult(f, constant_op.constant(5), 0)
160
161  def test_composite_state_slice(self):
162
163    def f(n):
164      d = {'a': n}
165      k = 'a'
166      while n > 0:
167        d[k] += 1
168        n -= 1
169      return d[k], n
170
171    self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
172
173  def test_composite_state_literal_slice(self):
174
175    def f(n):
176      d = {'a': n}
177      while n > 0:
178        d['a'] += 1
179        n -= 1
180      return d['a'], n
181
182    self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
183
184  def test_composite_state_attr_initialized_in_loop(self):
185
186    class TestClass(object):
187      pass
188
189    def f(n, x):
190      tc = TestClass()
191      while n < 5:
192        if n == 0:
193          tc.subattr = x
194        else:
195          tc.subattr = tc.subattr + 1
196        n += 1
197      return tc.subattr
198
199    self.assertTransformedResult(f, (0, constant_op.constant(10)), 14)
200    tr = self.transform(f, control_flow)
201    with self.assertRaisesRegex(
202        ValueError, "'tc.subattr' must be defined before the loop"):
203      tr(constant_op.constant(0), 0)
204
205  def test_composite_state_slice_initialized_in_loop(self):
206
207    def f(n, x):
208      d = {}
209      k = 'subkey'
210      while n < 5:
211        if n == 0:
212          d[k] = x
213        else:
214          d[k] = d[k] + 1
215        n += 1
216      return d
217
218    self.assertTransformedResult(f, (0, constant_op.constant(10)),
219                                 {'subkey': 14})
220    tr = self.transform(f, control_flow)
221    with self.assertRaisesRegex(
222        ValueError, r"'d\[k\]' must be defined before the loop"):
223      tr(constant_op.constant(0), 0)
224
225  def test_composite_state_literal_slice_initialized_in_loop(self):
226
227    def f(n, x):
228      d = {}
229      while n < 5:
230        if n == 0:
231          d['subkey'] = x
232        else:
233          d['subkey'] = d['subkey'] + 1
234        n += 1
235      return d
236
237    self.assertTransformedResult(f, (0, constant_op.constant(10)),
238                                 {'subkey': 14})
239    tr = self.transform(f, control_flow)
240    with self.assertRaisesRegex(
241        ValueError, r"'d\['subkey'\]' must be defined before the loop"):
242      tr(constant_op.constant(0), 0)
243
244  def test_composite_state_slice_aliased_to_local(self):
245
246    def f(n, x):
247      d = {}
248      while n < 5:
249        k = 'subkey'
250        d[k] = x + 1
251        n += 1
252      return d
253
254    self.assertTransformedResult(f, (0, constant_op.constant(10)),
255                                 {'subkey': 11})
256    tr = self.transform(f, control_flow)
257    # TODO(b/136999953): Better error message.
258    # Note that this error happens at execution time.
259    with self.assertRaises(errors.InaccessibleTensorError):
260      graph_fn = def_function.function(tr, autograph=False)
261      self.evaluate(
262          graph_fn(constant_op.constant(0), constant_op.constant(5)))
263
264  def test_local_composite_attr(self):
265
266    class TestClass(object):
267
268      def __init__(self):
269        self.x = constant_op.constant(3)
270
271    def f(n):
272      while n > 0:
273        tc = TestClass()
274        tc.x = tc.x
275        n -= 1
276      return n
277
278    self.assertTransformedResult(f, constant_op.constant(5), 0)
279
280  def test_local_composite_slice(self):
281
282    def f(n):
283      while n > 0:
284        d = {'x': n}
285        k = 'x'
286        d[k] = d[k]
287        n -= 1
288      return n
289
290    self.assertTransformedResult(f, constant_op.constant(5), 0)
291
292  def test_local_composite_literal_slice(self):
293
294    def f(n):
295      while n > 0:
296        d = {'x': n}
297        d['x'] = d['x']
298        n -= 1
299      return n
300
301    self.assertTransformedResult(f, constant_op.constant(5), 0)
302
303  def test_non_tensor_state(self):
304
305    # This class is ok to be in a tf.while's state.
306    class TestClass(collections.namedtuple('TestClass', ('x'))):
307      pass
308
309    def f(n):
310      tc = TestClass([constant_op.constant(0)])
311      while n > 0:
312        tc = TestClass([constant_op.constant(3)])
313        tc.x[0] = tc.x[0] + 1
314        n -= 1
315      return tc.x[0]
316
317    self.assertTransformedResult(f, constant_op.constant(5), 4)
318
319  def test_non_tensor_state_illegal_type(self):
320
321    class TestClass(object):
322
323      def __init__(self):
324        self.x = [constant_op.constant(3)]
325
326    def f(n):
327      while n > 0:
328        tc = TestClass()
329        tc.x[0] = tc.x[0] + 1
330        n -= 1
331      return tc.x[0]
332
333    tr = self.transform(f, control_flow)
334
335    # The tested function would require `tc` to become part of the while loop
336    # state, but TensorFlow doesn't support classes at the moment.
337    with self.assertRaisesRegex(
338        ValueError, 'tc.*must be defined before the loop'):
339      tr(constant_op.constant(5))
340
341  def test_dispatches_by_cond_only(self):
342
343    class TensorIncompatibleNumeric(object):
344      """Works in arithmetic expression, but errors out with TF ops."""
345
346      def __init__(self, val):
347        self.val = val
348
349      def __add__(self, other):
350        return TensorIncompatibleNumeric(self.val + other)
351
352    def f(n, s):
353      while n > 0:
354        n -= 1
355        s += n
356      return s
357
358    self.assertTransformedResult(f, (constant_op.constant(5), 0), 10)
359    tr = self.transform(f, control_flow)
360    # n alone controls the staging. When the loop is not staged, Python
361    # knows how to add the two objects. But when staged, tf.while will
362    # not know how to deal with the TensorIncompatibleNumeric object.
363    self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10)
364    with self.assertRaises(TypeError):
365      tr(constant_op.constant(5), TensorIncompatibleNumeric(0))
366
367
368class IfStatementTest(ControlFlowTestBase):
369
370  def test_basic(self):
371
372    def f(n):
373      a = 0
374      b = 0
375      if n > 0:
376        a = -n
377      else:
378        b = 2 * n
379      return a, b
380
381    self.assertTransformedResult(f, constant_op.constant(1), (-1, 0))
382    self.assertTransformedResult(f, constant_op.constant(-1), (0, -2))
383
384  def test_sparse_tensor(self):
385
386    def f(cond, a):
387      if cond:
388        a = -a
389      return a
390
391    st = sparse_tensor.SparseTensor(
392        indices=((0,),), values=(0,), dense_shape=(1,))
393    self.assertTransformedResult(f, (st, constant_op.constant(1)), -1)
394    self.assertTransformedResult(f, (None, constant_op.constant(1)), 1)
395
396  def test_complex_outputs(self):
397
398    class TestClass(object):
399
400      def __init__(self, a, b):
401        self.a = a
402        self.b = b
403
404    def f(n, obj):
405      obj.a = 0
406      obj.b = 0
407      if n > 0:
408        obj.a = -n
409      else:
410        obj.b = 2 * n
411      return obj
412
413    tr = self.transform(f, control_flow)
414
415    res_obj = tr(constant_op.constant(1), TestClass(0, 0))
416    self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
417    res_obj = tr(constant_op.constant(-1), TestClass(0, 0))
418    self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
419
420  def test_single_output(self):
421
422    def f(n):
423      if n > 0:
424        n = -n
425      return n
426
427    self.assertTransformedResult(f, constant_op.constant(1), -1)
428
429  def test_unbalanced(self):
430
431    def f(n):
432      if n > 0:
433        n = 3
434      return n
435
436    self.assertTransformedResult(f, constant_op.constant(2), 3)
437    self.assertTransformedResult(f, constant_op.constant(-3), -3)
438
439  def test_unbalanced_raising(self):
440
441    def f(n):
442      if n > 0:
443        n = n + 1
444        raise ValueError()
445      return n
446
447    self.assertTransformedResult(f, -3, -3)
448
449    tr = self.transform(f, control_flow)
450
451    with self.assertRaises(ValueError):
452      tr(1)
453
454  def test_local_var(self):
455
456    def f(n):
457      if n > 0:
458        b = 4
459        n = b + 1
460      return n
461
462    self.assertTransformedResult(f, constant_op.constant(1), 5)
463    self.assertTransformedResult(f, constant_op.constant(-1), -1)
464
465  def test_local_remains_local(self):
466
467    def f(n):
468      if n > 0:
469        b = 4
470        n = b + 1
471      return n
472
473    self.assertTransformedResult(f, constant_op.constant(1), 5)
474    self.assertTransformedResult(f, constant_op.constant(-1), -1)
475
476  def test_global_local(self):
477
478    def f(n):
479      if n > 0:
480        global for_test_global_local
481        if for_test_global_local is None:
482          for_test_global_local = 1
483        else:
484          for_test_global_local += 1
485        n += for_test_global_local
486      return n
487
488    tr = self.transform(f, control_flow)
489    assert for_test_global_local is None
490    self.assertEqual(tr(1), 2)
491    self.assertEqual(for_test_global_local, 1)
492
493  def test_no_outputs(self):
494
495    def f(n):
496      if n > 0:
497        b = 4  # pylint:disable=unused-variable
498      return n
499
500    self.assertTransformedResult(f, constant_op.constant(1), 1)
501    self.assertTransformedResult(f, constant_op.constant(-1), -1)
502
503  def test_created_outputs(self):
504
505    def f(i):
506      if i == 0:
507        result = i - 1
508      else:
509        result = i + 1
510      return result
511
512    self.assertTransformedResult(f, 0, -1)
513    self.assertTransformedResult(f, 1, 2)
514
515  def test_created_loop_local_outputs(self):
516
517    def f(n, x):
518      for i in n:
519        if i == 0:
520          result = i - 1
521        else:
522          result = i + 1
523        if result > 0:
524          x += 1
525      return x
526
527    self.assertTransformedResult(f, (range(5), 10), 14)
528
529  def test_created_loop_variable(self):
530
531    def f(n, x):
532      for i in n:
533        if i == 0:
534          result = i - 1
535        if i > 0:  # Using the result from previous iteration.
536          if result < 0:
537            x += 1
538      return x
539
540    self.assertTransformedResult(f, (range(5), 10), 14)
541
542  def test_unaffected_global(self):
543
544    global for_unaffected_global
545    for_unaffected_global = 3
546
547    def f(i):
548      global for_unaffected_global
549      if i == 0:
550        for_unaffected_global = i - 1
551      return for_unaffected_global
552
553    self.assertTransformedResult(f, 1, 3)
554    self.assertTransformedResult(f, 0, -1)
555    self.assertEqual(for_unaffected_global, -1)
556
557  def test_unaffected_nonlocal(self):
558
559    def f(i):
560      def inner_fn():
561        nonlocal n
562        if i == 0:
563          n = i - 1
564
565      n = 3
566      inner_fn()
567      return n
568
569    self.assertTransformedResult(f, 1, 3)
570    self.assertTransformedResult(f, 0, -1)
571
572  def test_output_defined_in_prior_except(self):
573
574    def f(i):
575      try:
576        raise ValueError()
577      except ValueError:
578        x = 1
579      if i == 0:
580        x = i - 1
581      return x
582
583    self.assertTransformedResult(f, 1, 1)
584    self.assertTransformedResult(f, 0, -1)
585
586  def test_unbalanced_multiple_composites(self):
587
588    class Foo(object):
589
590      def __init__(self):
591        self.b = 2
592        self.c = 3
593
594    def f(x, condition):
595
596      z = 5
597      if condition:
598        x.b = 7
599        x.c = 11
600        z = 13
601
602      return x.b, x.c, z
603
604    self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
605                                 (7, 11, 13))
606    self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
607                                 (2, 3, 5))
608
609  def test_unbalanced_composite(self):
610
611    class Foo(object):
612
613      def __init__(self):
614        self.b = 2
615
616    def f(x, condition):
617
618      z = 5
619      if condition:
620        x.b = 7
621        z = 13
622
623      return x.b, z
624
625    self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
626                                 (7, 13))
627    self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
628                                 (2, 5))
629
630
631class ForStatementTest(ControlFlowTestBase):
632
633  def test_basic(self):
634
635    def f(l):
636      s1 = 0
637      s2 = 0
638      for e in l:
639        s1 += e
640        s2 += e * e
641      return s1, s2
642
643    self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10))
644    empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
645    self.assertTransformedResult(f, empty_vector, (0, 0))
646
647  def test_single_output(self):
648
649    def f(l):
650      s = 0
651      for e in l:
652        s += e
653      return s
654
655    self.assertTransformedResult(f, constant_op.constant([1, 3]), 4)
656    empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
657    self.assertTransformedResult(f, empty_vector, 0)
658
659  def test_iterated_expression(self):
660
661    eval_count = [0]
662
663    def count_evals(x):
664      eval_count[0] += 1
665      return x
666
667    def f(n):
668      s = 0
669      for e in count_evals(range(n)):
670        s += e
671      return s
672
673    tr = self.transform(f, control_flow)
674
675    self.assertEqual(tr(5), 10)
676    self.assertEqual(eval_count[0], 1)
677
678  def test_composite_state_initialized_in_loop(self):
679
680    class TestClass(object):
681      pass
682
683    def f(n, x):
684      tc = TestClass()
685      for i in n:
686        if i == 0:
687          tc.x = x
688        else:
689          tc.x = tc.x + i
690      return tc.x
691
692    self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20)
693    tr = self.transform(f, control_flow)
694
695    with self.assertRaisesRegex(
696        ValueError, "'tc.x' must be defined before the loop"):
697      tr(constant_op.constant(list(range(5))), 0)
698
699  def test_tuple_unpacking(self):
700
701    def f(x_list):
702      z = constant_op.constant(0)  # pylint:disable=undefined-variable
703      for i, x in enumerate(x_list):
704        z = z + x + i
705      return z
706
707    self.assertTransformedResult(f, [3, 3], 7)
708
709  def test_with_comprehension_in_body(self):
710
711    def f(l, n):
712      s = constant_op.constant(list(range(n)))
713      for _ in l:
714        s += constant_op.constant([a for a in range(n)])
715      return s
716
717    self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5),
718                                 np.array(range(5)) * 4)
719
720
721class AdvancedControlFlowTest(ControlFlowTestBase):
722
723  def assertTransformedEquivalent(self, f, *inputs):
724    tr = self.transform(
725        f, (break_statements, continue_statements, control_flow))
726    self.assertEqual(f(*inputs), tr(*inputs))
727
728  def test_while_with_else(self):
729
730    def f(x):
731      while x > 2:
732        x /= 2
733      else:
734        x += 1
735      return x
736
737    self.assertTransformedEquivalent(f, 4)
738    self.assertTransformedEquivalent(f, 2)
739
740  def test_while_with_else_and_break(self):
741
742    def f(cond1):
743      x = 8
744      while x > 2:
745        x /= 2
746        if cond1:
747          break
748      else:
749        x += 1
750      return x
751
752    self.assertTransformedEquivalent(f, True)
753    self.assertTransformedEquivalent(f, False)
754
755  def test_for_with_else(self):
756
757    def f(l):
758      res = 0
759      for x in l:
760        res += x
761      else:
762        res += 1
763      return res
764
765    self.assertTransformedEquivalent(f, [])
766    self.assertTransformedEquivalent(f, [1, 2])
767
768  def test_for_with_else_and_break(self):
769
770    def f(flag):
771      l = [1, 2, 3]
772      res = 0
773      for x in l:
774        res += x
775        if flag:
776          break
777      else:
778        res += 1
779      return res
780
781    self.assertTransformedEquivalent(f, True)
782    self.assertTransformedEquivalent(f, False)
783
784
785if __name__ == '__main__':
786  test.main()
787