xref: /aosp_15_r20/external/pytorch/test/dynamo/test_subgraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2from unittest.mock import patch
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7from torch._dynamo.testing import unsupported
8from torch._dynamo.utils import ifdynstaticdefault
9
10
11globalmod = torch.nn.ReLU()
12
13
14def indirectly_unsupported(a, b):
15    c = a + b
16    return unsupported(a, c)
17
18
19class SubGraphTests(torch._dynamo.test_case.TestCase):
20    def _common(self, fn, frame_count, op_count):
21        torch._dynamo.reset()
22        v1 = torch.ones(10)
23        v2 = torch.ones(10) * -2.0
24        correct1 = fn(v1, v2)
25        correct2 = fn(v2, v1)
26        cnt = torch._dynamo.testing.CompileCounter()
27        opt_fn = torch._dynamo.optimize(cnt)(fn)
28        r1 = opt_fn(v1, v2)
29        r2 = opt_fn(v2, v1)
30        self.assertTrue(torch._dynamo.testing.same(r1, correct1))
31        self.assertTrue(torch._dynamo.testing.same(r2, correct2))
32        self.assertEqual(
33            cnt.frame_count,
34            frame_count,
35            f"actual {cnt.frame_count} != expected {frame_count}",
36        )
37        self.assertEqual(cnt.op_count, op_count)
38
39    def test_control_flow1(self):
40        def fn(a, b):
41            c1 = a - b
42            c2 = b - a
43            if c1.sum() > c2.sum():
44                return c1
45            else:
46                return c2
47
48        self._common(fn, 1, 5)
49
50    def test_control_flow2(self):
51        def fn(a, b):
52            if a.sum() > b.sum():
53                return 1
54            else:
55                return 2
56
57        self._common(fn, 1, 3)
58
59    def test_control_flow3(self):
60        def fn(a, b):
61            c1 = a - b
62            c2 = b - a
63            m = globalmod
64            if c1.sum() > c2.sum():
65                return m(c1)
66            else:
67                return m(c2)
68
69        self._common(fn, 3, 7)
70
71    def test_control_flow4(self):
72        def fn(a, b):
73            tmp1 = a.sum() > b.sum() and a.sum() > 0
74            if tmp1:
75                return 1
76            else:
77                return 2
78
79        self._common(fn, 3, 5)
80
81    def test_control_flow5(self):
82        def fn(a, b):
83            tmp1 = a.sum() > b.sum() and a.sum() > 0
84            tmp2 = a.sum() < b.sum() or b.sum() > 0
85            if tmp1 and tmp2:
86                return 1, tmp1, tmp2
87            else:
88                return 2, tmp1, tmp2
89
90        self._common(fn, 6, 13)
91
92    def test_capi_call1(self):
93        def fn(a, b):
94            c1 = a - b
95            c2 = b - a
96            return unsupported(c1, c2)
97
98        self._common(fn, 1, 2)
99
100    def test_capi_call2(self):
101        def fn(a, b):
102            c1 = a - b
103            c2 = b - a
104            return a - (b - unsupported(c1, c2))
105
106        self._common(fn, 2, 4)
107
108    def test_capi_call3(self):
109        def fn(a, b):
110            c1 = a - b
111            c2 = b - a
112            return torch._dynamo.testing.unsupported(c1, c2)
113
114        self._common(fn, 1, 2)
115
116    def test_indirect_unsupported1(self):
117        def fn(a, b):
118            c1 = a - b
119            c2 = b - a
120            return indirectly_unsupported(c1, c2)
121
122        self._common(fn, 2, 3)
123
124    def test_indirect_unsupported2(self):
125        def fn(a, b):
126            local_const1 = 7
127            local_const2 = 22
128            c1 = a - b
129            c2 = b - a
130            return local_const1 / (local_const2 - indirectly_unsupported(c1, c2))
131
132        self._common(fn, 3, 5)
133
134    def test_indirect_unsupported3(self):
135        def fn(a, b):
136            args = [a - b, b - a]
137            return indirectly_unsupported(*args)
138
139        self._common(fn, 2, 3)
140
141    def test_stack_state1(self):
142        def fn(a, b):
143            t1 = 1.23 * a
144            t2 = 4.56 * a
145            c1 = a - b
146            c2 = b - a
147            return t1 / (t2 - unsupported(c1, c2))
148
149        self._common(fn, 2, 6)
150
151    def test_stack_state2(self):
152        def fn(a, b):
153            t1 = 1.23 * a
154            t2 = 4.56 * a
155            c1 = a - b
156            c2 = b - a
157            return t1 / (t2 - indirectly_unsupported(c1, c2))
158
159        self._common(fn, 3, 7)
160
161    def test_multigraph(self):
162        def fn(a, b):
163            x = a + b
164            x = x / 2.0
165            if x.sum() < 0:
166                return x * -1.0
167            return x
168
169        self._common(fn, 2, 5)
170
171    def test_extended_args(self):
172        too_many_adds = "+".join(["a", "b"] * 256)
173        source = (
174            f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
175        )
176        self._common(eval(source), 3, 1026)
177
178    def test_resume1(self):
179        def fn(a, b):
180            x = a + b
181            x = x / 2.0
182            x = x + 2.0
183            x = unsupported(x, a)
184            x = x + 2.0
185            x = x + 2.0
186            x = x + 2.0
187            return x
188
189        self._common(fn, 2, 6)
190
191    def test_resume2(self):
192        def fn(a, b):
193            x = a + b
194            x = x / 2.0
195            x = x + 2.0
196            x = indirectly_unsupported(x, a)
197            x = x + 2.0
198            x = x + 2.0
199            x = x + 2.0
200            return x
201
202        self._common(fn, 3, 7)
203
204    def test_resume3(self):
205        def fn(a, b):
206            x = a + b
207            x = x / 2.0
208            x = x + 2.0
209            x = indirectly_unsupported(x, b=a)
210            x = x + 2.0
211            x = x + 2.0
212            x = x + 2.0
213            return x
214
215        self._common(fn, 3, 7)
216
217    def test_resume4(self):
218        def fn(a, b):
219            x = a + b
220            x = x / 2.0
221            x = x + 2.0
222            x = indirectly_unsupported(a=x, b=a)
223            x = x + 2.0
224            x = x + 2.0
225            x = x + 2.0
226            return x
227
228        self._common(fn, 3, 7)
229
230    def test_resume5(self):
231        def fn(a, b):
232            x = a + b
233            x = x / 2.0
234            x = x + 2.0
235            print(x)
236            x = x + 2.0
237            x = x + 2.0
238            x = x + 2.0
239            return x
240
241        self._common(fn, 2, 6)
242
243    def test_start1(self):
244        def fn(a, b):
245            print(a)
246            x = a + b
247            x = x + 2.0
248            x = x + 2.0
249            return x
250
251        self._common(fn, 1, 3)
252
253    def test_start2(self):
254        def fn(a, b):
255            x = indirectly_unsupported(a, b)
256            x = x + 2.0
257            x = x + 2.0
258            x = x + 2.0
259            return x
260
261        self._common(fn, 2, 4)
262
263    def test_start3(self):
264        def fn(a, b):
265            x = unsupported(a, b)
266            x = x + 2.0
267            x = x + 2.0
268            x = x + 2.0
269            return x
270
271        self._common(fn, 1, 3)
272
273    def test_start4(self):
274        def fn(a, b, check):
275            if check:
276                return a + b + 10
277            else:
278                return a + b - 10
279
280        v1 = torch.randn(10)
281        v2 = torch.randn(10)
282        f = torch.zeros(1, dtype=torch.int32)
283        t = torch.ones(1, dtype=torch.int32)
284        correct1 = fn(v1, v2, t)
285        correct2 = fn(v1, v2, f)
286        cnt = torch._dynamo.testing.CompileCounter()
287        opt_fn = torch._dynamo.optimize(cnt)(fn)
288        r1 = opt_fn(v1, v2, t)
289        r2 = opt_fn(v1, v2, f)
290        self.assertTrue(torch._dynamo.testing.same(r1, correct1))
291        self.assertTrue(torch._dynamo.testing.same(r2, correct2))
292        self.assertEqual(cnt.frame_count, 3)
293        self.assertEqual(cnt.op_count, 4)
294
295    def test_resume_freevars(self):
296        c1 = torch.randn(10)
297        c2 = torch.randn(10)
298
299        def fn(a, b):
300            x = a + b + (c1 - c2)
301            x = unsupported(x, x)
302            return x + (c1 - c2)
303
304        self._common(fn, 2, 5)
305
306    def test_restore_state(self):
307        def fn(a, b):
308            len_ = len
309            x = a + b
310            x = torch.add(unsupported(x, x), 1)
311            return a * x + len_(b)
312
313        self._common(fn, 2, 4)
314
315    def test_restore_range(self):
316        def fn(a, b):
317            x = a + b
318            rng = range(3, 8, 2)
319            x = unsupported(x, x)
320            for i in rng:
321                x = x + i
322            return x
323
324        # We don't specialize on range with dynamic shapes, which
325        # means we fail to unroll the loop.
326        # TODO: Consider forcing specialization when we iterate over
327        # the loop
328        self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1))
329
330    def test_restore_range_iter(self):
331        def fn(a, b):
332            x = a + b
333            rng = iter(range(3, 8, 2))
334            x = unsupported(x, x)
335            x += next(rng)
336            return x, list(rng)
337
338        self._common(fn, 2, 2)
339
340    def test_pop_after_resume(self):
341        def fn(a, b):
342            tmp = [a + 1, b + 2, a + b]
343            x = a
344            x = unsupported(x, x)
345            for i in range(3):
346                x += tmp.pop(-1)
347            return x
348
349        self._common(fn, 2, 6)
350
351    @patch("torch._dynamo.config.assume_static_by_default", False)
352    def test_dynamic_getitem(self):
353        def fn(a, b):
354            return a[b.size(0) - 1]
355
356        cnt = torch._dynamo.testing.CompileCounter()
357        opt_fn = torch._dynamo.optimize(cnt)(fn)
358        for i in range(3, 12):
359            opt_fn(torch.randn(i), torch.randn(i))
360        # just one graph
361        self.assertEqual(cnt.frame_count, 1)
362
363    def test_dynamic_kwarg(self):
364        def fn(a, b):
365            return a - b * 10
366
367        torch._dynamo.reset()
368        cnt_dynamic = torch._dynamo.testing.CompileCounter()
369        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
370        start = 2
371        end = 12
372        steps = end - start
373        for i in range(start, end):
374            opt_fn(torch.randn(i), torch.randn(i))
375
376        self.assertEqual(cnt_dynamic.frame_count, 1)
377
378    def test_dynamic_duck_size(self):
379        def fn(a, b):
380            if a.size(0) == b.size(0):
381                return a + b
382            else:
383                return a.sum() + b.sum()
384
385        torch._dynamo.reset()
386        cnt_dynamic = torch._dynamo.testing.CompileCounter()
387        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
388        x = torch.randn(2)
389        y = torch.randn(3)
390        self.assertEqual(opt_fn(x, x), fn(x, x))
391        self.assertEqual(opt_fn(x, y), fn(x, y))
392        self.assertEqual(cnt_dynamic.frame_count, 2)
393
394    def test_dynamic_order_dependence(self):
395        def fn(a, b):
396            return a.sum() + b.sum()
397
398        torch._dynamo.reset()
399        cnt_dynamic = torch._dynamo.testing.CompileCounter()
400        opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
401        x = torch.randn(2)
402        y = torch.randn(3)
403        self.assertEqual(opt_fn(x, y), fn(x, y))
404        self.assertEqual(opt_fn(x, x), fn(x, x))
405        # NB: This COULD validly be 2, but we don't test disjointness in the
406        # guards for when x and y didn't duck size together, so we end up
407        # with a generic graph that also works when x and y happen to duck
408        # size together.
409        self.assertEqual(cnt_dynamic.frame_count, 2)
410
411        torch._dynamo.reset()
412        cnt_dynamic.frame_count = 0
413        self.assertEqual(opt_fn(x, x), fn(x, x))  # this overspecializes!
414        self.assertEqual(opt_fn(x, y), fn(x, y))
415        self.assertEqual(cnt_dynamic.frame_count, 2)
416
417    def test_dynamic_zero_inference(self):
418        def fn(a):
419            if a.size(0) != 0:
420                return a * 2
421            else:
422                return a + 1
423
424        torch._dynamo.reset()
425        cnt_dynamic = torch._dynamo.testing.CompileCounter()
426        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
427        x = torch.randn(0)
428        y = torch.randn(2)
429        self.assertEqual(opt_fn(y), fn(y))
430        self.assertEqual(opt_fn(x), fn(x))
431        self.assertEqual(cnt_dynamic.frame_count, 2)
432
433    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
434    def test_no_graph_break_on_item(self):
435        def fn(a, b):
436            x = a + b - 1.5
437            x = x.sum()
438            x.item()
439            x = x / (a + b)
440            return x
441
442        self._common(fn, 1, 5)  # item gets DCE'd
443
444    @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
445    def test_graph_break_on_item(self):
446        def fn(a, b):
447            x = a + b - 1.5
448            x = x.sum()
449            x.item()
450            x = x / (a + b)
451            return x
452
453        self._common(fn, 2, 5)
454
455    def test_resume_paths_join(self):
456        def fn(x, c1, c2, c3):
457            x = x + 1
458            if c1:
459                x = x + 2
460            x = x + 3
461            if c2:
462                x = x + 4
463            x = x + 5
464            if c3:
465                x = x + 6
466            return x + 7
467
468        v1 = torch.randn(10)
469        t = torch.Tensor([True])
470        f = torch.Tensor([False])
471        cnt = torch._dynamo.testing.CompileCounter()
472        opt_fn = torch._dynamo.optimize(cnt)(fn)
473        for a in (t, f):
474            for b in (t, f):
475                for c in (t, f):
476                    opt_fn(v1, a, b, c)
477
478        # checking here we don't create 2^n graphs
479        self.assertEqual(cnt.frame_count, 7)
480        self.assertEqual(cnt.op_count, 10)
481
482    def test_resume_with_no_grad1(self):
483        def fn(a, b):
484            x = a + b
485            with torch.no_grad():
486                x = x + 1
487                x.sum().tolist()  # graph break
488                x = x + 2
489            x = x + 3
490            return x
491
492        self._common(fn, 2, 9)
493        torch._dynamo.reset()
494        with torch.no_grad():
495            self._common(fn, 2, 5)
496
497    def test_resume_with_no_grad2(self):
498        def fn(a, b):
499            x = a + b
500            with torch.no_grad():
501                x = x + 1
502                x.sum().tolist()  # graph break
503                x = x + 2
504                x.sum().tolist()  # graph break
505                x = x + 3
506            x = x + 4
507            return x
508
509        self._common(fn, 3, 13)
510
511    def test_resume_with_no_grad3(self):
512        def fn(a, b):
513            x = a + b
514            with torch.no_grad():
515                with torch.no_grad():
516                    x = x + 1
517                    with torch.enable_grad():
518                        x.sum().tolist()  # graph break
519                        x = x[0] + 2
520                    x = x + 3
521            x = x + 4
522            return x
523
524        self._common(fn, 2, 11)
525
526    def test_resume_tuple_iterator(self):
527        def fn(a, b):
528            x = a + b
529            it = iter(tuple(range(10)))
530            x = x + next(it)
531            x = x + next(it)
532            x = x + next(it)
533            x = unsupported(x, x)
534            x = x + next(it)
535            x = x + next(it)
536            x = x + next(it)
537            x = x + next(it)
538            return x
539
540        self._common(fn, 2, 8)
541
542    def test_tuple_iterator_return(self):
543        def fn(x):
544            it = iter(tuple(range(10)))
545            x = x + next(it)
546            x = x + next(it)
547            x = unsupported(x, x)
548            x = x + next(it)
549            x = x + next(it)
550            x = unsupported(x, x)
551            x = x + next(it)
552            x = x + next(it)
553            return x, it
554
555        v1 = torch.randn(10)
556        v2, it2 = fn(v1)
557        cnt = torch._dynamo.testing.CompileCounter()
558        opt_fn = torch._dynamo.optimize(cnt)(fn)
559        v3, it3 = opt_fn(v1)
560        v4, it4 = opt_fn(v1)
561        self.assertEqual(v2.tolist(), v3.tolist())
562        self.assertEqual(v2.tolist(), v4.tolist())
563        self.assertEqual(list(it2), list(it3))
564        self.assertEqual(cnt.frame_count, 3)
565        self.assertEqual(cnt.op_count, 6)
566
567    def test_tuple_iterator_mutate(self):
568        def fn(x, it):
569            x = x + next(it)
570            x = x + next(it)
571            x = x + next(it)
572            x = x + next(it)
573            return x
574
575        v1 = torch.randn(10)
576        it1 = iter(tuple(range(10)))
577        cnt = torch._dynamo.testing.CompileCounter()
578        opt_fn = torch._dynamo.optimize(cnt)(fn)
579        self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist())
580        self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9])
581
582    def test_enumerate_not_break_graph(self):
583        def fn(a, b):
584            for i, x in enumerate(a.shape):
585                b = b + x
586            for i, x in enumerate(b.shape, 8):
587                b = b + x * i
588            return b
589
590        self._common(fn, 1, ifdynstaticdefault(2, 3))
591
592
593if __name__ == "__main__":
594    from torch._dynamo.test_case import run_tests
595
596    run_tests()
597