xref: /aosp_15_r20/external/pytorch/test/inductor/test_cudagraph_trees.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import functools
4import gc
5import importlib
6import sys
7import unittest
8import warnings
9from unittest import mock
10
11import torch
12import torch._dynamo.config as dynamo_config
13import torch.nn as nn
14from torch._dynamo.utils import counters
15from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
16from torch._inductor import config
17from torch._inductor.codecache import FxGraphCache
18from torch._inductor.compile_fx import compile_fx_inner
19from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
20from torch._inductor.cudagraph_utils import FunctionID
21from torch._inductor.test_case import TestCase as InductorTestCase
22from torch.fx.experimental.proxy_tensor import make_fx
23from torch.testing import FileCheck
24from torch.testing._internal.common_cuda import TEST_MULTIGPU
25from torch.testing._internal.common_utils import (
26    instantiate_parametrized_tests,
27    IS_CI,
28    IS_LINUX,
29    IS_WINDOWS,
30    parametrize,
31    skipIfRocm,
32    TEST_CUDA_GRAPH,
33    TEST_WITH_ASAN,
34)
35from torch.utils._python_dispatch import TorchDispatchMode
36
37
38if IS_WINDOWS and IS_CI:
39    sys.stderr.write(
40        "Windows CI does not have necessary dependencies for test_torchinductor yet\n"
41    )
42    if __name__ == "__main__":
43        sys.exit(0)
44    raise unittest.SkipTest("requires sympy/functorch/filelock")
45
46importlib.import_module("functorch")
47importlib.import_module("filelock")
48
49from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
50
51
52aten = torch.ops.aten
53requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
54requires_multigpu = functools.partial(
55    unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
56)
57from io import StringIO
58
59
60def get_compile_fn(backend):
61    if backend == "cudagraphs":
62        return functools.partial(torch.compile, backend="cudagraphs")
63    else:
64        return functools.partial(torch.compile, mode="reduce-overhead")
65
66
67class capture_stderr(list):
68    """
69    Replace sys.stderr with a temporary StringIO
70    """
71
72    def __enter__(self):
73        self.sys_stderr = sys.stderr
74        self.stringio = StringIO()
75        sys.stderr = self.stringio
76        return self
77
78    def __exit__(self, *args):
79        self.append(str(self.stringio.getvalue()))
80        del self.stringio
81        sys.stderr = self.sys_stderr
82
83
84def cdata(t):
85    return t.untyped_storage()._cdata
86
87
88class TestCase(InductorTestCase):
89    @classmethod
90    def setUpClass(cls):
91        super().setUpClass()
92        cls._stack = contextlib.ExitStack()
93        cls._stack.enter_context(
94            config.patch(
95                {
96                    "debug": True,
97                    "cpp.min_chunk_size": 1,
98                    "triton.autotune_pointwise": False,  # too slow
99                    "implicit_fallbacks": False,
100                }
101            )
102        )
103
104    @classmethod
105    def tearDownClass(cls):
106        cls._stack.close()
107        super().tearDownClass()
108
109    def setUp(self):
110        torch._dynamo.reset()
111        super().setUp()
112
113    def tearDown(self):
114        super().tearDown()
115        torch._dynamo.reset()
116
117
118if HAS_CUDA and not TEST_WITH_ASAN:
119
120    def get_all_cudagraph_segments():
121        segments = torch.cuda.memory_snapshot()
122        return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
123
124    def all_live_blocks():
125        blocks_addrs = []
126        for segment in get_all_cudagraph_segments():
127            addr = segment["address"]
128            for block in segment["blocks"]:
129                if block["state"] == "active_allocated":
130                    blocks_addrs.append(addr)
131                addr += block["size"]
132
133        return blocks_addrs
134
135    def all_live_block_count():
136        return len(all_live_blocks())
137
138    class CudaGraphTreeTests(TestCase):
139        def setUp(self):
140            super().setUp()
141            self.graph_stack = contextlib.ExitStack()
142            self.graph_stack.enter_context(
143                config.patch(
144                    {
145                        "triton.cudagraphs": True,
146                        "triton.cudagraph_trees": True,
147                        "triton.fast_path_cudagraph_asserts": True,  # too slow
148                        "triton.slow_path_cudagraph_asserts": True,
149                    }
150                )
151            )
152            self.graph_stack.enter_context(
153                dynamo_config.patch(automatic_dynamic_shapes=True)
154            )
155            self.device_idx = torch.rand([0], device="cuda").device.index
156            warnings.filterwarnings("ignore")
157
158        def tearDown(self):
159            super().tearDown()
160            torch._dynamo.reset()
161            gc.collect()
162            torch.cuda.empty_cache()
163            self.graph_stack.close()
164
165            self.assertIsNone(self.get_manager())
166            self.assertEqual(all_live_block_count(), 0)
167            self.assertEqual(len(get_all_cudagraph_segments()), 0)
168            warnings.resetwarnings()
169
170        def get_manager(self, device_index=None):
171            return torch._inductor.cudagraph_trees.get_container(
172                self.device_idx if not device_index else device_index
173            ).tree_manager
174
175        def get_roots(self):
176            return self.get_manager().get_roots()
177
178        def curr_node(self):
179            return self.get_manager().current_node
180
181        def get_root_children(self):
182            return [root.num_descendants() for root in self.get_roots()]
183
184        def cudagraphify_impl(
185            self, *args, is_inference=True, is_backward=False, **kwargs
186        ):
187            return tree_cudagraphify_impl(
188                *args,
189                **kwargs,
190                device_index=self.device_idx,
191                is_inference=is_inference,
192                is_backward=is_backward,
193            )
194
195        @staticmethod
196        def run_twc(fn, *args, **kwargs):
197            fn(*args, **kwargs)
198            return fn(*args, **kwargs)
199
200        def num_checkpoints(self):
201            return self.get_manager().debug_checkpointing_counter
202
203        def test_run_simple(self):
204            def foo(x):
205                return x * x * x
206
207            foo_opt = torch.compile(foo)
208            ones = torch.ones([4, 4], device="cuda")
209            zeros = torch.zeros([5, 5], device="cuda")
210            self.run_twc(foo_opt, ones)
211            self.run_twc(foo_opt, zeros)
212            self.assertEqual(self.get_root_children(), [0, 0])
213
214        def check_rng(self):
215            @torch.compile(mode="reduce-overhead")
216            def foo():
217                return torch.rand([20])
218
219            torch.manual_seed(0)
220
221            out = foo()
222            out2 = foo()
223            out3 = foo()
224
225            torch.manual_seed(0)
226
227            self.assertEqual(out, foo())
228            self.assertEqual(out2, foo())
229            self.assertEqual(out3, foo())
230
231        @torch._inductor.config.patch("fallback_random", True)
232        def test_rng_trees(self):
233            self.check_rng()
234
235        @torch._inductor.config.patch("triton.cudagraph_trees", False)
236        @torch._inductor.config.patch("fallback_random", True)
237        def test_rng_non_trees(self):
238            self.check_rng()
239
240        def test_mutation_reinplaced(self):
241            import torch.nn as nn
242
243            class Model(nn.Module):
244                def __init__(self) -> None:
245                    super().__init__()
246
247                def forward(self, input, other, out):
248                    input = torch.logical_xor(input=input, other=other, out=out)
249                    return input
250
251            x = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
252            y = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
253            z = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float16).cuda()
254
255            model = Model().cuda()
256            eag = model(x, y, z)
257            with capture_stderr() as captured_output:
258                opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
259
260            FileCheck().check(
261                "skipping cudagraphs due to mutated inputs (1 instances). Found from"
262            ).check("torch.logical_xor").run(captured_output[0])
263            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
264
265        @requires_multigpu()
266        @parametrize("backend", ("inductor", "cudagraphs"))
267        def test_multiple_devices_msg(self, backend):
268            def foo(x, y):
269                return (x + 1, y + 2)
270
271            foo = get_compile_fn(backend)(foo)
272            with capture_stderr() as captured_output:
273                foo(torch.ones([10], device="cuda"), torch.ones([20]))
274
275            FileCheck().check(
276                "skipping cudagraphs due to cpu device (arg1_1). Found from"
277            ).check("y + 2").run(captured_output[0])
278            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
279
280            with capture_stderr() as captured_output:
281                foo(
282                    torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1")
283                )
284
285            FileCheck().check("skipping cudagraphs due to multiple devices").run(
286                captured_output[0]
287            )
288            self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
289
290        @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
291        def test_skip_symbolic(self):
292            @torch.compile(dynamic=True)
293            def foo(x, y):
294                return x + y
295
296            with capture_stderr() as captured_output:
297                foo(torch.rand([10], device="cuda"), torch.rand([10], device="cuda"))
298
299            FileCheck().check(
300                "skipping cudagraphs due to graph with symbolic shapes inputs"
301            ).check("x + y").run(captured_output[0])
302            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
303
304        @parametrize("backend", ("inductor", "cudagraphs"))
305        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
306        @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
307        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
308        def test_mutation_on_inp(self, backend):
309            def foo(x):
310                x.add_(2)
311                return x
312
313            foo = get_compile_fn(backend)(foo)
314
315            def inp():
316                return torch.ones([10], device="cuda")
317
318            with capture_stderr() as captured_output:
319                foo(inp())
320
321            FileCheck().check(
322                "skipping cudagraphs due to mutated inputs (1 instances). Found from"
323            ).check(".add_(2)").run(captured_output[0])
324            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
325
326            # mutation on inp doesnt hit cudagraphs
327            self.assertEqual(len(self.get_manager().roots), 0)
328
329            # mutation on parameters/buffers hits cudagraphs
330            class Mod(torch.nn.Module):
331                def __init__(self) -> None:
332                    super().__init__()
333                    self.buf = torch.ones([10], device="cuda")
334
335                def forward(self, x):
336                    self.buf.add_(x)
337                    return self.buf + x
338
339            def foo(mod, x):
340                return mod(x)
341
342            foo = get_compile_fn(backend)(foo)
343            mod = Mod()
344            mod2 = Mod()
345
346            for _ in range(3):
347                self.assertEqual(foo(mod, inp()), mod2(inp()))
348                self.assertEqual(mod.buf, mod2.buf)
349
350            self.assertIsNotNone(self.get_manager())
351
352        @parametrize("backend", ("inductor", "cudagraphs"))
353        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
354        @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False)
355        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
356        def test_mutation_cudagraph_managed_tensors_config(self, backend):
357            def foo(x):
358                return x + 1
359
360            def mut(x):
361                x.add_(2)
362                return x
363
364            def non_mut(x):
365                return x.add(2)
366
367            mut = get_compile_fn(backend)(mut)
368            foo = get_compile_fn(backend)(foo)
369
370            with capture_stderr() as captured_output:
371                for i in range(3):
372                    torch.compiler.cudagraph_mark_step_begin()
373                    inp = torch.rand([4], device="cuda")
374
375                    tmp = foo(inp)
376                    mut_out = mut(tmp)
377                    self.assertEqual(mut_out, non_mut(foo(inp)))
378            FileCheck().check_count(
379                "skipping cudagraphs due to mutated inputs (1 instances). Found from",
380                1,
381                exactly=True,
382            ).run(captured_output[0])
383
384        @parametrize("backend", ("inductor", "cudagraphs"))
385        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
386        @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
387        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
388        def test_mutation_cudagraph_managed_tensors(self, backend):
389            def foo(x):
390                return x + 1
391
392            def mut(x):
393                x.add_(2)
394                return x
395
396            def non_mut(x):
397                return x.add(2)
398
399            mut = get_compile_fn(backend)(mut)
400            foo = get_compile_fn(backend)(foo)
401
402            with capture_stderr() as captured_output:
403                for i in range(3):
404                    torch.compiler.cudagraph_mark_step_begin()
405                    inp = torch.rand([4], device="cuda")
406
407                    tmp = foo(inp)
408                    mut_out = mut(tmp)
409                    self.assertEqual(mut_out, non_mut(foo(inp)))
410            FileCheck().check_count(
411                "skipping cudagraphs due to mutated inputs (1 instances). Found from",
412                0,
413                exactly=True,
414            ).run(captured_output[0])
415            self.assertTrue("cudagraph_skips" not in counters["inductor"])
416
417            torch.compiler.cudagraph_mark_step_begin()
418            inp = torch.rand([4], device="cuda")
419            tmp = foo(inp)
420            mut_inp = tmp.clone()
421            # in this case, what previously a mutated cudagraph managed tensor is no longer,
422            # now its an input from eager we should fallback to inductor without cudagraphs
423            with capture_stderr() as captured_output:
424                mut(mut_inp)
425            FileCheck().check(
426                "skipping cudagraphs due to mutated inputs (1 instances). Found from"
427            ).check("x.add_(2)").run(captured_output[0])
428            self.assertEqual(mut_inp, non_mut(foo(inp)))
429            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
430
431        @parametrize("backend", ("inductor", "cudagraphs"))
432        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
433        @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
434        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
435        def test_mutation_cudagraph_managed_tensor_warn(self, backend):
436            def foo(x):
437                return x.add_(1)
438
439            def fee(y, z):
440                return z.add(3)
441
442            def inp():
443                return torch.rand([4], device="cuda")
444
445            foo = get_compile_fn(backend)(foo)
446            fee = get_compile_fn(backend)(fee)
447
448            with capture_stderr() as captured_output:
449                for _ in range(3):
450                    torch.compiler.cudagraph_mark_step_begin()
451                    fee(inp(), foo(inp()))
452            FileCheck().check_count(
453                "skipping cudagraphs due to mutated inputs (1 instances). Found from",
454                1,
455                exactly=True,
456            ).run(captured_output[0])
457            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
458
459        @parametrize("backend", ("inductor", "cudagraphs"))
460        @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
461        @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
462        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
463        def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend):
464            def foo(x):
465                return x + 1
466
467            def mut(x):
468                x.add_(2)
469                return x
470
471            def inp():
472                return torch.rand([4], device="cuda")
473
474            mut = get_compile_fn(backend)(mut)
475            foo = get_compile_fn(backend)(foo)
476
477            with capture_stderr() as captured_output:
478                # Should warn for current_node=None
479                mut(inp())
480
481                for i in range(3):
482                    torch.compiler.cudagraph_mark_step_begin()
483                    tmp = foo(inp())
484                    mut(tmp)  # should not warn
485
486                mut_inp = tmp.clone()
487                mut(mut_inp)  # should not warn since mut has warned
488
489            FileCheck().check_count(
490                "skipping cudagraphs due to mutated inputs (1 instances). Found from",
491                1,
492                exactly=True,
493            ).run(captured_output[0])
494            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
495
496        def test_function_compiled_multiple_times(self):
497            def foo(x):
498                y = foo2(x)
499                y2 = foo2(y)
500                return y + y2
501
502            def foo2(x):
503                torch._dynamo.graph_break()
504                return x * x * x
505
506            foo_opt = torch.compile(foo)
507            ones = torch.ones([4, 4], device="cuda")
508            foo(ones)
509            foo_opt(ones)
510            foo_opt(ones)
511            self.assertEqual(foo_opt(ones), foo(ones))
512            # paths
513            children = self.get_root_children()
514            # one root with two children
515            self.assertEqual(children, [2])
516
517        def test_end_recording_early(self):
518            def foo(x):
519                y = x * x * x
520                torch._dynamo.graph_break()
521                z = x + y
522                return z
523
524            @torch.compile
525            def foo2(x):
526                return x + 4
527
528            foo_opt = torch.compile(foo)
529
530            for _ in range(3):
531                out = foo_opt(torch.ones([4, 4], device="cuda"))
532                del out
533
534                # when I tried inducing separate recordings via graph break,
535                # the frame kept interferring by keeping outputs alive
536                # this isnt great by simulates the logic.
537                from torch._dynamo.mutation_guard import GenerationTracker
538
539                GenerationTracker.generation -= 1
540
541                out = foo2(torch.ones([4, 4], device="cuda"))
542                del out
543
544            foo_opt(torch.ones([4, 4], device="cuda"))
545
546            # Two separate traces - one has a child, one doesnt
547            self.assertEqual(self.get_root_children(), [1, 0])
548
549        def test_execution_into_recording(self):
550            def foo(x):
551                y = x + x
552
553                if y.sum() > 0:
554                    return y + 10
555                else:
556                    return y - 10
557
558            foo_opt = torch.compile(foo)
559            inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
560            self.assertEqual(foo_opt(inp), foo(inp))
561            self.assertEqual(foo_opt(inp), foo(inp))
562
563            inp.add_(1)
564            out_eager = foo(inp)
565            out_warmup = foo_opt(inp)
566            self.assertEqual(out_warmup, out_eager)
567            # warmup should be have storage deallocator hooked on
568            self.assertEqual(all_live_block_count(), 1)
569
570            out_live = foo_opt(inp)
571            self.assertEqual(out_live, out_eager)
572
573            # should be in recording mode, with storage deallocator hooked on
574            self.assertEqual(all_live_block_count(), 1)
575            # warmup should have been freed
576            del out_warmup
577            # should be in recording mode, with storage deallocator hooked on
578            self.assertEqual(all_live_block_count(), 1)
579
580            del out_live
581            self.assertEqual(all_live_block_count(), 0)
582
583            out = foo_opt(inp)
584            self.assertEqual(foo(inp), out)
585
586            # should be in execution mode
587            self.assertEqual(all_live_block_count(), 0)
588
589        def test_forward_with_skipped_cudagraphed_backward(self):
590            @torch.compile(mode="reduce-overhead")
591            def foo(x):
592                return x * x * x
593
594            for _ in range(3):
595                inp = torch.rand([20, 20], device="cuda", requires_grad=True)
596                out = foo(inp)
597
598                def complex_memory_overlap_new(t):
599                    return True
600
601                try:
602                    prev = torch._inductor.compile_fx.complex_memory_overlap
603                    torch._inductor.compile_fx.complex_memory_overlap = (
604                        complex_memory_overlap_new
605                    )
606                    back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
607                    out.backward(back_inp)
608                finally:
609                    torch._inductor.compile_fx.complex_memory_overlap = prev
610
611            # we should not have cudagraph'd the backwards
612            new_id = self.get_manager().new_graph_id().id
613            self.assertEqual(new_id, 1)
614
615            self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
616
617        @torch._functorch.config.patch("enable_autograd_cache", True)
618        @torch._inductor.config.patch("fx_graph_cache", True)
619        @torch._inductor.config.patch("fx_graph_remote_cache", False)
620        def test_cache_hit_forward_miss_backward(self):
621            # Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
622
623            @torch.compile(mode="reduce-overhead")
624            def foo(x):
625                return x * x * x
626
627            def complex_memory_overlap_new(t):
628                return True
629
630            # Run forwards, fx graph should cache miss
631            for _ in range(3):
632                torch._dynamo.reset()
633                counters.clear()
634                FxGraphCache.clear()
635                AOTAutogradCache.clear()
636
637                with mock.patch(
638                    "torch._inductor.compile_fx.complex_memory_overlap",
639                    new=complex_memory_overlap_new,
640                ):
641                    inp = torch.rand([20, 20], device="cuda", requires_grad=True)
642                    out = foo(inp)
643                    self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
644
645                    # Reset dynamo and related caches except for FXGraphCache
646                    torch._dynamo.reset()
647                    # Forwards should be a cache hit now, we still skip cudagraphs
648                    inp = torch.rand([20, 20], device="cuda", requires_grad=True)
649                    out = foo(inp)
650                    self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
651                    self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
652
653                    # Run backward without complex memory overlap being set
654
655                # Run the backward without complex memory overlap reason
656                # cache should miss, but cudagraphs should not run
657                # because forward skipped it
658                back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
659                out.backward(back_inp)
660                self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
661
662            # Run it one more time, this time AOTAutogradCache will hit
663            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
664            self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
665
666            torch._dynamo.reset()
667            inp = torch.rand([20, 20], device="cuda", requires_grad=True)
668            out = foo(inp)
669            back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
670            out.backward(back_inp)
671
672            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
673
674            # we should not have cudagraph'd anything
675            assert self.get_manager() is None
676
677        @torch._functorch.config.patch("enable_autograd_cache", True)
678        @torch._inductor.config.patch("fx_graph_cache", True)
679        @torch._inductor.config.patch("fx_graph_remote_cache", False)
680        def test_backward_gets_cached_cudagraphs(self):
681            # We pass cpu tensors to foo and save that into the cache
682            # On a subsequent run in a new process, cudagraphs should be
683            # disabled properly on both forward and backwards runs.
684
685            @torch.compile(mode="reduce-overhead")
686            def foo(x):
687                return x * x * x
688
689            torch._dynamo.reset()
690            counters.clear()
691            FxGraphCache.clear()
692            AOTAutogradCache.clear()
693
694            # Use cpu device to disable cudagraphs during compilation
695            inp = torch.rand([20, 20], device="cpu", requires_grad=True)
696            out = foo(inp)
697            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
698
699            back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
700            out.backward(back_inp)
701            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
702
703            # Run again on new process
704            torch._dynamo.reset()
705
706            # Forward and backward should also disable cudagraphs without compilation
707            inp = torch.rand([20, 20], device="cpu", requires_grad=True)
708            out = foo(inp)
709            # AOTAutogradCache will load the forward and the backward from cache immediately, so fx_graph_cache_hit will equal 2
710            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
711            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
712            torch._dynamo.reset()
713
714            back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
715            out.backward(back_inp)
716
717            # we should not have cudagraph'd anything
718            assert self.get_manager() is None
719
720        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
721        @torch._functorch.config.patch("enable_autograd_cache", True)
722        @torch._inductor.config.patch("fx_graph_cache", True)
723        @torch._inductor.config.patch("fx_graph_remote_cache", False)
724        def test_cached_forward_backward(self):
725            counters.clear()
726            AOTAutogradCache.clear()
727            FxGraphCache.clear()
728
729            @torch.compile
730            def foo(x):
731                torch.manual_seed(0)
732                y = x * 2
733                return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
734
735            inp = torch.rand([4, 4], requires_grad=True, device="cuda")
736            inp2 = inp.detach().clone().requires_grad_(True)
737            out = foo(inp)
738
739            out.sum().backward()
740
741            self.assertEqual(self.get_root_children(), [1])
742
743            # the three saved tensors should die in the backward
744            # we kept alive the output
745            self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
746            self.assertEqual(
747                self.curr_node().expected_dead_indices_after_graph,
748                [(0, 1), (0, 2)],
749            )
750            self.assertFalse(self.get_manager().new_graph_id().id == 0)
751            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
752
753            # Reset dynamo and rerun. We should see a cache hit now
754            torch._dynamo.reset()
755
756            out2 = foo(inp2)
757            out2.sum().backward()
758            self.assertEqual(out, out2)
759            self.assertEqual(inp.grad, inp2.grad)
760
761            self.assertEqual(self.get_root_children(), [1])
762            self.assertFalse(self.get_manager().new_graph_id().id == 0)
763            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
764
765        @parametrize("backend", ("inductor", "cudagraphs"))
766        def test_forward_backward_not_called(self, backend):
767            def foo(x, y):
768                x_out = x * x * x
769                torch._dynamo.graph_break()
770                y_out = y * y * y
771                return x_out, y_out
772
773            foo = get_compile_fn(backend)(foo)
774
775            for _ in range(3):
776                inps = [
777                    torch.rand([20, 20], requires_grad=True, device="cuda")
778                    for _ in range(2)
779                ]
780                x_out, y_out = foo(inps[0], inps[1])
781                x_out.sum().backward()
782
783            self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
784
785            # we should not have cudagraph'd the y backward
786            new_id = self.get_manager().new_graph_id().id
787            self.assertEqual(new_id, 3)
788
789        def _test_unaligned_static_input_impl(self, expected_clones):
790            def fn(x, y):
791                return (x + y,)
792
793            def get_aligned_inputs():
794                return [torch.rand([5, 5], device="cuda") for _ in range(2)]
795
796            mod = make_fx(fn)(*get_aligned_inputs())
797
798            mode = torch._subclasses.FakeTensorMode()
799
800            with mode:
801                inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
802
803            compiled_f = compile_fx_inner(
804                mod, inps, static_input_idxs=[0], cudagraphs=True
805            )
806
807            def get_unaligned_inputs():
808                return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
809
810            class CloneCounterMode(TorchDispatchMode):
811                def __init__(self) -> None:
812                    self.count = 0
813
814                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
815                    kwargs = {} if kwargs is None else kwargs
816                    self.count += func is torch.ops.aten.clone.default
817                    return func(*args, **kwargs)
818
819            for _ in range(3):
820                with CloneCounterMode() as m:
821                    compiled_f(get_unaligned_inputs())
822                    self.assertEqual(m.count, expected_clones)
823
824                    compiled_f(get_aligned_inputs())
825                    self.assertEqual(m.count, expected_clones)
826
827        def test_unaligned_static_input_trees(self):
828            self._test_unaligned_static_input_impl(expected_clones=0)
829
830        @torch._inductor.config.patch("triton.cudagraph_trees", False)
831        def test_unaligned_static_input_non_trees(self):
832            self._test_unaligned_static_input_impl(expected_clones=0)
833
834        @torch._inductor.config.patch("triton.cudagraphs", False)
835        def test_unaligned_static_input_no_cudagraphs(self):
836            self._test_unaligned_static_input_impl(expected_clones=0)
837
838        def test_sparsity(self):
839            def foo(view_6, buf31):
840                return aten._sparse_coo_tensor_with_dims_and_tensors(
841                    1,
842                    1,
843                    [1000000, 64],
844                    view_6,
845                    buf31,
846                    dtype=torch.float32,
847                    layout=torch.sparse_coo,
848                    device="cuda",
849                    pin_memory=None,
850                )
851
852            foo_opt = torch.compile(foo)
853
854            view_6 = torch.zeros([1, 102397], dtype=torch.int64, device="cuda")
855            buf31 = torch.rand([102397, 64], device="cuda")
856
857            for _ in range(3):
858                self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31))
859
860        def test_accumulate_multiple_recordings(self):
861            def foo(x):
862                y = x + x + x
863                torch._dynamo.graph_break()
864                if y.sum() <= 0:
865                    return y
866                else:
867                    return y * 10
868
869            foo_opt = torch.compile(foo)
870
871            # two separate compilations & recordings
872            out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
873
874            # out1 gets manually freed
875            out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
876
877            self.assertEqual(all_live_block_count(), 1)
878
879            out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
880
881            self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
882
883            self.assertEqual(all_live_block_count(), 1)
884            del out1, out2
885            self.assertEqual(all_live_block_count(), 1)
886
887            del out3
888            gc.collect()
889            self.assertEqual(all_live_block_count(), 0)
890
891        @torch._inductor.config.patch("freezing", True)
892        def test_constant_output(self):
893            class Mod(torch.nn.Module):
894                def __init__(self) -> None:
895                    super().__init__()
896                    self.param = torch.nn.Parameter(
897                        torch.tensor([float(i) for i in range(10)], device="cuda")
898                    )
899
900                def forward(self, inp):
901                    return self.param, self.param[0:2], inp + 2
902
903            inp = torch.tensor([2], device="cuda")
904            m = Mod()
905            with torch.no_grad():
906                out_eager = m(inp)
907
908                m_comp = torch.compile(m)
909                for _ in range(3):
910                    self.assertEqual(out_eager, m_comp(inp))
911
912        def test_live_outputs_multiple_graphs(self):
913            def foo(x):
914                x = x + x + x
915                y = x + 1
916                torch._dynamo.graph_break()
917                z = x * x
918                if z.sum() > 0:
919                    return y + 1
920                else:
921                    return y
922
923            foo_opt = torch.compile(foo)
924
925            self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
926            self.assertEqual(self.num_checkpoints(), 0)
927            out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
928
929            self.assertEqual(all_live_block_count(), 1)
930
931            del out
932            self.assertEqual(all_live_block_count(), 0)
933
934            # we need to checkpoint from function to warmup y + 1,
935            # and then again to record it
936            self.assertEqual(self.num_checkpoints(), 2)
937
938        def test_expanded_inputs(self):
939            x = torch.rand(1, 512, device="cuda").expand(4, 512)
940
941            def foo(x):
942                return x + 4 + torch.ones([4, 512], device="cuda")
943
944            foo_opt = torch.compile()(foo)
945
946            for _ in range(3):
947                self.assertEqual(foo_opt(x), foo(x))
948
949            self.assertFalse(self.get_manager().new_graph_id().id == 0)
950
951        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
952        def test_tensor_dies_between_checkpoint(self):
953            def foo(args):
954                x = args[0]
955                args.clear()
956                return x + 1, x + 2
957
958            inp = torch.rand([4], device="cuda")
959            inp_list = [inp]
960            foo_cg = self.cudagraphify_impl(foo, inp_list, ())
961            foo_cg(inp_list)
962            foo_cg([inp])
963
964            out1, out2 = foo_cg([inp])
965            inp = [out1]
966
967            del out1, out2
968
969            def foo2(args):
970                x = args[0]
971                args.clear()
972                return [x * x * x]
973
974            self.assertEqual(self.num_checkpoints(), 0)
975            foo2_cg = self.cudagraphify_impl(foo2, inp, ())
976
977            x = foo2_cg(inp)[0]
978
979            self.assertEqual(self.num_checkpoints(), 1)
980            # out2 dies between the previous recording and the new one,
981            # need to be manually deallocated after the checkpoint
982
983            self.assertEqual(all_live_block_count(), 1)
984            del x
985            self.assertEqual(all_live_block_count(), 0)
986
987        def test_aliased_storage_single_weakref(self):
988            @torch.compile(mode="reduce-overhead")
989            def foo(x):
990                x = x * 20
991                x_alias = x[0]
992                y = x * 10
993                y_alias = y[0]
994                torch._dynamo.graph_break()
995                ind = torch.tensor(4, device="cuda")
996                x_alias2 = x[ind:]
997                y_alias2 = y[ind:]
998                return x, x_alias, x_alias2, y_alias, y_alias2
999
1000            for _ in range(4):
1001                outs = foo(torch.rand([20, 20], device="cuda"))
1002
1003                ptr_to_ref = {
1004                    out.untyped_storage().data_ptr(): out.untyped_storage()._cdata
1005                    for out in outs
1006                }
1007
1008                self.assertEqual(len(ptr_to_ref), 2)
1009                for out in outs:
1010                    self.assertEqual(
1011                        ptr_to_ref[out.untyped_storage().data_ptr()],
1012                        out.untyped_storage()._cdata,
1013                    )
1014                del outs
1015                del out
1016
1017            node = self.get_manager().current_node
1018            self.assertEqual(len(list(node.path_live_weakrefs())), 0)
1019            self.assertFalse(self.get_manager().new_graph_id().id == 0)
1020
1021        def test_aliasing_static_ref(self):
1022            class Mod(torch.nn.Linear):
1023                def forward(self, x):
1024                    return self.weight.T @ x, self.weight.T, self.weight[0:4]
1025
1026            m = Mod(10, 10).cuda()
1027
1028            @torch.compile(mode="reduce-overhead")
1029            def foo(mod, x):
1030                return mod(x)
1031
1032            @torch.compile(mode="reduce-overhead")
1033            def foo2(x):
1034                return x[2:]
1035
1036            x = torch.rand([10, 10], device="cuda", requires_grad=True)
1037            param_c = cdata(m.weight)
1038            for _ in range(3):
1039                out1, alias_1, alias_2 = foo(m, x)
1040                self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
1041
1042                out2 = foo2(out1)
1043                out2.sum().backward()
1044                self.assertEqual(cdata(out1), cdata(out2))
1045
1046            node = self.curr_node()
1047            first_node = next(node._path_from_root)
1048            self.assertFalse(first_node.unaliased_in_all_paths[0])
1049            self.assertTrue(first_node.cached_tensor_outputs[0] is None)
1050
1051        @skipIfRocm
1052        def test_checkpointing_resets_persistent_refs(self):
1053            @torch.compile(mode="reduce-overhead")
1054            def foo(x):
1055                return x @ x
1056
1057            def inp():
1058                return torch.rand([20, 20], device="cuda", requires_grad=False)
1059
1060            for _ in range(3):
1061                foo(inp())
1062
1063            self.assertEqual(self.num_checkpoints(), 0)
1064
1065            out = foo(inp())
1066            out_id = id(out)
1067            del out
1068            self.assertEqual(id(foo(inp())), out_id)
1069
1070            @torch.compile(mode="reduce-overhead")
1071            def foo2(x):
1072                return x[0], x @ x
1073
1074            for i in range(2):
1075                out = foo(inp())
1076
1077                from torch._dynamo.mutation_guard import GenerationTracker
1078
1079                GenerationTracker.generation -= 1
1080
1081                out_alias, out2 = foo2(out)
1082                del out_alias
1083
1084                self.assertEqual(all_live_block_count(), 2)
1085                del out
1086                self.assertEqual(all_live_block_count(), 1)
1087                del out2
1088                self.assertEqual(all_live_block_count(), 0)
1089
1090                self.assertEqual(self.num_checkpoints(), i + 1)
1091
1092            new_out = foo(inp())
1093            curr_node = self.curr_node()
1094            self.assertFalse(curr_node.unaliased_in_all_paths[0])
1095            self.assertFalse(out_id == id(new_out))
1096
1097        def test_aliased_static_parameter(self):
1098            inp = torch.rand([20, 20], device="cuda")
1099
1100            def foo(args):
1101                x = args[0]
1102                args.clear()
1103                return (x[0],)
1104
1105            foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
1106
1107            for _ in range(3):
1108                out = foo_cg([inp])[0]
1109                self.assertEqual(cdata(inp), cdata(out))
1110
1111            node = self.curr_node()
1112            self.assertEqual(node.cached_tensor_outputs, [None])
1113            self.assertEqual(node.unaliased_in_all_paths, [False])
1114
1115        def test_warmup_stream_sync(self):
1116            def foo(args):
1117                x = args[0]
1118                args.clear()
1119                x_orig = x
1120                for _ in range(100):
1121                    x = x @ x
1122                return (x,)
1123
1124            inp = torch.rand([4096, 4096], device="cuda")
1125            ref = foo([inp])[0]
1126            torch.cuda.synchronize()
1127
1128            user_stream = torch.cuda.Stream()
1129            with torch.cuda.stream(user_stream):
1130                foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
1131                out = foo_cg([inp])[0]
1132                y = out + 1
1133                self.assertEqual(y, ref + 1)
1134
1135        def test_unaligned_static_parameter(self):
1136            def gen_inp():
1137                inp = torch.ones([20], device="cuda")
1138                return [inp[1:]]
1139
1140            def foo(args):
1141                x = args[0]
1142                args.clear()
1143                return (x + x,)
1144
1145            foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,))
1146
1147            for _ in range(3):
1148                out = foo_cg(gen_inp())
1149                self.assertEqual(out, foo(gen_inp()))
1150                del out
1151
1152            node = self.curr_node()
1153            self.assertEqual(node.static_input_data_ptrs, [None])
1154
1155        def test_amp_cache_disabled(self):
1156            @torch.compile()
1157            def foo(x):
1158                return x + x
1159
1160            for _ in range(3):
1161                out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1162
1163            # amp cache for cudagraph outputs should be disabled
1164            t2 = torch.rand([4, 4], device="cuda")
1165
1166            with torch.cuda.amp.autocast():
1167                run_once = out @ t2
1168
1169                out.detach().zero_()
1170
1171                run_twice = out @ t2
1172
1173                self.assertNotEqual(run_once, run_twice)
1174
1175        def test_remove_hooks_on_cached_tensors(self):
1176            @torch.compile()
1177            def foo(x):
1178                return x * x
1179
1180            inp = torch.rand([4], device="cuda", requires_grad=True)
1181
1182            for _ in range(5):
1183                out = foo(inp)
1184                self.assertIsNone(out._backward_hooks)
1185                out.register_hook(lambda: None)
1186
1187            # today, torch.compile never outputs a leaf tensor which is the only
1188            # tensor that can register _post_accumulate_grad_hooks
1189            # add this as a preventative test
1190
1191            @torch.compile()
1192            def foo(x):
1193                return torch.rand([4], device="cuda", requires_grad=True)
1194
1195            for _ in range(5):
1196                out = foo(inp)
1197                self.assertIsNone(out._post_accumulate_grad_hooks)
1198                out.register_post_accumulate_grad_hook(lambda: None)
1199
1200        def test_multiple_insert_removal_caching(self):
1201            torch._C._set_cached_tensors_enabled(True)
1202            try:
1203                x = torch.rand([4], device="cuda")
1204
1205                torch._C._add_cached_tensor(x)
1206                self.assertTrue(torch._C._is_cached_tensor(x))
1207
1208                torch._C._add_cached_tensor(x)
1209                torch._C._remove_cached_tensor(x)
1210
1211                self.assertFalse(torch._C._is_cached_tensor(x))
1212            finally:
1213                torch._C._set_cached_tensors_enabled(False)
1214
1215        def test_accumulate_grad(self):
1216            # cudagraph trees shouldnt interfere with accumulation logic
1217
1218            def compute_grad(grad_output, create_graph):
1219                x = torch.randn(5, 5, requires_grad=True, device="cuda")
1220
1221                @torch.compile()
1222                def foo(x):
1223                    return x + 2
1224
1225                y = foo(x)
1226                y.backward(grad_output, retain_graph=True)
1227                x_grad = x.grad
1228                x_grad_clone = x.grad.clone()
1229                y.backward(grad_output, create_graph=create_graph)
1230                return x_grad, x_grad_clone
1231
1232            for _ in range(3):
1233                grad_output = torch.ones(5, 5, device="cuda")
1234
1235                # Accumulate in-place when create_graph is False
1236                x_grad, x_grad_clone = compute_grad(grad_output, create_graph=False)
1237                self.assertEqual(x_grad, x_grad_clone * 2)
1238
1239                # Accumulate out-of-place when create_graph is False
1240                x_grad, x_grad_clone = compute_grad(grad_output, create_graph=True)
1241                self.assertEqual(x_grad, x_grad_clone)
1242
1243        def test_frozen_fn(self):
1244            @torch.compile()
1245            def foo(x):
1246                return x @ x
1247
1248            for _ in range(3):
1249                out = foo(torch.rand([10, 10], device="cuda"))
1250
1251            self.assertTrue(self.get_manager().new_graph_id().id == 1)
1252            frozen = torch._dynamo.run(foo)
1253
1254            for _ in range(3):
1255                out = frozen(torch.rand([10, 10], device="cuda"))
1256
1257            # didnt do additional recordings
1258            self.assertTrue(self.get_manager().new_graph_id().id == 2)
1259
1260        def test_empty_cpu_tensor(self):
1261            def foo(x):
1262                return x @ x, torch.tensor([])
1263
1264            foo_opt = torch.compile(foo)
1265            x = torch.rand([4], device="cuda")
1266
1267            for _ in range(3):
1268                out_opt = foo_opt(x)
1269                self.assertEqual(foo(x), out_opt)
1270
1271            self.assertTrue(self.get_manager().new_graph_id().id == 1)
1272
1273        def test_output_alias(self):
1274            inp = torch.rand([20, 20], device="cuda")
1275
1276            def foo(args):
1277                x = args[0]
1278                args.clear()
1279                out = x + x
1280                return (x, x[0])
1281
1282            foo_cg = self.cudagraphify_impl(foo, [inp], ())
1283
1284            for _ in range(3):
1285                out_1, out_2 = foo_cg([inp])
1286                self.assertEqual(cdata(out_1), cdata(out_2))
1287                del out_1, out_2
1288                self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
1289
1290            self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
1291
1292        def test_empty_storage(self):
1293            @torch.compile(mode="reduce-overhead")
1294            def foo(x):
1295                return (
1296                    (x + x + x),
1297                    torch.zeros([0], device="cuda"),
1298                    torch.zeros([100], device="cuda")[0:0],
1299                )
1300
1301            inp = torch.rand([4], device="cuda")
1302            for _ in range(3):
1303                out = foo(inp)
1304                node = self.curr_node()
1305                self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1306
1307            @torch.compile(mode="reduce-overhead")
1308            def foo(x):
1309                return (x + x + x), torch.rand([4], device="cuda") + 10
1310
1311            inp = torch.rand([0], device="cuda")
1312            for _ in range(3):
1313                out = foo(inp)
1314                node = self.curr_node()
1315                self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1316
1317        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
1318        def test_aliased_output_checkpoint(self):
1319            def foo(args):
1320                x = args[0]
1321                args.clear()
1322                y = x + 2
1323                return x + 1, y, y[0]
1324
1325            inp = torch.rand([4, 4], device="cuda")
1326            foo_cg = self.cudagraphify_impl(foo, [inp], ())
1327            foo_cg([inp])
1328            foo_cg([inp])
1329
1330            out1, out2, out3 = foo_cg([inp])
1331            inp = [out1]
1332
1333            del out1, out2, out3
1334
1335            def foo2(args):
1336                x = args[0]
1337                args.clear()
1338                return [x * x * x]
1339
1340            self.assertEqual(self.num_checkpoints(), 0)
1341            foo2_cg = self.cudagraphify_impl(foo2, inp, ())
1342
1343            x = foo2_cg(inp)[0]
1344
1345            self.assertEqual(self.num_checkpoints(), 1)
1346            # out2 and out3 dies between the previous recording and the new one,
1347            # need to be manually deallocated after the checkpoint
1348
1349            self.assertEqual(all_live_block_count(), 1)
1350            del x
1351            self.assertEqual(all_live_block_count(), 0)
1352
1353        @skipIfRocm
1354        @unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
1355        @torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
1356        def test_workspace_allocation_error(self):
1357            torch._C._cuda_clearCublasWorkspaces()
1358
1359            prev = torch._inductor.cudagraph_trees.clear_cublas_manager
1360
1361            try:
1362                torch._inductor.cudagraph_trees.clear_cublas_manager = (
1363                    contextlib.nullcontext
1364                )
1365
1366                @torch.compile()
1367                def foo(x, y):
1368                    return x @ x
1369
1370                inps = [torch.rand([400, 400], device="cuda") for _ in range(2)]
1371
1372                thrown = False
1373                try:
1374                    foo(*inps)
1375                except Exception as e:
1376                    thrown = True
1377                    self.assertTrue(
1378                        "at::cuda::blas::gemm<float>" in str(e)
1379                        or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
1380                    )
1381                    self.assertTrue(
1382                        "getCurrentCUDABlasHandle" in str(e)
1383                        or "getNewWorkspace" in str(e)
1384                    )
1385
1386                self.assertTrue(thrown)
1387
1388            finally:
1389                torch._C._cuda_clearCublasWorkspaces()
1390                torch._inductor.cudagraph_trees.clear_cublas_manager = prev
1391                torch._inductor.cudagraph_trees.get_container(
1392                    self.device_idx
1393                ).tree_manager = None
1394
1395        def test_peristed_output_livenes(self):
1396            @torch.compile
1397            def foo(x):
1398                return x + x
1399
1400            for _ in range(3):
1401                foo(torch.rand([2, 2], device="cuda"))
1402
1403            node = self.get_manager().current_node
1404            self.assertEqual(len(list(node.path_live_weakrefs())), 0)
1405
1406            out = foo(torch.rand([2, 2], device="cuda"))
1407            self.assertTrue(out is node.cached_tensor_outputs[0])
1408            self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1409
1410            out_ref = out[0:]
1411            del out
1412            self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1413
1414            del out_ref
1415            self.assertEqual(len(list(node.path_live_weakrefs())), 0)
1416
1417        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
1418        def test_tensor_no_longer_in_pool(self):
1419            def foo(args):
1420                x = args[0]
1421                args.clear()
1422                return x + 1, x + 2
1423
1424            inp = torch.rand([4], device="cuda")
1425            inp_list = [inp]
1426            foo_cg = self.cudagraphify_impl(foo, inp_list, ())
1427            x1, x2 = foo_cg(inp_list)
1428
1429            def foo2(args):
1430                x = args[0]
1431                args.clear()
1432                return [x * x * x]
1433
1434            inp_list = [x1]
1435            foo2_cg = self.cudagraphify_impl(foo2, inp_list, ())
1436            foo2_cg(inp_list)
1437
1438            del x1, x2
1439            # TODO make configurable
1440
1441            x1, x2 = foo_cg([inp])
1442            self.assertEqual(self.num_checkpoints(), 0)
1443
1444            # input location has changed, should force recompile and checkpointing
1445            foo2_cg([torch.zeros_like(x1)])
1446
1447            self.assertEqual(self.num_checkpoints(), 1)
1448            self.assertEqual(self.get_root_children(), [2])
1449
1450        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
1451        def test_checkpoint_shared_output_storage_deallocation(self):
1452            def foo(args):
1453                x = args[0]
1454                args.clear()
1455                x_tmp = x + 1
1456                return x[0], x[1]
1457
1458            inp = torch.rand([2, 2], device="cuda")
1459            inp_list = [inp]
1460            foo_cg = self.cudagraphify_impl(foo, inp_list, ())
1461            foo_cg(inp_list)
1462            foo_cg([inp])
1463
1464            x1, x2 = foo_cg([inp])
1465            inp = [x1]
1466
1467            def foo2(args):
1468                x = args[0]
1469                args.clear()
1470                y = x * x
1471                return y[0], y[1]
1472
1473            foo2_cg = self.cudagraphify_impl(foo2, inp, ())
1474            foo2_cg(inp)
1475
1476            self.assertEqual(self.num_checkpoints(), 1)
1477            self.assertEqual(
1478                x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
1479            )
1480            self.assertEqual(all_live_block_count(), 1)
1481            del x1
1482            self.assertEqual(all_live_block_count(), 1)
1483            del x2
1484            self.assertEqual(all_live_block_count(), 0)
1485
1486        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
1487        def test_cleanup(self):
1488            def test_closure():
1489                @torch.compile
1490                def foo(x):
1491                    return x + 1 + 2, x * 10
1492
1493                foo(torch.rand([4], device="cuda"))
1494                return foo(torch.rand([4], device="cuda"))
1495
1496            out1, out2 = test_closure()
1497            torch._dynamo.reset()
1498
1499            # TODO - deallocate on tensor deallocation
1500            # self.assertTrue(self.get_manager() is not None)
1501            # del out1
1502            # self.assertTrue(self.get_manager() is not None)
1503            # del out2
1504            self.assertTrue(self.get_manager() is None)
1505
1506        @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
1507        def test_forward_backward(self):
1508            @torch.compile
1509            def foo(x):
1510                y = x * 2
1511                return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
1512
1513            inp = torch.rand([4, 4], requires_grad=True, device="cuda")
1514            out = foo(inp)
1515            out.sum().backward()
1516
1517            self.assertEqual(self.get_root_children(), [1])
1518
1519            # the three saved tensors should die in the backward
1520            # we kept alive the output
1521            self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
1522            self.assertEqual(
1523                self.curr_node().expected_dead_indices_after_graph,
1524                [(0, 1), (0, 2)],
1525            )
1526            self.assertFalse(self.get_manager().new_graph_id().id == 0)
1527
1528        def test_separate_recordings(self):
1529            def foo_unopt(x, y):
1530                return (x + 1) @ y
1531
1532            foo = torch.compile(foo_unopt)
1533
1534            foo_unopt(
1535                torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
1536            )
1537
1538            inps = [
1539                torch.ones([20, 20], device="cuda", requires_grad=False)
1540                for _ in range(2)
1541            ]
1542
1543            out = foo(*inps)
1544            torch.cuda.synchronize()
1545            foo(*inps)
1546            torch.cuda.synchronize()
1547            foo(*inps)
1548            torch.cuda.synchronize()
1549
1550            foo_unopt(
1551                torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
1552            )
1553
1554            inps2 = [
1555                torch.rand([40, 40], device="cuda", requires_grad=False)
1556                for _ in range(2)
1557            ]
1558
1559            foo(*inps2)
1560            foo(*inps2)
1561            foo(*inps2)
1562
1563            # two separate roots
1564            self.assertEqual(self.get_root_children(), [0, 0])
1565
1566        def test_alias_of_parameter(self):
1567            class AliasMod(nn.Module):
1568                def __init__(self) -> None:
1569                    super().__init__()
1570                    self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
1571
1572                def forward(self, x):
1573                    return self.param[0], self.param, self.param + x
1574
1575            @torch.compile(mode="reduce-overhead")
1576            def foo(mod, inp):
1577                return mod(inp)
1578
1579            inp = torch.rand([20, 20], device="cuda")
1580            mod = AliasMod()
1581
1582            storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
1583                mod.param.untyped_storage()
1584            )
1585
1586            for _ in range(3):
1587                outs = foo(mod, inp)
1588
1589            self.assertEqual(mod(inp), outs)
1590
1591            self.assertFalse(storage_ref.expired())
1592
1593            node = self.get_manager().current_node
1594            self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1595
1596        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
1597        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
1598        def test_unstable_ptr(self):
1599            import torch
1600
1601            @torch.compile(mode="reduce-overhead")
1602            def foo(m, inp):
1603                return m(inp)
1604
1605            def f():
1606                l = []
1607                m = torch.nn.Linear(20, 20).cuda()
1608                for _ in range(4):
1609                    inp = torch.rand([20, 20], device="cuda")
1610                    foo(m, inp)
1611                    m.weight.data = torch.rand([20, 20], device="cuda")
1612
1613            self.assertRaises(RuntimeError, f)
1614
1615        @requires_multigpu()
1616        def test_manager_per_device(self):
1617            def test():
1618                def foo(args):
1619                    x = args[0]
1620                    args.clear()
1621                    return (x + 3,)
1622
1623                inp = torch.rand([20, 20], device="cuda:1")
1624
1625                inp_list = [inp]
1626                foo_cg = tree_cudagraphify_impl(
1627                    foo,
1628                    inp_list,
1629                    (),
1630                    device_index=1,
1631                    is_backward=False,
1632                    is_inference=True,
1633                )
1634                for _ in range(3):
1635                    self.assertEqual(foo_cg([inp]), foo([inp]))
1636
1637                self.assertTrue(self.get_manager(device_index=0) is None)
1638                self.assertFalse(self.get_manager(device_index=1) is None)
1639
1640            test()
1641            self.assertTrue(self.get_manager(device_index=1) is None)
1642
1643        def test_error_on_dealloc_use(self):
1644            @torch.compile()
1645            def foo(x):
1646                return x * x * x
1647
1648            inp = torch.rand([4], device="cuda")
1649            out = foo(inp)
1650            out2 = foo(inp)
1651
1652            with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
1653                out + out
1654
1655            foo(inp)
1656
1657            with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
1658                out2 + out2
1659
1660        @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
1661        def test_conv_benchmark(self):
1662            with torch.backends.cudnn.flags(
1663                enabled=True, benchmark=True, deterministic=False
1664            ):
1665                m = torch.nn.Conv2d(5, 6, [3, 3]).cuda()
1666                inp = torch.randn([2, 5, 16, 16]).cuda()
1667
1668                @torch.compile()
1669                def foo(m, inp):
1670                    return m(inp)
1671
1672                foo(m, inp)
1673
1674        def test_single_stream_use(self):
1675            @torch.compile()
1676            def foo(x):
1677                return (x * x * x).relu()
1678
1679            inp = torch.rand([4], device="cuda", requires_grad=True)
1680            streams = set()
1681            streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
1682            for _ in range(4):
1683                foo(inp).sum().backward()
1684
1685            streams = {
1686                seg["stream"] for seg in get_all_cudagraph_segments()
1687            } - streams_init
1688            self.assertEqual(len(streams), 1)
1689            self.assertFalse(self.get_manager().new_graph_id().id == 0)
1690
1691        @torch._dynamo.config.patch("assume_static_by_default", False)
1692        def test_dynamic_backward(self):
1693            def foo(x):
1694                x = torch.cat([x, x])
1695                return torch.addmm(x, x, x).relu(), x.size(0)
1696
1697            opt_foo = torch.compile(mode="reduce-overhead")(foo)
1698
1699            def run_test(foo, inp):
1700                r, s = foo(inp)
1701                r.sum().backward()
1702                g = inp.grad.clone()
1703                inp.grad = None
1704                r = r.clone()
1705                return r, s, g
1706
1707            def run_big_test(inp):
1708                r0, s0, g0 = run_test(foo, inp)
1709                r1, s1, g1 = run_test(opt_foo, inp)
1710                r2, s2, g2 = run_test(opt_foo, inp)
1711                self.assertEqual(r0, r1)
1712                self.assertEqual(r0, r2)
1713                self.assertEqual(s0, s1)
1714                self.assertEqual(s0, s2)
1715                self.assertEqual(g0, g1)
1716                self.assertEqual(g0, g2)
1717
1718            inp = torch.randn(2, 4, device="cuda", requires_grad=True)
1719            run_big_test(inp)
1720
1721            inp = torch.randn(3, 6, device="cuda", requires_grad=True)
1722            run_big_test(inp)
1723
1724        def test_dynamic_warmup(self):
1725            COUNTER = 0
1726
1727            def f(inps):
1728                i, x = inps
1729                inps.clear()
1730                nonlocal COUNTER
1731                COUNTER += 1
1732                return x * 2
1733
1734            x = torch.randn(2, device="cuda")
1735            inp_list = [2, x]
1736            foo_cg = self.cudagraphify_impl(f, inp_list, ())
1737            foo_cg(inp_list)  # warmup
1738            foo_cg([2, x])  # record
1739            foo_cg([2, x])  # replay
1740            self.assertEqual(COUNTER, 2)
1741
1742            # Switching the size will require a warmup again
1743            x = torch.randn(3, device="cuda")
1744            inp_list = [3, x]
1745            foo_cg(inp_list)  # warmup
1746            foo_cg([3, x])  # record
1747            foo_cg([3, x])  # replay
1748            self.assertEqual(COUNTER, 4)
1749
1750        def test_forward_generation(self):
1751            def foo(x):
1752                return x * x * x
1753
1754            def foo2(x):
1755                return x * 12
1756
1757            foo_opt = torch.compile(foo)
1758            foo2_opt = torch.compile(foo2)
1759            ones = torch.ones([4, 4], device="cuda", requires_grad=True)
1760
1761            out = foo_opt(ones)
1762            out2 = foo2_opt(out)
1763
1764            self.assertEqual(all_live_block_count(), 2)
1765
1766            self.assertTrue(self.get_manager().running_forwards_with_pending_backwards)
1767
1768            out2.sum().backward()
1769            self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
1770
1771            del out
1772            del out2
1773
1774            foo2_opt(foo_opt(ones)).sum().backward()
1775
1776            out = foo_opt(ones.detach())
1777            self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
1778            self.assertFalse(self.get_manager().new_graph_id().id == 0)
1779
1780        def test_warn_on_pending_backward(self):
1781            @torch.compile
1782            def foo(x):
1783                return x * x * x
1784
1785            out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1786            out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1787
1788            warnings.resetwarnings()
1789            with warnings.catch_warnings(record=True) as w:
1790                out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1791
1792            FileCheck().check(
1793                "Unable to hit fast path of CUDAGraphs because of pending"
1794            ).run(str(w[0]))
1795            self.assertTrue(self.get_manager().new_graph_id().id == 0)
1796
1797        def test_mark_step(self):
1798            @torch.compile
1799            def foo(x):
1800                return x * x * x
1801
1802            torch.compiler.cudagraph_mark_step_begin()
1803            out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1804
1805            torch.compiler.cudagraph_mark_step_begin()
1806            out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
1807            self.assertFalse(self.get_manager().new_graph_id().id == 0)
1808
1809        @torch._dynamo.config.patch("capture_scalar_outputs", True)
1810        def test_incompatible_cudagraph_ops_item(self):
1811            @torch.compile(mode="reduce-overhead")
1812            def foo(x):
1813                return x.item()
1814
1815            # NB: This doesn't work with float, because float unbacked codegen
1816            # is currently broken.  But testing the float case here is also
1817            # awkward, because we plan to Tensor-ify the float compute, and as
1818            # a result we'd actually expect this to work with cuda graphs!
1819            with capture_stderr() as captured_output:
1820                self.assertEqual(foo(torch.tensor(3, device="cuda")), 3)
1821                self.assertEqual(foo(torch.tensor(6, device="cuda")), 6)
1822
1823            # NOTE: this test is named after incompatible ops, but is not skipping due to incompatible ops.
1824            # This should get fixed.
1825            FileCheck().check(
1826                "skipping cudagraphs due to cpu device (_local_scalar_dense)"
1827            ).run(captured_output[0])
1828            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
1829
1830        @torch._dynamo.config.patch("compiled_autograd", True)
1831        def test_compiled_autograd_static_input_params(self):
1832            @torch.compile(mode="reduce-overhead")
1833            def bwd(loss):
1834                loss.backward()
1835
1836            model = torch.nn.Linear(10, 10, bias=False, device="cuda")
1837            x = torch.randn(10, 10, device="cuda")
1838            for i in range(5):
1839                out = model(x)
1840                bwd(out.sum())
1841                model.weight.grad = None
1842
1843            # i=0, 0 copies (warmup)
1844            # i=1, 2 copies (record, 1/3 inputs marked as static)
1845            # i>1, 0 copies (run)
1846            self.assertEqual(
1847                counters["inductor"]["cudagraph_recorded_non_static_inputs"], 2
1848            )
1849
1850        @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
1851        def test_incompatible_cudagraph_ops_nonzero(self):
1852            @torch.compile(mode="reduce-overhead")
1853            def foo(x):
1854                return x.nonzero()
1855
1856            with capture_stderr() as captured_output:
1857                self.assertEqual(
1858                    foo(torch.tensor([1, 0, 2], device="cuda")),
1859                    torch.tensor([[0], [2]]),
1860                )
1861                self.assertEqual(
1862                    foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
1863                )
1864
1865            FileCheck().check("skipping cudagraphs due to ['incompatible ops']").run(
1866                captured_output[0]
1867            )
1868            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
1869
1870        @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
1871        def test_incompatible_cudagraph_ops_nonzero_graph_breaks(self):
1872            @torch.compile(mode="reduce-overhead")
1873            def foo(x):
1874                y = x.nonzero()  # skip
1875                torch._dynamo.graph_break()
1876                return y.nonzero()  # skip 2 times (due to recompile)
1877
1878            foo(torch.tensor([1, 0, 2], device="cuda"))
1879            foo(torch.tensor([1, 0, 0], device="cuda"))
1880
1881            self.assertEqual(counters["inductor"]["cudagraph_skips"], 3)
1882
1883        @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
1884        def test_incompatible_cudagraph_ops_nonzero_backend(self):
1885            @torch.compile(backend="cudagraphs")
1886            def foo(x):
1887                return x.nonzero()
1888
1889            with capture_stderr() as captured_output:
1890                self.assertEqual(
1891                    foo(torch.tensor([1, 0, 2], device="cuda")),
1892                    torch.tensor([[0], [2]]),
1893                )
1894                self.assertEqual(
1895                    foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
1896                )
1897
1898            FileCheck().check(
1899                "skipping cudagraphs due to incompatible op (nonzero)"
1900            ).run(captured_output[0])
1901            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
1902
1903        def test_storage_access_error(self):
1904            x = torch.rand([4], device="cuda")
1905            torch._C._set_storage_access_error_msg(x, "custom error msg")
1906
1907            with self.assertRaisesRegex(Exception, "custom error msg"):
1908                device = x.untyped_storage()
1909
1910        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
1911        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
1912        def test_static_inputs_address_mutation_log(self):
1913            class Goo(torch.nn.Module):
1914                def __init__(self) -> None:
1915                    super().__init__()
1916                    self.linear = torch.nn.Linear(2, 2, device="cuda")
1917
1918                def forward(self, x) -> torch.Tensor:
1919                    return self.linear(x)
1920
1921            class Foo(torch.nn.Module):
1922                def __init__(self) -> None:
1923                    super().__init__()
1924                    self.static_tensor = torch.zeros((2, 2), device="cuda")
1925                    self.goo = Goo()
1926
1927                def forward(self, x) -> torch.Tensor:
1928                    self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
1929                    return self.static_tensor + x + self.goo(x)
1930
1931            foo = Foo()
1932            foo = torch.compile(foo, mode="reduce-overhead")
1933            inp = torch.rand((2, 2), device="cuda")
1934
1935            for _ in range(3):
1936                foo(inp)
1937
1938            # mutates static input tensors' addresses
1939            foo.static_tensor = torch.ones((2, 2), device="cuda")
1940            foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
1941
1942            with self.assertRaisesRegex(
1943                Exception,
1944                r"static input data pointer changed.\n"
1945                r"input name: primals_2. data pointer changed from .* to .*. input stack trace:(?s).*"
1946                r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
1947                r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
1948            ):
1949                self.curr_node().run(
1950                    [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]
1951                )
1952
1953        def _run_iter(self, param, fn):
1954            fwd_output = fn(torch.ones(2, 2), param)
1955            fwd_output.sum().backward()
1956            grad_output = param.grad.clone().detach()
1957            param.grad = None
1958            return fwd_output, grad_output
1959
1960        def _assert_equal_multi_loop(self, param, fn_eager, fn_compiled):
1961            exp_output, exp_grad = self._run_iter(param, fn_eager)
1962            for _ in range(5):
1963                compiled_output, compiled_grad = self._run_iter(param, fn_compiled)
1964                self.assertEqual(exp_output, compiled_output)
1965                self.assertEqual(exp_grad, compiled_grad)
1966
1967        def run_static_input_param_test(self, fn_eager, num_graphs):
1968            with torch.device("cuda"):
1969                fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")
1970
1971                p1 = torch.nn.Parameter(torch.rand([2, 2]))
1972                self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
1973
1974                p2 = torch.nn.Parameter(torch.rand([2, 2]))
1975                self._assert_equal_multi_loop(p2, fn_eager, fn_compiled)
1976
1977                # Run p1 again to ensure we reuse the previous recording
1978                self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
1979
1980                self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
1981
1982        def _module_test(self, mod, name="weight", param_wrapping=True):
1983            with torch.device("cuda"):
1984
1985                def fn(x, mod):
1986                    return mod(x)
1987
1988                fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True)
1989
1990                def run_test_iter(mod, fn):
1991                    fwd_output = fn(torch.ones(2, 2), mod)
1992                    fwd_output.sum().backward()
1993                    grad_output = mod.weight.grad.clone().detach()
1994                    mod.zero_grad()
1995                    return fwd_output, grad_output
1996
1997                def run_test():
1998                    exp_output, exp_grad = run_test_iter(mod, fn)
1999                    for _ in range(5):
2000                        compiled_output, compiled_grad = run_test_iter(mod, fn_compiled)
2001                        self.assertEqual(exp_output, compiled_output)
2002                        self.assertEqual(exp_grad, compiled_grad)
2003
2004                run_test()
2005                old_attr = getattr(mod, name)
2006                modified_attr = torch.rand_like(old_attr)
2007                if param_wrapping:
2008                    modified_attr = torch.nn.Parameter(modified_attr)
2009                setattr(mod, name, modified_attr)
2010                run_test()
2011                # Run original version to verify we reuse the other recording
2012                setattr(mod, name, old_attr)
2013                run_test()
2014
2015                # Fwd + bwd graphs for each version of the function => 4 graphs
2016                self.assertEqual(self.get_manager().new_graph_id().id, 4)
2017
2018        @torch._dynamo.config.patch("error_on_recompile", True)
2019        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2020        def test_multi_dispatch_single_compile_param_inputs(self):
2021            # Verify that we can record multiple cudagraphs for a single
2022            # compiled function with param inputs
2023            def fn(x, y):
2024                return x * y
2025
2026            # Fwd + bwd graphs for each version of the function => 4 graphs
2027            self.run_static_input_param_test(fn, 4)
2028
2029        @torch._dynamo.config.patch("error_on_recompile", True)
2030        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2031        def test_multi_dispatch_single_compile_builtin_module(self):
2032            # Verify that we don't recompile when changing the param of a builtin module
2033            # and that we record another cudagraph
2034            # Note: Linear is a builtin module so we enable that config setting above
2035            self._module_test(torch.nn.Linear(2, 3, device="cuda"))
2036
2037        @torch._dynamo.config.patch("error_on_recompile", True)
2038        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2039        def test_multi_dispatch_single_compile_builtin_module_buffers(self):
2040            # Verify that we don't recompile when changing the buffer of a builtin module
2041            # and that we record another cudagraph
2042            self._module_test(
2043                torch.nn.BatchNorm1d(2, device="cuda"),
2044                name="running_mean",
2045                param_wrapping=False,
2046            )
2047
2048        @torch._inductor.config.patch("triton.cudagraphs", True)
2049        @torch._dynamo.config.patch("error_on_recompile", True)
2050        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2051        def test_multi_dispatch_custom_module(self):
2052            # Test that we can correctly dispatch multiple graphs
2053            # if params of a custom module change
2054            class TestModule(torch.nn.Module):
2055                def __init__(self, param) -> None:
2056                    super().__init__()
2057                    self.weight = param
2058
2059                def forward(self, x):
2060                    return self.weight * x
2061
2062            self._module_test(
2063                TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
2064            )
2065
2066        @torch._dynamo.config.patch("error_on_recompile", True)
2067        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2068        def test_multi_dispatch_custom_module_buffer(self):
2069            # Test that we can correctly dispatch multiple graphs
2070            # if buffers of a custom module change
2071            class TestModule(torch.nn.Module):
2072                def __init__(self, param, buf) -> None:
2073                    super().__init__()
2074                    self.weight = param
2075                    self.buf = torch.nn.Buffer(buf)
2076
2077                def forward(self, x):
2078                    return x * self.weight + self.buf
2079
2080            self._module_test(
2081                TestModule(
2082                    torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
2083                    torch.rand([2, 2], device="cuda"),
2084                ),
2085                name="buf",
2086                param_wrapping=False,
2087            )
2088
2089        @torch._inductor.config.patch("triton.cudagraphs", True)
2090        @torch._dynamo.config.patch("error_on_recompile", True)
2091        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2092        def test_multi_dispatch_child_node(self):
2093            # Test that we can correctly dispatch multiple graphs if a child node
2094            # in the tree has stable input pointers change
2095            def fn(x, p):
2096                # Graph 1
2097                y = x * x
2098                torch._dynamo.graph_break()
2099                # Graph 2
2100                return y * p
2101
2102            # We have 5 graphs here
2103            #            Graph 1
2104            #       /                \
2105            # Graph 2 w/ p1     Graph 2 w/ p2
2106            # and then two backward graphs
2107            self.run_static_input_param_test(fn, 5)
2108
2109        @torch._dynamo.config.patch("error_on_recompile", True)
2110        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2111        def test_multi_dispatch_parent_node(self):
2112            def fn(x, p):
2113                # Graph 1
2114                y = x * p
2115                torch._dynamo.graph_break()
2116                # Graph 2
2117                return y + x
2118
2119            # We have 6 graphs here
2120            #    Graph 1 w/ p1    Graph 1 w/ p2
2121            #          |                |
2122            #     Graph 2 (v1)     Graph 2 (v2)
2123            # There are two versions of graph 2 because
2124            # we re-record due to different memory state after running the
2125            # two versions of Graph 1
2126            # and then two backward graphs
2127            self.run_static_input_param_test(fn, 6)
2128
2129        @torch._dynamo.config.patch("error_on_recompile", True)
2130        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
2131        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
2132        @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
2133        def test_fallback_to_eager_if_recompiling_too_many_times(self):
2134            class Foo(torch.nn.Module):
2135                def __init__(self) -> None:
2136                    super().__init__()
2137                    self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
2138
2139                def forward(self, x):
2140                    return x * self.param
2141
2142            with capture_stderr() as captured_output:
2143                # We have 3 graphs here
2144                #             None
2145                #       /                           \
2146                # (fwd w/ p1, Graph 0)            (bwd w/p2, Graph2)
2147                # (bwd w/ p1, Graph 1)
2148                # All other graphs are skipped because we hit the max recording limit
2149                # (=0 for each node and function pair)
2150                fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
2151                for _ in range(3):
2152                    fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
2153
2154                # Change static tensor address
2155                fn_compiled.param.data = torch.rand([2, 2], device="cuda")
2156                fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
2157                self.assertEqual(self.get_manager().new_graph_id().id, 3)
2158
2159            FileCheck().check(
2160                "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
2161                "on cudagraph node None due to static input data pointer changed."
2162            ).run(captured_output[0])
2163            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
2164
2165        @torch._dynamo.config.patch("error_on_recompile", True)
2166        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
2167        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
2168        @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
2169        def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self):
2170            class Foo(torch.nn.Module):
2171                def __init__(self) -> None:
2172                    super().__init__()
2173                    self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
2174
2175                def forward(self, x):
2176                    return x * self.param
2177
2178            with capture_stderr() as captured_output:
2179                with torch.device("cuda"):
2180                    # We have 3 graphs here
2181                    #             None
2182                    #       /                           \
2183                    # (fwd w/ p1, Graph 0)            (bwd w/p2, Graph2)
2184                    # (bwd w/ p1, Graph 1)
2185                    # All other graphs are skipped because we hit the max recording limit
2186                    # (=0 for each node and function pair)
2187                    fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
2188                    for _ in range(3):
2189                        fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
2190
2191                    for _ in range(5):
2192                        # Change static tensor address
2193                        fn_compiled.param.data = torch.rand([2, 2], device="cuda")
2194                        fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
2195
2196            FileCheck().check_count(
2197                "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
2198                "on cudagraph node None due to static input data pointer changed.",
2199                1,
2200                exactly=True,
2201            ).check_count(
2202                "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
2203                "on cudagraph node None due to static input data pointer changed.",
2204                1,
2205                exactly=True,
2206            ).run(
2207                captured_output[0]
2208            )
2209            self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
2210
2211        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
2212        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
2213        @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
2214        def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor(
2215            self,
2216        ):
2217            # By setting triton.cudagraph_support_input_mutation=True, we force re-record
2218            # if cudagraph managed tensor addresses changed.
2219            @torch.compile(mode="reduce-overhead")
2220            def foo(x):
2221                return x + 1
2222
2223            @torch.compile(mode="reduce-overhead")
2224            def goo(x):
2225                return x * 2
2226
2227            for _ in range(3):
2228                torch.compiler.cudagraph_mark_step_begin()
2229                inp = torch.rand((2, 3), device="cuda")
2230                y = foo(inp)
2231                z = goo(y)
2232
2233            with capture_stderr() as captured_output:
2234                torch.compiler.cudagraph_mark_step_begin()
2235                x = torch.rand(2, 3, device="cuda")
2236                y = foo(x)
2237                y_clone = y.clone()
2238                z = goo(y_clone)
2239
2240            # eager function should run successfully
2241            for _ in range(5):
2242                torch.compiler.cudagraph_mark_step_begin()
2243                x = torch.rand(2, 3, device="cuda")
2244                y = foo(x)
2245                y_clone = y.clone()
2246                z = goo(y_clone)
2247
2248            FileCheck().check_count(
2249                "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
2250                "on cudagraph node 0 due to cudagraph managed tensor data pointer changed",
2251                1,
2252                exactly=True,
2253            ).run(captured_output[0])
2254            self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
2255
2256        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
2257        @torch._dynamo.config.patch("error_on_recompile", True)
2258        @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2259        @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1)
2260        def test_not_fallback_to_eager_if_have_not_recompiling_too_many_times(self):
2261            def fn(x, y):
2262                return x * y
2263
2264            # We have 4 graphs here
2265            #             None
2266            #       /                           \
2267            # (fwd w/ p1, Graph 0)            (fwd w/p2, Graph2)
2268            # (bwd w/ p1, Graph 1)            (bwd w/p2, Graph3)
2269            self.run_static_input_param_test(fn, 4)
2270            self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
2271
2272        def test_tensor_constant_mutation(self):
2273            class Foo(torch.nn.Module):
2274                def __init__(self) -> None:
2275                    super().__init__()
2276                    self.tensor_constant = torch.ones((2, 3), device="cuda")
2277
2278                def forward(self, x: torch.Tensor) -> torch.Tensor:
2279                    self.tensor_constant += 1
2280                    return x + self.tensor_constant
2281
2282            foo = Foo()
2283            foo = torch.compile(foo, mode="reduce-overhead")
2284            inp = torch.rand((2, 3), device="cuda")
2285            for _ in range(3):
2286                foo(inp)
2287
2288        @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
2289        def test_rerecord_if_static_input_address_changed(self):
2290            # By setting triton.cudagraph_support_input_mutation=True, we force re-record
2291            # if static tensor addresses changed.
2292            class Goo(torch.nn.Module):
2293                def __init__(self) -> None:
2294                    super().__init__()
2295                    self.linear = torch.nn.Linear(2, 2, device="cuda")
2296
2297                def forward(self, x) -> torch.Tensor:
2298                    return self.linear(x)
2299
2300            class Foo(torch.nn.Module):
2301                def __init__(self) -> None:
2302                    super().__init__()
2303                    self.register_buffer(
2304                        "static_tensor", torch.zeros((2, 2), device="cuda")
2305                    )
2306                    self.goo = Goo()
2307
2308                def forward(self, x) -> torch.Tensor:
2309                    self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
2310                    return self.static_tensor + x + self.goo(x)
2311
2312            foo = Foo()
2313            foo = torch.compile(foo, mode="reduce-overhead")
2314            inp = torch.rand((2, 2), device="cuda")
2315
2316            for _ in range(3):
2317                foo(inp)
2318
2319            # mutates static input tensors' addresses
2320            foo.static_tensor = torch.ones((2, 2), device="cuda")
2321            foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
2322
2323            if torch._dynamo.config.inline_inbuilt_nn_modules:
2324                for _ in range(3):
2325                    foo(inp)
2326            else:
2327                # Run with specific function id to avoid dynamo recompiling
2328                self.get_manager().run(
2329                    [
2330                        foo.goo.linear.weight,
2331                        foo.goo.linear.bias,
2332                        foo.static_tensor,
2333                        inp,
2334                    ],
2335                    FunctionID(0),
2336                )
2337
2338            self.assertEqual(self.get_manager().new_graph_id().id, 2)
2339
2340        @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
2341        def test_skip_if_dynamic_shape_limit_reached1(self):
2342            class Mod(torch.nn.Module):
2343                def __init__(self) -> None:
2344                    super().__init__()
2345                    self.linear = torch.nn.Linear(3, 3, device="cuda")
2346
2347                def forward(self, x: torch.Tensor) -> torch.Tensor:
2348                    return self.linear(x)
2349
2350            def iter(batch_size: int, mod: torch.nn.Module):
2351                x = torch.rand((batch_size, 3), device="cuda")
2352                for _ in range(3):
2353                    mod(x)
2354
2355            mod = torch.compile(Mod(), mode="reduce-overhead")
2356
2357            with capture_stderr() as captured_output:
2358                for batch_size in range(10, 40, 10):
2359                    iter(batch_size, mod)
2360
2361            FileCheck().check(
2362                "CUDAGraph supports dynamic shapes by recording a new graph for each "
2363                "distinct input size. Recording too many CUDAGraphs may lead to "
2364                "extra overhead. We have observed 2 distinct sizes. "
2365                "Please consider the following options for better performance: "
2366                "a) padding inputs to a few fixed number of shapes; or b) set "
2367                "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
2368                "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
2369                "to silence this warning."
2370            ).run("\n".join(captured_output))
2371
2372        @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
2373        def test_skip_if_dynamic_shape_limit_reached2(self):
2374            class Mod(torch.nn.Module):
2375                def __init__(self) -> None:
2376                    super().__init__()
2377                    self.attn = torch.nn.MultiheadAttention(
2378                        embed_dim=3, num_heads=3, device="cuda"
2379                    )
2380
2381                def forward(
2382                    self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
2383                ) -> torch.Tensor:
2384                    return self.attn(q, k, v)
2385
2386            mod = torch.compile(Mod(), mode="reduce-overhead")
2387
2388            def iter(batch_size: int, length: int):
2389                q = torch.rand((batch_size, length, 3), device="cuda")
2390                k = torch.rand((batch_size, length, 3), device="cuda")
2391                v = torch.rand((batch_size, length, 3), device="cuda")
2392                for _ in range(3):
2393                    mod(q, k, v)
2394
2395            with capture_stderr() as captured_output:
2396                for batch_size in range(10, 40, 10):
2397                    for length in range(10, 30, 10):
2398                        iter(batch_size, length)
2399
2400            print(captured_output)
2401            FileCheck().check(
2402                "CUDAGraph supports dynamic shapes by recording a new graph for each "
2403                "distinct input size. Recording too many CUDAGraphs may lead to "
2404                "extra overhead. We have observed 2 distinct sizes. "
2405                "Please consider the following options for better performance: "
2406                "a) padding inputs to a few fixed number of shapes; or b) set "
2407                "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
2408                "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
2409                "to silence this warning."
2410            ).run(captured_output[0])
2411
2412        @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
2413        def test_warn_once_if_dynamic_shape_limit_reached(self):
2414            class Mod(torch.nn.Module):
2415                def __init__(self) -> None:
2416                    super().__init__()
2417                    self.linear = torch.nn.Linear(3, 3, device="cuda")
2418
2419                def forward(self, x: torch.Tensor) -> torch.Tensor:
2420                    return self.linear(x)
2421
2422            def iter(batch_size: int, mod: torch.nn.Module):
2423                x = torch.rand((batch_size, 3), device="cuda")
2424                for _ in range(3):
2425                    mod(x)
2426
2427            mod = torch.compile(Mod(), mode="reduce-overhead")
2428
2429            with capture_stderr() as captured_output:
2430                for batch_size in range(10, 200, 10):
2431                    iter(batch_size, mod)
2432
2433            print(captured_output)
2434
2435            FileCheck().check_count(
2436                "CUDAGraph supports dynamic shapes by recording a new graph for each "
2437                "distinct input size. Recording too many CUDAGraphs may lead to "
2438                "extra overhead. We have observed 2 distinct sizes. "
2439                "Please consider the following options for better performance: "
2440                "a) padding inputs to a few fixed number of shapes; or b) set "
2441                "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
2442                "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
2443                "to silence this warning.",
2444                1,
2445                exactly=True,
2446            ).run("\n".join(captured_output))
2447
2448        @torch._inductor.config.patch("cpp_wrapper", 1)
2449        def test_cpp_wrapper(self):
2450            def f(x):
2451                return torch.sin(x)
2452
2453            compiled = torch.compile(f, mode="reduce-overhead")
2454            example_input = torch.randn(10, device="cuda")
2455            compiled_result = self.run_twc(compiled, example_input)
2456            eager_result = f(example_input)
2457            self.assertEqual(compiled_result, eager_result)
2458
2459    instantiate_parametrized_tests(CudaGraphTreeTests)
2460
2461if __name__ == "__main__":
2462    from torch._inductor.test_case import run_tests
2463
2464    if not TEST_CUDA_GRAPH:
2465        if __name__ == "__main__":
2466            sys.exit(0)
2467        raise unittest.SkipTest("cuda graph test is skipped")
2468
2469    if HAS_CPU or HAS_CUDA:
2470        run_tests(needs="filelock")
2471