xref: /aosp_15_r20/external/pytorch/test/inductor/test_compiled_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import dataclasses
4import functools
5import io
6import itertools
7import logging
8import os
9import re
10import subprocess
11import sys
12import unittest
13from importlib.machinery import SourceFileLoader
14from pathlib import Path
15from unittest import mock
16
17import torch
18import torch.nn as nn
19import torch.nn.functional as F
20from torch import _inductor as inductor
21from torch._dynamo import compiled_autograd, config
22from torch._dynamo.backends.debugging import aot_eager
23from torch._dynamo.utils import counters
24from torch._inductor import config as inductor_config
25from torch._inductor.test_case import run_tests, TestCase
26from torch.testing._internal.common_utils import skipIfWindows
27from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
28from torch.testing._internal.logging_utils import logs_to_string
29
30
31# note: these tests are not run on windows due to inductor_utils.HAS_CPU
32
33
34def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"):
35    assert backend in ["inductor", "aot_eager"]
36
37    def _compiler_fn(gm):
38        """Same as torch.compile() but counts number of compiles"""
39
40        def _inner_compiler(gm_, example_inputs_):
41            counters["compiled_autograd"]["compiles"] += 1
42            if backend == "inductor":
43                return inductor.compile(gm_, example_inputs_)
44            elif backend == "aot_eager":
45                return aot_eager(gm_, example_inputs_)
46
47        return torch.compile(
48            gm, backend=_inner_compiler, fullgraph=fullgraph, dynamic=dynamic
49        )
50
51    return _compiler_fn
52
53
54compiler_fn = make_compiler_fn()
55
56
57# TODO(jansel): hooks as lambdas creates recompiles in dynamo, we should fix that
58def hook1(grad):
59    return grad * 2
60
61
62def hook2(grads):
63    return (grads[0] + 1,)
64
65
66def hook3(gI, gO):
67    return (torch.sin(gI[0]) + gO[0],)
68
69
70class TestCompiledAutograd(TestCase):
71    def setUp(self) -> None:
72        super().setUp()
73        torch._logging.set_logs(compiled_autograd_verbose=False)
74        config.compiled_autograd = False
75        compiled_autograd.reset()
76
77    def tearDown(self) -> None:
78        super().tearDown()
79        torch._logging.set_logs(compiled_autograd_verbose=False)
80        config.compiled_autograd = False
81        compiled_autograd.reset()
82
83    def check_output_and_recompiles(
84        self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
85    ):
86        if isinstance(count, list):
87            captures, compiles = count
88        else:
89            captures, compiles = count, count
90        with torch.autograd.set_multithreading_enabled(False):
91            torch._dynamo.reset()
92            counters["compiled_autograd"].clear()
93            torch.manual_seed(123)
94            expected = list(fn())
95            torch.manual_seed(123)
96            with compiled_autograd.enable(compiler_fn):
97                opt_fn = torch.compile(fn) if compile_fn else fn
98                actual = list(opt_fn())
99            self.assertEqual(expected, actual)
100            self.assertEqual(counters["compiled_autograd"]["captures"], captures)
101            self.assertEqual(counters["compiled_autograd"]["compiles"], compiles)
102
103    def run_as_subprocess(self, script) -> bytes:
104        try:
105            return subprocess.check_output(
106                [sys.executable, "-c", script],
107                stderr=subprocess.STDOUT,
108                # On Windows, opening the subprocess with the default CWD makes `import torch`
109                # fail, so just set CWD to this script's directory
110                cwd=os.path.dirname(os.path.realpath(__file__)),
111            )
112        except subprocess.CalledProcessError as e:
113            self.fail(f"Subprocess exited with return code: {e.returncode}")
114
115    def test_dynamo_flaky_segfault(self):
116        script = """
117import torch
118
119def main():
120    def compiler_fn(gm):
121        return torch.compile(gm, backend="eager")
122
123    def inner():
124        x = torch.randn(1000, 3000)
125        w = torch.randn(1000, 3000, requires_grad=True)
126        def model(i):
127            return torch.nn.functional.linear(i, w)
128        out = model(x)
129        loss = out.sum()
130        with torch._dynamo.compiled_autograd.enable(compiler_fn):
131            loss.backward()
132        assert(w.grad is not None)
133
134    inner()
135    torch._dynamo.reset()
136    inner()
137
138main()
139        """
140        # Run it three times to catch bad dynamo state resets
141        for _ in range(3):
142            self.run_as_subprocess(script)
143
144    def test_basic(self):
145        def fn():
146            model = torch.nn.Sequential(
147                torch.nn.Linear(4, 4),
148                torch.nn.ReLU(),
149                torch.nn.Linear(4, 4),
150                torch.nn.ReLU(),
151            )
152            x = torch.randn([2, 4])
153            result = model(x).sum()
154            result.backward()
155            yield model[0].weight.grad
156            yield model[0].bias.grad
157            yield model[2].weight.grad
158            yield model[2].bias.grad
159
160        self.check_output_and_recompiles(fn)
161
162    def test_cache_hit(self):
163        def fn():
164            for _ in range(3):
165                model = torch.nn.Sequential(
166                    torch.nn.Linear(4, 4),
167                    torch.nn.ReLU(),
168                    torch.nn.Linear(4, 4),
169                    torch.nn.ReLU(),
170                )
171                x = torch.randn([2, 4])
172                result = model(x).sum()
173                result.backward()
174                yield model[0].weight.grad
175                yield model[0].bias.grad
176                yield model[2].weight.grad
177                yield model[2].bias.grad
178
179        self.check_output_and_recompiles(fn)
180
181    def test_graph_break_custom_op(self):
182        @torch.library.custom_op("mylib::sin", mutates_args={})
183        def sin(x: torch.Tensor) -> torch.Tensor:
184            return x.sin()
185
186        def setup_context(ctx, inputs, output):
187            (x,) = inputs
188            ctx.save_for_backward(x)
189
190        def backward(ctx, grad):
191            (x,) = ctx.saved_tensors
192            return grad * x.cos()
193
194        sin.register_autograd(backward, setup_context=setup_context)
195
196        x = torch.randn(3, requires_grad=True)
197        y = sin(x.clone()).sum()
198        with compiled_autograd.enable(compiler_fn):
199            y.backward()
200
201    def test_tensor_grad_hook1(self):
202        def fn():
203            for _ in range(3):
204                model = torch.nn.Sequential(
205                    torch.nn.Linear(4, 4),
206                    torch.nn.ReLU(),
207                )
208                x = torch.randn([2, 4])
209
210                model[0].weight.register_hook(hook1)
211
212                result = model(x).sum()
213                result.backward()
214                yield model[0].weight.grad
215                yield model[0].bias.grad
216
217        self.check_output_and_recompiles(fn)
218
219    def test_tensor_grad_hook2(self):
220        def fn():
221            for _ in range(3):
222                model = torch.nn.Sequential(
223                    torch.nn.Linear(4, 4),
224                    torch.nn.ReLU(),
225                )
226                x = torch.randn([1, 4])
227
228                result = model(x).sum()
229                result.grad_fn.register_prehook(hook2)
230                result.backward()
231                yield model[0].weight.grad
232                yield model[0].bias.grad
233
234        self.check_output_and_recompiles(fn)
235
236    def test_tensor_grad_hook3(self):
237        def fn():
238            for _ in range(3):
239                model = torch.nn.Sequential(
240                    torch.nn.Linear(4, 4),
241                    torch.nn.ReLU(),
242                )
243                x = torch.randn([1, 4])
244
245                result = model(x).sum()
246                result.grad_fn.register_hook(hook3)
247                result.backward()
248                yield model[0].weight.grad
249                yield model[0].bias.grad
250
251        self.check_output_and_recompiles(fn)
252
253    def test_torch_compile(self):
254        def fn():
255            model = torch.nn.Sequential(
256                torch.nn.Linear(4, 4),
257                torch.nn.Sigmoid(),
258            )
259            opt_model = torch.compile(model, fullgraph=True)
260
261            for _ in range(3):
262                x = torch.randn([1, 4])
263
264                result = opt_model(x).sum()
265                result.backward()
266                yield model[0].weight.grad
267                yield model[0].bias.grad
268                model.zero_grad()
269
270        self.check_output_and_recompiles(fn)
271
272    def test_torch_compile_api_inductor(self):
273        def fn():
274            torch.manual_seed(123)
275            model = torch.nn.Sequential(
276                torch.nn.Linear(4, 4),
277                torch.nn.Sigmoid(),
278            )
279
280            res = []
281            for _ in range(3):
282                x = torch.randn([1, 4])
283
284                result = model(x).sum()
285                result.backward()
286                res.append(model[0].weight.grad)
287                res.append(model[0].bias.grad)
288                model.zero_grad()
289            return res
290
291        expected = fn()
292        with config.patch(compiled_autograd=True):
293            compiled_fn = torch.compile(fn)
294        actual = compiled_fn()
295        self.assertEqual(expected, actual)
296        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
297
298    def test_torch_compile_api_aot_eager(self):
299        def fn():
300            torch.manual_seed(123)
301            model = torch.nn.Sequential(
302                torch.nn.Linear(4, 4),
303                torch.nn.Sigmoid(),
304            )
305
306            res = []
307            for _ in range(3):
308                x = torch.randn([1, 4])
309
310                result = model(x).sum()
311                result.backward()
312                res.append(model[0].weight.grad)
313                res.append(model[0].bias.grad)
314                model.zero_grad()
315            return res
316
317        expected = fn()
318        with config.patch(compiled_autograd=True):
319            compiled_fn = torch.compile(fn, backend="aot_eager")
320        actual = compiled_fn()
321        self.assertEqual(expected, actual)
322        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
323
324    def test_torch_compile_api_eager(self):
325        def fn():
326            torch.manual_seed(123)
327            model = torch.nn.Sequential(
328                torch.nn.Linear(4, 4),
329                torch.nn.Sigmoid(),
330            )
331
332            res = []
333            for _ in range(3):
334                x = torch.randn([1, 4])
335
336                result = model(x).sum()
337                result.backward()
338                res.append(model[0].weight.grad)
339                res.append(model[0].bias.grad)
340                model.zero_grad()
341            return res
342
343        expected = fn()
344        with config.patch(compiled_autograd=True):
345            compiled_fn = torch.compile(fn, backend="eager")
346        actual = compiled_fn()
347        self.assertEqual(expected, actual)
348        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
349
350    def test_multiple_torch_compile(self):
351        model = torch.nn.Sequential(
352            torch.nn.Linear(4, 4),
353            torch.nn.Sigmoid(),
354        )
355        x = torch.randn([1, 4])
356
357        def fn():
358            result = model(x).sum()
359            result.backward()
360
361        model2 = torch.nn.Linear(4, 4)
362        x2 = torch.randn([1, 4])
363
364        def fn2():
365            result = model2(x2).sum()
366            result.backward()
367
368        no_ca1 = torch.compile(fn)
369        no_ca1()
370        self.assertEqual(counters["compiled_autograd"]["captures"], 0)
371        counters.clear()
372
373        with config.patch(compiled_autograd=True):
374            with_ca = torch.compile(fn2)
375            with_ca()
376            self.assertEqual(counters["compiled_autograd"]["captures"], 1)
377            counters.clear()
378
379        no_ca2 = torch.compile(fn)
380        no_ca2()
381        self.assertEqual(counters["compiled_autograd"]["captures"], 0)
382
383    def test_torch_compile_graph_break(self):
384        model = torch.nn.Sequential(
385            torch.nn.Linear(4, 4),
386            torch.nn.Sigmoid(),
387        )
388        x = torch.randn([1, 4])
389
390        @torch._dynamo.disable()
391        def fn():
392            result = model(x).sum()
393            result.backward()
394
395        with config.patch(compiled_autograd=True):
396            opt_fn = torch.compile(fn)
397            opt_fn()
398
399        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
400
401    def test_torch_compile_graph_break2(self):
402        model = torch.nn.Sequential(
403            torch.nn.Linear(4, 4),
404            torch.nn.Sigmoid(),
405        )
406        x = torch.randn([1, 4])
407
408        @torch._dynamo.disable()
409        def inner_fn(loss):
410            loss.backward()
411
412        def fn():
413            result = model(x).sum()
414            inner_fn(result)
415
416        with config.patch(compiled_autograd=True):
417            opt_fn = torch.compile(fn)
418            opt_fn()
419
420        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
421
422    def test_torch_compile_only_backward_call(self):
423        model = torch.nn.Sequential(
424            torch.nn.Linear(4, 4),
425            torch.nn.Sigmoid(),
426        )
427        x = torch.randn([1, 4])
428
429        result = model(x).sum()
430        with config.patch(compiled_autograd=True):
431            opt_bwd = torch.compile(lambda: result.backward())
432            opt_bwd()
433
434        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
435
436    def test_dynamo_boxed(self):
437        def get_placeholders(gm_):
438            placeholders = []
439            for node in gm_.graph.nodes:
440                if node.op == "placeholder":
441                    placeholders.append(node)
442            return placeholders
443
444        def eager_with_check(gm, is_bwd):
445            def inner_compiler(gm_, example_inputs_):
446                placeholders = get_placeholders(gm_)
447                if is_bwd:
448                    # should be boxed inputs
449                    assert len(placeholders) == 1
450                else:
451                    assert len(placeholders) > 1
452
453                return gm_
454
455            return torch.compile(gm, backend=inner_compiler)
456
457        fwd_compiler_fn = functools.partial(eager_with_check, is_bwd=False)
458        bwd_compiler_fn = functools.partial(eager_with_check, is_bwd=True)
459
460        def fn(inputs):
461            args_0, args_1, args_2 = inputs
462            out = torch.mm(args_0, args_1)
463            out = torch.mm(out, args_2)
464            loss = out.sum()
465            with compiled_autograd.enable(bwd_compiler_fn):
466                loss.backward()
467            yield args_0.grad
468            yield args_1.grad
469            yield args_2.grad
470
471        inputs = [
472            torch.randn([1, 2], requires_grad=True),
473            torch.randn([2, 3], requires_grad=True),
474            torch.randn([3, 4], requires_grad=True),
475        ]
476
477        compiled_fn = eager_with_check(fn, is_bwd=False)
478        grads = list(compiled_fn(inputs))
479        self.assertEqual(len(grads), 3)
480        self.assertNotEqual(grads[0], None)
481        self.assertNotEqual(grads[1], None)
482        self.assertNotEqual(grads[2], None)
483
484    def test_inputs_aliasing_bytecode_attr_mutations(self):
485        # Freeze compiled autograd graph
486        compiler = torch._dynamo.compiled_autograd.AutogradCompilerInstance(compiler_fn)
487        param = torch.ones(100)
488        activ = torch.ones(100) * 2
489        inputs = [param, activ]
490        proxies, _, _ = compiler.begin_capture(inputs=inputs, sizes=[], scalars=[])
491        param_proxy, activ_proxy = proxies
492        buf = activ_proxy * 2
493        torch.ops.inductor.accumulate_grad_.default(param_proxy, buf)
494        runtime_wrapper, compiled_fn = compiler.end_capture(buf)
495
496        def bytecode_hook(code, out_code):
497            import dis
498            import sys
499
500            if sys.version_info < (3, 11):
501                call_op = "CALL_FUNCTION"
502            else:
503                call_op = "CALL"
504
505            insts = list(dis.get_instructions(out_code))
506            call_graph_idx = next(
507                i for i, inst in enumerate(insts) if inst.opname == call_op
508            )
509            # pre-graph should alias: inputs_ref_0 = inputs[0]
510            matches = [
511                inst
512                for inst in insts[:call_graph_idx]
513                if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0"
514            ]
515            self.assertTrue(len(matches) == 1)
516            # post-graph should access inputs_ref_0 instead of inputs
517            matches = [
518                inst for inst in insts[call_graph_idx:] if inst.argval == "inputs"
519            ]
520            self.assertTrue(len(matches) == 0)
521            matches = [
522                inst
523                for inst in insts[call_graph_idx:]
524                if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0"
525            ]
526            self.assertTrue(len(matches) == 1)
527
528        torch._dynamo.reset()
529        handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
530        try:
531            runtime_wrapper(
532                compiled_fn=compiled_fn,
533                inputs=[param, activ],
534                sizes=(),
535                scalars=(),
536                hooks=(),
537            )
538        finally:
539            handle.remove()
540
541    def test_inputs_aliasing_bytecode_stack_restore(self):
542        logging.getLogger().setLevel(logging.WARNING)
543        from torch.testing._internal.logging_tensor import LoggingTensor
544
545        # Create a graph that allows inputs stealing
546        def forward(inputs):
547            add = inputs[0] + 1
548            add_1 = add + inputs[1]  # handled in suffix for tensor subclass
549            out = add_1.cpu()
550            return (out,)
551
552        gm = torch.fx.symbolic_trace(forward)
553        torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
554        compiled_fn = torch.compile(gm)
555
556        inputs = [
557            torch.ones(1000000, dtype=torch.float32),
558            LoggingTensor(torch.ones(1)),
559        ]
560
561        def bytecode_hook(code, out_code):
562            import dis
563            import sys
564
565            if sys.version_info < (3, 11):
566                call_op = "CALL_FUNCTION"
567            else:
568                call_op = "CALL"
569
570            insts = list(dis.get_instructions(out_code))
571            call_graph_idx = next(
572                i for i, inst in enumerate(insts) if inst.opname == call_op
573            )
574            # pre-graph should alias: inputs_ref_0 = inputs[0]
575            matches = [
576                inst
577                for inst in insts[:call_graph_idx]
578                if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0"
579            ]
580            self.assertTrue(len(matches) == 1)
581            # post-graph should access inputs_ref_0 instead of inputs
582            matches = [
583                inst for inst in insts[call_graph_idx:] if inst.argval == "inputs"
584            ]
585            self.assertTrue(len(matches) == 0)
586            matches = [
587                inst
588                for inst in insts[call_graph_idx:]
589                if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0"
590            ]
591            self.assertTrue(len(matches) == 1)
592
593        torch._dynamo.reset()
594        handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
595        try:
596            out = compiled_fn(inputs)
597            self.assertTrue(len(inputs) == 0)
598        finally:
599            handle.remove()
600
601    def test_implicit_add(self):
602        def fn():
603            y = torch.randn(1, 4, requires_grad=True)
604
605            def model(x):
606                # y is used multiple times, gradients get added
607                return torch.sigmoid(x * y + torch.sin(y) + torch.cos(y))
608
609            for _ in range(3):
610                x = torch.randn([1, 4])
611
612                result = model(x).sum()
613                result.backward()
614                yield result
615                yield y.grad
616                y.grad = None
617
618        self.check_output_and_recompiles(fn)
619
620    def test_output_nodes_all_leaves(self):
621        def fn():
622            y = torch.randn(1, 4, requires_grad=True)
623            z = torch.randn(1, 4, requires_grad=True)
624
625            def model(x):
626                return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y))
627
628            for _ in range(3):
629                x = torch.randn([1, 4])
630
631                result = model(x).sum()
632                gy, gz = torch.autograd.grad(result, inputs=[y, z])
633                assert y.grad is None
634                assert z.grad is None
635                yield gy
636                yield gz
637
638        self.check_output_and_recompiles(fn)
639
640    def test_output_nodes_some_leaves(self):
641        def fn():
642            class UnreachableBwd(torch.autograd.Function):
643                @staticmethod
644                def forward(ctx, x):
645                    return x
646
647                @staticmethod
648                def backward(ctx, gO):
649                    raise RuntimeError
650
651            y = torch.randn(1, 4, requires_grad=True)
652            z = torch.randn(1, 4, requires_grad=True)
653
654            def model(x):
655                return torch.sigmoid(UnreachableBwd.apply(y) * z)
656
657            for _ in range(3):
658                x = torch.randn([1, 4])
659
660                result = model(x).sum()
661                gz = torch.autograd.grad(result, inputs=[z])
662                assert y.grad is None
663                assert z.grad is None
664                yield gz
665
666        self.check_output_and_recompiles(fn)
667
668    def test_no_output_nodes_all_leaves(self):
669        def fn():
670            y = torch.randn(1, 4, requires_grad=True)
671            z = torch.randn(1, 4, requires_grad=True)
672
673            def model(x):
674                return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y))
675
676            for _ in range(3):
677                x = torch.randn([1, 4])
678                result = model(x).sum()
679                out = result.backward()
680                assert out is None
681                assert y.grad is not None
682                assert z.grad is not None
683                yield y.grad
684                yield z.grad
685                y.grad = None
686                z.grad = None
687
688        self.check_output_and_recompiles(fn)
689
690    def test_no_output_nodes_some_leaves(self):
691        def fn():
692            class UnreachableBwd(torch.autograd.Function):
693                @staticmethod
694                def forward(ctx, x):
695                    return x
696
697                @staticmethod
698                def backward(ctx, gO):
699                    raise RuntimeError
700
701            y = torch.randn(1, 4, requires_grad=True)
702            z = torch.randn(1, 4, requires_grad=True)
703            a = torch.randn(1, 4, requires_grad=True)
704
705            def model(x):
706                return torch.sigmoid(x * y * z * UnreachableBwd.apply(a))
707
708            for _ in range(3):
709                x = torch.randn([1, 4])
710                result = model(x).sum()
711                out = result.backward(inputs=[y, z])
712                assert out is None
713                assert y.grad is not None
714                assert z.grad is not None
715                assert a.grad is None
716                yield y.grad
717                yield z.grad
718                y.grad = None
719                z.grad = None
720
721        self.check_output_and_recompiles(fn)
722
723    def test_no_output_nodes_different_leaves_will_recompile(self):
724        def fn():
725            def fwd(x, y, z):
726                out = x * y  # MulBackward0
727                out2 = out * z  # MulBackward0
728                return out2.sum()  # SumBackward0
729
730            x = torch.randn(5, requires_grad=True)
731            y = torch.randn(5, requires_grad=True)
732            z = torch.randn(5, requires_grad=True)
733            loss = fwd(x, y, z)
734            torch.compile(lambda: torch.autograd.backward(loss, inputs=[x]))()
735            yield x.grad
736            x.grad = None
737
738            loss = fwd(x, y, z)
739            torch.compile(lambda: torch.autograd.backward(loss, inputs=[y]))()
740            yield y.grad
741
742        # Guarded by TensorArg id, mismatch on last MulBackward0
743        self.check_output_and_recompiles(fn, 2)
744
745    def test_dynamic_shapes(self):
746        def fn():
747            model = torch.nn.Sequential(
748                torch.nn.Linear(4, 4),
749                torch.nn.ReLU(),
750                torch.nn.Linear(4, 4),
751                torch.nn.ReLU(),
752            )
753            opt_model = torch.compile(model, dynamic=True)
754
755            for b in range(10, 100, 10):
756                x = torch.randn([b, 4])
757                result = opt_model(x).sum()
758                result.backward()
759                yield model[0].weight.grad
760                yield model[0].bias.grad
761                yield model[2].weight.grad
762                yield model[2].bias.grad
763                model.zero_grad()
764
765        # TODO(jansel): we should be able to get this count to 1
766        self.check_output_and_recompiles(fn, count=2)
767
768    def test_accumulate_without_zero(self):
769        def fn():
770            model = torch.nn.Sequential(
771                torch.nn.Linear(4, 4),
772                torch.nn.ReLU(),
773                torch.nn.Linear(4, 4),
774                torch.nn.ReLU(),
775            )
776            opt_model = torch.compile(model, dynamic=True)
777
778            for _ in range(10):
779                x = torch.randn([10, 4])
780                result = opt_model(x).sum()
781                result.backward()
782                yield model[0].weight.grad.clone()
783                yield model[0].bias.grad.clone()
784                yield model[2].weight.grad.clone()
785                yield model[2].bias.grad.clone()
786
787        self.check_output_and_recompiles(fn, count=2)
788
789    def test_inplace_grad_update(self):
790        def fn():
791            model = torch.nn.Sequential(
792                torch.nn.Linear(4, 4),
793                torch.nn.ReLU(),
794            )
795            opt_model = torch.compile(model, dynamic=True)
796
797            for _ in range(10):
798                w_grad = torch.rand_like(model[0].weight)
799                b_grad = torch.rand_like(model[0].bias)
800                model[0].weight.grad = w_grad
801                model[0].bias.grad = b_grad
802
803                x = torch.randn([10, 4])
804                result = opt_model(x).sum()
805                result.backward()
806                assert model[0].weight.grad is w_grad
807                assert model[0].bias.grad is b_grad
808                yield w_grad.clone()
809                yield b_grad.clone()
810
811        self.check_output_and_recompiles(fn, count=1)
812
813    @unittest.skipIf(not HAS_CUDA, "requires cuda")
814    def test_issue106555(self):
815        DEVICE = torch.device("cuda:0")
816        NUM_FEATURES = 256
817
818        def bias_sigmoid_mul(x1, x2, bias):
819            x2 = torch.sigmoid(x2 + bias)
820            y = x1 * x2
821            return y
822
823        bias_sigmoid_mul_jit = torch.compile(bias_sigmoid_mul)
824
825        class ModuleWithJit(nn.Module):
826            def __init__(self) -> None:
827                super().__init__()
828                self.linear_1 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=True)
829                self.linear_2 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=False)
830                self.linear_2_bias = nn.Parameter(torch.zeros(NUM_FEATURES))
831
832            def forward(self, input_tensor):
833                x1 = self.linear_1(input_tensor)
834                x2 = self.linear_2(input_tensor)
835                output = bias_sigmoid_mul_jit(x1, x2, self.linear_2_bias)
836                return output
837
838        class Model(nn.Module):
839            def __init__(self) -> None:
840                super().__init__()
841                self.module_with_jit_1 = ModuleWithJit()
842                self.module_with_jit_2 = ModuleWithJit()
843
844            def forward(self, x, gradient_checkpointing: bool):
845                if gradient_checkpointing:
846                    y = torch.utils.checkpoint.checkpoint(
847                        self._forward, x, use_reentrant=True
848                    )
849                else:
850                    y = self._forward(x)
851                return y
852
853            def _forward(self, x):
854                x = x + self.module_with_jit_1(x)
855                x = x + self.module_with_jit_2(x.transpose(-2, -3)).transpose(-2, -3)
856                return x
857
858        torch.cuda.set_device(device=DEVICE)
859        torch.manual_seed(1234567890)
860        model = Model()
861        model.train()
862        model.to(device=DEVICE)
863        model_parameters = list(model.parameters())
864
865        torch.manual_seed(1234567890)
866        input_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to(device=DEVICE)
867        input_tensor.requires_grad = True
868        target_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to(
869            dtype=input_tensor.dtype, device=DEVICE
870        )
871
872        for iteration in range(10):
873            for param in model_parameters:
874                param.grad = None
875            output_tensor = model(
876                x=input_tensor.clone(),
877                gradient_checkpointing=True,
878            )
879            loss = torch.mean(torch.abs(target_tensor - output_tensor))
880            loss.backward()
881
882    def test_keep_graph_simple(self):
883        x = torch.tensor([2.0], requires_grad=True)
884        y = x**2
885
886        # First backward pass; keep the computation graph
887        y.backward(retain_graph=True)
888        self.assertEqual(x.grad, torch.Tensor([4]))  # dy/dx at x=2 is 4
889
890        # Note - this will run under both the eager and compiled regime.
891        def fn():
892            # Reset the gradients
893            x.grad = torch.tensor([0.0])
894            # Second and Third backward pass; keep the computation graph
895            y.backward(retain_graph=True)
896            self.assertEqual(x.grad, torch.Tensor([4]))  # dy/dx at x=2 is 4
897            return x.grad
898
899        self.check_output_and_recompiles(fn, count=1)
900
901    def test_keep_graph_usage_after_compiled(self):
902        x = torch.tensor([2.0], requires_grad=True)
903        y = x**2
904
905        # First backward pass; keep the computation graph
906        def eager_check():
907            y.backward(retain_graph=True)
908            self.assertEqual(x.grad, torch.Tensor([4]))  # dy/dx at x=2 is 4
909            x.grad = torch.tensor([0.0])
910
911        eager_check()
912
913        for i in range(0, 5):
914            with compiled_autograd.enable(compiler_fn):
915                eager_check()
916
917            eager_check()
918
919    def test_custom_fn_saved_tensors(self):
920        def fn():
921            class MySin(torch.autograd.Function):
922                @staticmethod
923                def forward(ctx, x):
924                    ctx.save_for_backward(x)
925                    return torch.sin(x)
926
927                @staticmethod
928                def backward(ctx, gO):
929                    (x,) = ctx.saved_tensors
930                    return gO * torch.cos(x)
931
932            for i in [10, 100, 10, 15, 20, 25]:
933                x = torch.arange(0.0, i, requires_grad=True)
934                out = MySin.apply(x)
935                loss = out.sum()
936                loss.backward()
937                yield x.grad
938
939        self.check_output_and_recompiles(fn, count=2)
940
941    def test_custom_fn_saved_multiple_tensors(self):
942        def fn():
943            class MyFn(torch.autograd.Function):
944                @staticmethod
945                def forward(ctx, x, y):
946                    ctx.save_for_backward(x, y)
947                    return torch.sin(x), torch.sin(y)
948
949                @staticmethod
950                def backward(ctx, gO_x, gO_y):
951                    (x, y) = ctx.saved_tensors
952                    return gO_x * torch.cos(x), gO_y * torch.cos(y)
953
954            for i in [10, 100, 10, 15, 20, 25]:
955                x = torch.arange(0.0, i, requires_grad=True)
956                y = torch.arange(0.0, i, requires_grad=True)
957                out1, out2 = MyFn.apply(x, y)
958                loss = (out1 * out2).sum()
959                loss.backward()
960                yield x.grad
961
962        self.check_output_and_recompiles(fn, count=2)
963
964    def test_custom_fn_saved_multiple_tensors_dedup(self):
965        def fn():
966            class MyFn(torch.autograd.Function):
967                @staticmethod
968                def forward(ctx, x):
969                    ctx.save_for_backward(x, x)
970                    return torch.sin(x)
971
972                @staticmethod
973                def backward(ctx, gO):
974                    (x1, x2) = ctx.saved_tensors
975                    return gO * torch.cos(x1) * torch.cos(x2)
976
977            for i in [10, 100, 10, 15, 20, 25]:
978                x = torch.arange(0.0, i, requires_grad=True)
979                out = MyFn.apply(x)
980                loss = out.sum()
981                loss.backward()
982                yield x.grad
983
984        self.check_output_and_recompiles(fn, count=2)
985
986    def test_custom_fn_saved_shape_tensor(self):
987        def fn():
988            class MyFn(torch.autograd.Function):
989                @staticmethod
990                def forward(ctx, x):
991                    ctx.save_for_backward(x)
992                    return x
993
994                @staticmethod
995                def backward(ctx, gO):
996                    (x,) = ctx.saved_tensors
997                    return gO * x.shape[0]
998
999            for i in [10, 100, 10, 15, 20, 25]:
1000                x = torch.arange(0.0, i, requires_grad=True)
1001                out = MyFn.apply(x)
1002                loss = out.sum()
1003                loss.backward()
1004                yield x.grad
1005
1006        self.check_output_and_recompiles(fn, count=2)
1007
1008    def test_custom_fn_saved_attr(self):
1009        def fn():
1010            class MyFn(torch.autograd.Function):
1011                @staticmethod
1012                def forward(ctx, x):
1013                    ctx.shape = x.shape
1014                    return x
1015
1016                @staticmethod
1017                def backward(ctx, gO):
1018                    x_shape = ctx.shape[0]
1019                    return gO * x_shape
1020
1021            for i in [10, 100, 10, 15, 20, 25]:
1022                x = torch.arange(0.0, i, requires_grad=True)
1023                out = MyFn.apply(x)
1024                loss = out.sum()
1025                loss.backward()
1026                yield x.grad
1027
1028        self.check_output_and_recompiles(
1029            fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False)
1030        )
1031
1032    def test_custom_fn_multiple_grads(self):
1033        def fn():
1034            class MyFn(torch.autograd.Function):
1035                @staticmethod
1036                def forward(ctx, x, y):
1037                    return x + y, y
1038
1039                @staticmethod
1040                def backward(ctx, gO_1, gO_2):
1041                    return gO_1, gO_2
1042
1043            for i in [10, 100, 10, 15, 20, 25]:
1044                x = torch.arange(0.0, i, requires_grad=True)
1045                y = torch.arange(0.0, i, requires_grad=True)
1046                out1, out2 = MyFn.apply(x, y)
1047                loss = (out1 + out2).sum()
1048                loss.backward()
1049                yield x.grad
1050                yield y.grad
1051
1052        self.check_output_and_recompiles(fn, count=2)
1053
1054    def test_custom_fn_non_variable_input(self):
1055        def fn():
1056            class MyFn(torch.autograd.Function):
1057                @staticmethod
1058                def forward(ctx, x, y, z):
1059                    return x * 2, y * 3, z * 4
1060
1061                @staticmethod
1062                def backward(ctx, gO_1, gO_2, gO_3):
1063                    return gO_1, gO_2, gO_3
1064
1065            for i in [10, 100, 10, 15, 20, 25]:
1066                x = torch.arange(0.0, i, requires_grad=True)
1067                y = 1
1068                z = torch.arange(0.0, i, requires_grad=True)
1069                out1, out2, out3 = MyFn.apply(x, y, z)
1070                loss = (out1 + out2 + out3).sum()
1071                loss.backward()
1072                yield x
1073                yield y
1074                yield z
1075
1076        self.check_output_and_recompiles(fn, count=2)
1077
1078    @unittest.skipIf(not HAS_CUDA, "requires cuda")
1079    def test_logging_tensor_flaky(self) -> None:
1080        # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
1081        # resulting in:
1082        #   - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
1083        #   - python: `TypeError: not all arguments converted during string formatting`
1084
1085        # 1. some triton involving test
1086        def fn():
1087            def _fn(x):
1088                return x
1089
1090            x = torch.arange(
1091                1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
1092            )
1093            out = _fn(x)
1094            loss = out.sum()
1095            loss.backward()
1096
1097        with compiled_autograd.enable(compiler_fn):
1098            fn()
1099
1100        logging.getLogger().setLevel(
1101            logging.WARNING
1102        )  # triton setup overwrote it to INFO
1103        # 2. test_inputs_aliasing_bytecode_stack_restore
1104        from torch.testing._internal.logging_tensor import LoggingTensor
1105
1106        def forward(inputs):
1107            add = inputs[0] + 1
1108            add_1 = add + inputs[1]
1109            out = add_1.cpu()
1110            return (out,)
1111
1112        gm = torch.fx.symbolic_trace(forward)
1113        print(gm.print_readable())
1114        torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
1115        compiled_fn = torch.compile(gm)
1116
1117        inputs = [
1118            torch.ones(1000000, dtype=torch.float32),
1119            LoggingTensor(torch.ones(1)),
1120        ]
1121
1122        compiled_fn(inputs)
1123
1124    @unittest.skipIf(not HAS_CUDA, "requires cuda")
1125    def test_custom_fn_output_metadata(self):
1126        def my_compiler_fn(gm):
1127            for node in gm.graph.nodes:
1128                if isinstance(node.target, torch._ops.OpOverload):
1129                    assert (
1130                        node.target._name != "aten::_to_copy"
1131                    ), "there should be no implicit copies (e.g. dtype casting)"
1132
1133            def inner_compiler(gm_, example_inputs_):
1134                counters["compiled_autograd"]["compiles"] += 1
1135                return inductor.compile(gm_, example_inputs_)
1136
1137            return torch.compile(
1138                gm, backend=inner_compiler, fullgraph=True, dynamic=True
1139            )
1140
1141        def fn():
1142            class MyFn(torch.autograd.Function):
1143                @staticmethod
1144                def forward(ctx, x):
1145                    return x
1146
1147                @staticmethod
1148                def backward(ctx, gO):
1149                    return gO
1150
1151            x = torch.arange(
1152                1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
1153            )
1154            x_view = x.view(3, 3)
1155            out = MyFn.apply(x_view)
1156            loss = out.sum()
1157            loss.backward()
1158            yield x.dtype
1159            yield x.device
1160            yield x.grad
1161
1162        self.check_output_and_recompiles(fn, count=1)
1163
1164    def test_custom_fn_with_same_graph(self):
1165        def fn():
1166            class MyFn1(torch.autograd.Function):
1167                @staticmethod
1168                def forward(ctx, x):
1169                    return x
1170
1171                @staticmethod
1172                def backward(ctx, gO):
1173                    return gO
1174
1175            # same as MyFn1, but different autograd function id
1176            # should not be using same graph as MyFn1
1177            class MyFn2(torch.autograd.Function):
1178                @staticmethod
1179                def forward(ctx, x):
1180                    return x
1181
1182                @staticmethod
1183                def backward(ctx, gO):
1184                    return gO
1185
1186            for myfn in [MyFn1, MyFn2, MyFn1, MyFn2]:
1187                x = torch.arange(0.0, 10, requires_grad=True)
1188                out = myfn.apply(x)
1189                loss = out.sum()
1190                loss.backward()
1191                yield x.grad
1192
1193        self.check_output_and_recompiles(
1194            fn, count=2
1195        )  # should compile once for MyFn1 and once for MyFn2
1196
1197    def test_custom_fn_dynamically_defined_class(self):
1198        def fn():
1199            def create_class(multiplier: int):
1200                class DynamicFn(torch.autograd.Function):
1201                    @staticmethod
1202                    def forward(ctx, x):
1203                        return x * multiplier
1204
1205                    @staticmethod
1206                    def backward(ctx, gO):
1207                        return gO * multiplier
1208
1209                return DynamicFn
1210
1211            for multiplier in [10, 20, 30]:
1212                x = torch.arange(0.0, 10, requires_grad=True)
1213                out = create_class(multiplier).apply(x)
1214                loss = out.sum()
1215                loss.backward()
1216                yield x.grad
1217
1218        self.check_output_and_recompiles(fn, count=3)
1219
1220    def test_custom_fn_bw_graph_break(self):
1221        def fn():
1222            class MySin(torch.autograd.Function):
1223                @staticmethod
1224                def forward(ctx, x):
1225                    ctx.save_for_backward(x)
1226                    return torch.sin(x)
1227
1228                @staticmethod
1229                def backward(ctx, gO):
1230                    print("graph break")
1231                    (x,) = ctx.saved_tensors
1232                    print("graph break")
1233                    return gO * torch.cos(x)
1234
1235            for i in [10, 100, 10, 15, 20, 25]:
1236                x = torch.arange(0.0, i, requires_grad=True)
1237                out = MySin.apply(x)
1238                loss = out.sum()
1239                loss.backward()
1240                yield x.grad
1241
1242        self.check_output_and_recompiles(
1243            fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False)
1244        )
1245
1246    def test_custom_fn_compiled_fw_graph_break(self):
1247        def fn():
1248            class MySin(torch.autograd.Function):
1249                @staticmethod
1250                def forward(ctx, x):
1251                    print("graph break")
1252                    ctx.save_for_backward(x)
1253                    return torch.sin(x)
1254
1255                @staticmethod
1256                def backward(ctx, gO):
1257                    (x,) = ctx.saved_tensors
1258                    return gO * torch.cos(x)
1259
1260            opt_model = torch.compile(MySin.apply)
1261            for i in [10, 100, 10, 15, 20, 25]:
1262                x = torch.arange(0.0, i, requires_grad=True)
1263                out = opt_model(x)
1264                loss = out.sum()
1265                loss.backward()
1266                yield x.grad
1267
1268        self.check_output_and_recompiles(
1269            fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False)
1270        )
1271        self.assertEqual(counters["stats"]["unique_graphs"], 5)  # 3 fw, 2 bw
1272
1273    def test_custom_fn_compiled_fw_bw_graph_break(self):
1274        def fn():
1275            class MySin(torch.autograd.Function):
1276                @staticmethod
1277                def forward(ctx, x):
1278                    print("graph break")
1279                    ctx.save_for_backward(x)
1280                    return torch.sin(x)
1281
1282                @staticmethod
1283                def backward(ctx, gO):
1284                    print("graph break")
1285                    (x,) = ctx.saved_tensors
1286                    return gO * torch.cos(x)
1287
1288            opt_model = torch.compile(MySin.apply)
1289            for i in [10, 100, 10, 15, 20, 25]:
1290                x = torch.arange(0.0, i, requires_grad=True)
1291                out = opt_model(x)
1292                loss = out.sum()
1293                loss.backward()
1294                yield x.grad
1295
1296        self.check_output_and_recompiles(
1297            fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False)
1298        )
1299        self.assertEqual(counters["stats"]["unique_graphs"], 9)  # 3 fw, 6 bw
1300
1301    def test_mismatch_fake_tensor_mode(self, dynamic_shape=False):
1302        """
1303        Repro the failure of training nanogpt with both compiled-autograd
1304        and _LazyGraphModule. Check https://github.com/pytorch/pytorch/pull/118981
1305        for more context.
1306        """
1307        B = 8
1308        x = torch.rand(B, 16)
1309        y = torch.rand(B, 16, requires_grad=True)
1310
1311        if dynamic_shape:
1312            torch._dynamo.mark_dynamic(x, 0)
1313            torch._dynamo.mark_dynamic(y, 0)
1314
1315        def f():
1316            y.grad = None
1317            out = x + y
1318
1319            # make sure the backward call does not trigger any error when
1320            # compiling the backward graph
1321            out.sum().backward()
1322            return out, y.grad
1323
1324        self.check_output_and_recompiles(f, compile_fn=True)
1325
1326    def test_mismatch_fake_tensor_mode_dynamic_shape(self):
1327        self.test_mismatch_fake_tensor_mode(dynamic_shape=True)
1328
1329    def test_accumulate_grad_accuracy(self):
1330        def fn():
1331            model = torch.nn.Sequential(
1332                torch.nn.Linear(2, 1, bias=False),
1333                torch.nn.Linear(1, 2, bias=False),
1334            )
1335            x = torch.randn(2, 2)
1336
1337            out = model(x)
1338            loss = out.sum()
1339            torch.manual_seed(0)
1340            loss.backward()
1341
1342            yield model[0].weight.grad
1343            yield model[1].weight.grad
1344
1345        self.check_output_and_recompiles(fn, 1)
1346
1347    def test_trace_run_with_rng_state(self):
1348        def sdpa(xq, xk):
1349            return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
1350
1351        def g(xq_1, xk_1, xq_2, xk_2):
1352            # xq: (bs, n_local_heads, seqlen, head_dim)
1353            # xk: (bs, n_local_heads, cache_len + seqlen, head_dim)
1354            y1 = sdpa(xq_1, xk_1)
1355            y2 = torch.utils.checkpoint.checkpoint(
1356                sdpa, xq_2, xk_2, use_reentrant=False
1357            )
1358            y = torch.mul(y1, y2)
1359            z = torch.matmul(y, y)
1360            return z
1361
1362        def f():
1363            bs = 1
1364            n_local_heads = 1
1365            seqlen = 2
1366            head_dim = 2
1367            cache_len = 2
1368            xq_list = [
1369                torch.ones(
1370                    (bs, n_local_heads, seqlen, head_dim),
1371                    requires_grad=True,
1372                    device="cpu",
1373                )
1374                for _ in range(2)
1375            ]
1376            xk_list = [
1377                torch.ones(
1378                    (bs, n_local_heads, cache_len + seqlen, head_dim),
1379                    requires_grad=True,
1380                    device="cpu",
1381                )
1382                for _ in range(2)
1383            ]
1384            out = torch.compile(g, fullgraph=True)(
1385                xq_list[0], xk_list[0], xq_list[1], xk_list[1]
1386            )
1387            out.sum().backward()
1388            return out, *[x.grad for x in xq_list + xk_list]
1389
1390        """
1391        Walkthrough of what happens with `run_with_rng_state`:
1392        1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner).
1393        2. The Dynamo graph captured by Compiled Autograd looks like:
1394        ```
1395        ===== __compiled_fn_3 =====
1396        torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
1397            def forward(self, L_inputs_ : list):
1398                ...
1399                run_with_rng_state = torch.ops.higher_order.run_with_rng_state(
1400                    getitem_8,
1401                    torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
1402                    getitem_3, getitem_4, getitem_4, 0.0, True,
1403                )
1404                ...
1405        ```
1406        3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling
1407        in `run_with_rng_state` op's py_functionalize_impl.
1408        """
1409
1410        def _run_with_rng_state_op_check(inductor_post_grad_graph):
1411            # Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph.
1412            op_set = {node.target for node in inductor_post_grad_graph.nodes}
1413            if torch.ops.higher_order.run_and_save_rng_state not in op_set:
1414                # This is backward graph, so check existence of `run_with_rng_state` op
1415                self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set)
1416
1417        with torch._inductor.config.patch(
1418            post_grad_custom_post_pass=_run_with_rng_state_op_check
1419        ):
1420            compiler_fn = make_compiler_fn(fullgraph=True)
1421
1422            def make_compiler_fn_with_op_check():
1423                def _compiler_fn(gm):
1424                    # Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph.
1425                    self.assertTrue(
1426                        any(
1427                            node.target is torch.ops.higher_order.run_with_rng_state
1428                            for node in gm.graph.nodes
1429                        )
1430                    )
1431                    return compiler_fn(gm)
1432
1433                return _compiler_fn
1434
1435            compiler_fn_with_op_check = make_compiler_fn_with_op_check()
1436            self.check_output_and_recompiles(
1437                f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
1438            )
1439
1440    def test_trace_auto_functionalized(self):
1441        torch.library.define(
1442            "testlib::foo",
1443            "(Tensor(a!) x) -> (Tensor)",
1444            tags=torch.Tag.pt2_compliant_tag,
1445        )
1446        torch.library.define(
1447            "testlib::foo_mutated",
1448            "(Tensor(a!) x) -> (Tensor)",
1449            tags=torch.Tag.pt2_compliant_tag,
1450        )
1451
1452        @torch.library.impl("testlib::foo", "cpu")
1453        def foo(x):
1454            x.add_(5)
1455            return x
1456
1457        @torch.library.impl("testlib::foo", "Meta")
1458        def foo_meta(x):
1459            return x
1460
1461        @torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd")
1462        def foo_mutated(x):
1463            return torch.ops.testlib.foo(x)
1464
1465        def _get_custom_policy(must_recompute_list=None):
1466            def _custom_policy(ctx, func, *args, **kwargs):
1467                if must_recompute_list is not None and func in must_recompute_list:
1468                    return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
1469                else:
1470                    return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
1471
1472            return _custom_policy
1473
1474        def context_fn():
1475            must_recompute_list = [
1476                torch.ops.higher_order.auto_functionalized,
1477            ]
1478            return torch.utils.checkpoint.create_selective_checkpoint_contexts(
1479                _get_custom_policy(
1480                    must_recompute_list=must_recompute_list,
1481                ),
1482            )
1483
1484        def g(x):
1485            x = torch.matmul(x, x)
1486            torch.ops.testlib.foo_mutated(x)
1487            return torch.matmul(x, x)
1488
1489        def g_cp(x):
1490            return torch.utils.checkpoint.checkpoint(
1491                g, x, use_reentrant=False, context_fn=context_fn
1492            )
1493
1494        def f():
1495            inps = (torch.randn(4, 4, requires_grad=True),)
1496            output = torch.compile(g_cp, backend="aot_eager", fullgraph=True)(*inps)
1497            output.sum().backward()
1498            return output, inps[0].grad
1499
1500        """
1501        Walkthrough of what happens with `auto_functionalized`:
1502        1. `auto_functionalized` op is inserted into the graph during AOTAutograd functionalization.
1503           We force the op to be recomputed (by using SAC), so it appears in the backward graph.
1504        2. The AOT backward graph looks like:
1505        ```
1506        ===== Backward graph 0 =====
1507        def forward(self, primals_1: "f32[4, 4][4, 1]cpu", tangents_1: "f32[4, 4][4, 1]cpu"):
1508            ...
1509            X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm)
1510            ...
1511            return (add_1,)
1512        ```
1513        3. The Compiled Autograd graph looks like:
1514        ```
1515        ===== Compiled autograd graph =====
1516        def forward(self, inputs, sizes, scalars, hooks):
1517            ...
1518            X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm)
1519            ...
1520            return []
1521        ```
1522        4. The Dynamo graph captured by Compiled Autograd looks like:
1523        ```
1524        ===== __compiled_fn_3 =====
1525        def forward(self, L_inputs_ : list):
1526            ...
1527            X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm)
1528            ...
1529            return (new_grad,)
1530        ```
1531        5. The Compiled Autograd's AOT "forward-only" graph looks like:
1532        ```
1533        ===== Forward graph 1 =====
1534        def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[4, 4][4, 1]cpu"):
1535            ...
1536            X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm)
1537            ...
1538            return (clone_1,)
1539        ```
1540        6. The `auto_functionalized` op should then be lowered using the normal lowering path in Inductor.
1541        """
1542
1543        compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager")
1544
1545        def make_compiler_fn_with_op_check():
1546            def _compiler_fn(gm):
1547                # Checks that `auto_functionalized` op exists in Compiled Autograd's Dynamo graph.
1548                self.assertTrue(
1549                    any(
1550                        node.target is torch.ops.higher_order.auto_functionalized
1551                        for node in gm.graph.nodes
1552                    ),
1553                    f"`torch.ops.higher_order.auto_functionalized` op not found in {gm.graph}",
1554                )
1555                return compiler_fn(gm)
1556
1557            return _compiler_fn
1558
1559        compiler_fn_with_op_check = make_compiler_fn_with_op_check()
1560        self.check_output_and_recompiles(
1561            f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
1562        )
1563
1564    def test_non_traceable_autograd_cpp_node(self):
1565        cpp_source = """
1566struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1567  static constexpr bool is_traceable = false;
1568
1569  static torch::Tensor forward(
1570      torch::autograd::AutogradContext* ctx,
1571      const torch::Tensor& x) {
1572    return x;
1573  }
1574
1575  static torch::autograd::variable_list backward(
1576      torch::autograd::AutogradContext *ctx,
1577      torch::autograd::variable_list grad_output) {
1578    return grad_output;
1579  }
1580};
1581
1582torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
1583  return CustomOpAutogradFunction::apply(x);
1584}
1585
1586TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) {
1587    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1588}
1589        """
1590
1591        module = torch.utils.cpp_extension.load_inline(
1592            name="test_non_traceable_autograd_cpp_node",
1593            cpp_sources=cpp_source,
1594            functions="custom_op_backed_by_autograd_fn",
1595            verbose=True,
1596        )
1597
1598        def fn():
1599            x = torch.ones(10, 10, requires_grad=True)
1600            out = torch.ops.test_non_traceable_autograd_cpp_node.custom_op_backed_by_autograd_fn(
1601                x
1602            )
1603            loss = out.sum()
1604            loss.backward()
1605
1606        with self.assertRaisesRegex(
1607            RuntimeError,
1608            "https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/",
1609        ), compiled_autograd.enable(compiler_fn):
1610            fn()
1611
1612    @unittest.skip("Flaky, cache from test ordering affects test. #135369")
1613    def test_autograd_cpp_node(self):
1614        cpp_source = """
1615struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1616  static constexpr bool is_traceable = true;
1617
1618  static torch::Tensor forward(
1619      torch::autograd::AutogradContext* ctx,
1620      const torch::Tensor& x) {
1621    return x;
1622  }
1623
1624  static torch::autograd::variable_list backward(
1625      torch::autograd::AutogradContext *ctx,
1626      torch::autograd::variable_list grad_output) {
1627    return grad_output;
1628  }
1629};
1630
1631torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
1632  return CustomOpAutogradFunction::apply(x);
1633}
1634
1635TORCH_LIBRARY(test_autograd_cpp_node, m) {
1636    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1637}
1638        """
1639
1640        module = torch.utils.cpp_extension.load_inline(
1641            name="test_autograd_cpp_node",
1642            cpp_sources=cpp_source,
1643            functions="custom_op_backed_by_autograd_fn",
1644            verbose=True,
1645        )
1646
1647        def fn():
1648            for i in [10, 100, 10, 20, 10]:
1649                x = torch.ones(i, i, requires_grad=True)
1650                out = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(
1651                    x
1652                )
1653                loss = out.sum()
1654                loss.backward()
1655                yield x.grad
1656
1657        # compiles for 10 (static) and 100 (dynamic)
1658        self.check_output_and_recompiles(fn, 2)
1659
1660    def test_autograd_cpp_node_id(self):
1661        cpp_source = """
1662struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1663  static constexpr bool is_traceable = true;
1664
1665  static torch::Tensor forward(
1666      torch::autograd::AutogradContext* ctx,
1667      const torch::Tensor& x) {
1668    return x;
1669  }
1670
1671  static torch::autograd::variable_list backward(
1672      torch::autograd::AutogradContext *ctx,
1673      torch::autograd::variable_list grad_output) {
1674    return grad_output;
1675  }
1676};
1677
1678struct CustomOpAutogradFunction2 : public torch::autograd::Function<CustomOpAutogradFunction2> {
1679  static constexpr bool is_traceable = true;
1680
1681  static torch::Tensor forward(
1682      torch::autograd::AutogradContext* ctx,
1683      const torch::Tensor& x) {
1684    return x;
1685  }
1686
1687  static torch::autograd::variable_list backward(
1688      torch::autograd::AutogradContext *ctx,
1689      torch::autograd::variable_list grad_output) {
1690    return grad_output;
1691  }
1692};
1693
1694torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
1695  return CustomOpAutogradFunction::apply(x);
1696}
1697
1698torch::Tensor custom_op_backed_by_autograd_fn2(torch::Tensor x) {
1699  return CustomOpAutogradFunction2::apply(x);
1700}
1701
1702TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
1703    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1704    m.def("custom_op_backed_by_autograd_fn2", custom_op_backed_by_autograd_fn2);
1705}
1706        """
1707
1708        module = torch.utils.cpp_extension.load_inline(
1709            name="test_autograd_cpp_node_id",
1710            cpp_sources=cpp_source,
1711            functions="custom_op_backed_by_autograd_fn",
1712            verbose=True,
1713        )
1714
1715        def same_autograd_fn():
1716            def fn():
1717                x = torch.ones(10, 10, requires_grad=True)
1718                out = (
1719                    torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn(
1720                        x
1721                    )
1722                )
1723                loss = out.sum()
1724                loss.backward()
1725                yield x.grad
1726
1727            yield from fn()  # compile
1728            yield from fn()  # reuse
1729            yield from fn()  # reuse
1730            yield from fn()  # reuse
1731
1732        self.check_output_and_recompiles(same_autograd_fn, 1)
1733
1734        def different_autograd_fn():
1735            def fn(op):
1736                x = torch.ones(10, 10, requires_grad=True)
1737                out = op(x)
1738                loss = out.sum()
1739                loss.backward()
1740                yield x.grad
1741
1742            op1 = torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn
1743            op2 = torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn2
1744            yield from fn(op1)  # compile
1745            yield from fn(op2)  # compile
1746            yield from fn(op1)  # reuse
1747            yield from fn(op2)  # reuse
1748
1749        self.check_output_and_recompiles(different_autograd_fn, 2)
1750
1751    def test_autograd_cpp_node_saved(self):
1752        cpp_source = """
1753struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1754  static constexpr bool is_traceable = true;
1755
1756  static torch::Tensor forward(
1757      torch::autograd::AutogradContext* ctx,
1758      const torch::Tensor& x,
1759      const torch::Tensor& y,
1760      const torch::Tensor& fixed) {
1761    ctx->save_for_backward({x, y});
1762    ctx->saved_data["fixed_tensor"] = fixed;
1763    ctx->saved_data["bool"] = true;
1764    ctx->saved_data["int"] = 1;
1765    c10::List<std::string> list({"string"});
1766    ctx->saved_data["list"] = std::move(list);
1767    c10::Dict<std::string, double> dict;
1768    dict.insert("string", 1.0);
1769    ctx->saved_data["dict"] = std::move(dict);
1770    return x;
1771  }
1772
1773  static torch::autograd::variable_list backward(
1774      torch::autograd::AutogradContext *ctx,
1775      torch::autograd::variable_list grad_output) {
1776    const auto& saved_variables = ctx->get_saved_variables();
1777    assert(saved_variables.size() == 2);
1778    torch::Tensor x = saved_variables[0];
1779    torch::Tensor y = saved_variables[1];
1780    torch::Tensor fixed = ctx->saved_data["fixed_tensor"].toTensor();
1781    assert(ctx->saved_data["bool"].isBool());
1782    c10::SymInt i = ctx->saved_data["int"].toSymInt();
1783    c10::List<c10::IValue> list = ctx->saved_data["list"].toList();
1784    assert(list.size() == 1);
1785    assert(list.get(0).toStringRef() == "string");
1786    c10::Dict<c10::IValue, c10::IValue> dict = ctx->saved_data["dict"].toGenericDict();
1787    assert(dict.size() == 1);
1788    assert(dict.at("string") == 1.0);
1789
1790    torch::autograd::variable_list grad_inputs(3);
1791    grad_inputs[0] = x + y + torch::sum(fixed) + i;
1792    return grad_inputs;
1793  }
1794};
1795
1796torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, const torch::Tensor& y, const torch::Tensor& fixed) {
1797  return CustomOpAutogradFunction::apply(x, y, fixed);
1798}
1799
1800TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
1801    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1802}
1803        """
1804
1805        module = torch.utils.cpp_extension.load_inline(
1806            name="test_autograd_cpp_node_saved",
1807            cpp_sources=cpp_source,
1808            functions="custom_op_backed_by_autograd_fn",
1809            verbose=True,
1810        )
1811
1812        def fn():
1813            fixed = torch.ones(2, 2)
1814            for i in [10, 100, 10, 20, 10]:
1815                x = torch.ones(i, i, requires_grad=True)
1816                y = torch.randn(i, i)
1817                out = torch.ops.test_autograd_cpp_node_saved.custom_op_backed_by_autograd_fn(
1818                    x, y, fixed
1819                )
1820                loss = out.sum()
1821                loss.backward()
1822                yield x.grad
1823
1824        self.check_output_and_recompiles(fn, 2)
1825
1826    def test_autograd_cpp_node_saved_dynamic(self):
1827        cpp_source = """
1828struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1829  static constexpr bool is_traceable = true;
1830
1831  static torch::Tensor forward(
1832      torch::autograd::AutogradContext* ctx,
1833      const torch::Tensor& x) {
1834    ctx->save_for_backward({x});
1835    ctx->saved_data["dynamic"] = x.view(-1);
1836    return x;
1837  }
1838
1839  static torch::autograd::variable_list backward(
1840      torch::autograd::AutogradContext *ctx,
1841      torch::autograd::variable_list grad_output) {
1842    const auto& saved_variables = ctx->get_saved_variables();
1843    assert(saved_variables.size() == 1);
1844    torch::Tensor x = saved_variables[0];
1845    torch::Tensor z = ctx->saved_data["dynamic"].toTensor();
1846
1847    torch::autograd::variable_list grad_inputs(1);
1848    grad_inputs[0] = x + torch::sum(z);
1849    return grad_inputs;
1850  }
1851};
1852
1853torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
1854  return CustomOpAutogradFunction::apply(x);
1855}
1856
1857TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) {
1858    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1859}
1860        """
1861
1862        module = torch.utils.cpp_extension.load_inline(
1863            name="test_autograd_cpp_node_saved_dynamic",
1864            cpp_sources=cpp_source,
1865            functions="custom_op_backed_by_autograd_fn",
1866            verbose=True,
1867        )
1868
1869        def fn():
1870            for i in [10, 100, 10, 20, 10]:
1871                x = torch.ones(i, i, requires_grad=True)
1872                out = torch.ops.test_autograd_cpp_node_saved_dynamic.custom_op_backed_by_autograd_fn(
1873                    x
1874                )
1875                loss = out.sum()
1876                loss.backward()
1877                yield x.grad
1878
1879        # compiles for 10 (static) and 100 (dynamic)
1880        self.check_output_and_recompiles(fn, 2)
1881
1882    def test_autograd_cpp_node_saved_int(self):
1883        cpp_source = """
1884struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1885  static constexpr bool is_traceable = true;
1886
1887  static torch::Tensor forward(
1888      torch::autograd::AutogradContext* ctx,
1889      const torch::Tensor& x,
1890      int64_t y) {
1891    ctx->save_for_backward({x});
1892    ctx->saved_data["int"] = y;
1893    ctx->saved_data["symint"] = c10::SymInt(y);
1894    return x;
1895  }
1896
1897  static torch::autograd::variable_list backward(
1898      torch::autograd::AutogradContext *ctx,
1899      torch::autograd::variable_list grad_output) {
1900    const auto& saved_variables = ctx->get_saved_variables();
1901    assert(saved_variables.size() == 1);
1902    torch::Tensor x = saved_variables[0];
1903    c10::SymInt y = ctx->saved_data["int"].toSymInt();
1904    c10::SymInt ys = ctx->saved_data["symint"].toSymInt();
1905
1906    torch::autograd::variable_list grad_inputs(2);
1907    grad_inputs[0] = x + y + ys;
1908    return grad_inputs;
1909  }
1910};
1911
1912torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, int64_t y) {
1913  return CustomOpAutogradFunction::apply(x, y);
1914}
1915
1916TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) {
1917    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1918}
1919        """
1920
1921        module = torch.utils.cpp_extension.load_inline(
1922            name="test_autograd_cpp_node_saved_int",
1923            cpp_sources=cpp_source,
1924            functions="custom_op_backed_by_autograd_fn",
1925            verbose=True,
1926        )
1927
1928        def fn():
1929            for y in [1, 2, 3, 1]:
1930                x = torch.ones(10, 10, requires_grad=True)
1931                out = torch.ops.test_autograd_cpp_node_saved_int.custom_op_backed_by_autograd_fn(
1932                    x, y
1933                )
1934                loss = out.sum()
1935                loss.backward()
1936                yield x.grad
1937
1938        self.check_output_and_recompiles(fn, 1)
1939
1940    def test_autograd_cpp_node_saved_float(self):
1941        cpp_source = """
1942struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1943  static constexpr bool is_traceable = true;
1944
1945  static torch::Tensor forward(
1946      torch::autograd::AutogradContext* ctx,
1947      const torch::Tensor& x,
1948      double z) {
1949    ctx->save_for_backward({x});
1950    ctx->saved_data["float"] = z;
1951    ctx->saved_data["symfloat"] = c10::SymFloat(z);
1952    return x;
1953  }
1954
1955  static torch::autograd::variable_list backward(
1956      torch::autograd::AutogradContext *ctx,
1957      torch::autograd::variable_list grad_output) {
1958    const auto& saved_variables = ctx->get_saved_variables();
1959    assert(saved_variables.size() == 1);
1960    torch::Tensor x = saved_variables[0];
1961    c10::SymFloat z = ctx->saved_data["float"].toSymFloat();
1962    c10::SymFloat zs = ctx->saved_data["symfloat"].toSymFloat();
1963
1964    torch::autograd::variable_list grad_inputs(2);
1965    grad_inputs[0] = x + z + zs;
1966    return grad_inputs;
1967  }
1968};
1969
1970torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, double z) {
1971  return CustomOpAutogradFunction::apply(x, z);
1972}
1973
1974TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
1975    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1976}
1977        """
1978
1979        module = torch.utils.cpp_extension.load_inline(
1980            name="test_autograd_cpp_node_saved_float",
1981            cpp_sources=cpp_source,
1982            functions="custom_op_backed_by_autograd_fn",
1983            verbose=True,
1984        )
1985
1986        def fn():
1987            for z in [1.1, 2.2, 3.3, 1.1]:
1988                x = torch.ones(10, 10, requires_grad=True)
1989                out = torch.ops.test_autograd_cpp_node_saved_float.custom_op_backed_by_autograd_fn(
1990                    x, z
1991                )
1992                loss = out.sum()
1993                loss.backward()
1994                yield x.grad
1995
1996        # compiled autograd and dynamo both support symfloat, but not backend
1997        self.check_output_and_recompiles(fn, [1, 3])
1998
1999    def test_autograd_cpp_node_data_dependent(self):
2000        cpp_source = """
2001struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2002  static constexpr bool is_traceable = true;
2003  static int iteration;
2004
2005  static torch::autograd::variable_list forward(
2006      torch::autograd::AutogradContext* ctx,
2007      const torch::Tensor& x,
2008      const torch::Tensor& y) {
2009    ctx->save_for_backward({x, y});
2010    ctx->saved_data["bool"] = true;
2011    ctx->saved_data["int"] = 1;
2012
2013    switch (iteration) {
2014        case 0: {
2015            break;
2016        }
2017        case 1: {
2018            // recompile
2019            ctx->saved_data["forces_recompile"] = iteration;
2020            break;
2021        }
2022        case 2: {
2023            // recompile
2024            ctx->set_materialize_grads(false);
2025            break;
2026        }
2027        case 3: {
2028            // reuse
2029            break;
2030        }
2031        default: {
2032            throw std::runtime_error("unexpected iteration");
2033        }
2034    }
2035    iteration++;
2036    return {x, y};
2037  }
2038
2039  static torch::autograd::variable_list backward(
2040      torch::autograd::AutogradContext *ctx,
2041      torch::autograd::variable_list grad_output) {
2042    const auto& saved_variables = ctx->get_saved_variables();
2043    assert(saved_variables.size() == 2);
2044    torch::Tensor x = saved_variables[0];
2045    torch::Tensor y = saved_variables[1];
2046    c10::SymInt i = ctx->saved_data["int"].toSymInt();
2047
2048    torch::autograd::variable_list grad_inputs(2);
2049    grad_inputs[0] = x + y + i;
2050    return grad_inputs;
2051  }
2052};
2053
2054int CustomOpAutogradFunction::iteration = 0;
2055
2056torch::autograd::variable_list custom_op_backed_by_autograd_fn(const torch::Tensor& x, const torch::Tensor& y) {
2057  return CustomOpAutogradFunction::apply(x, y);
2058}
2059
2060void reset() {
2061    CustomOpAutogradFunction::iteration = 0;
2062}
2063
2064TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) {
2065    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2066    m.def("reset", reset);
2067}
2068        """
2069
2070        module = torch.utils.cpp_extension.load_inline(
2071            name="test_autograd_cpp_node_data_dependent",
2072            cpp_sources=cpp_source,
2073            functions="custom_op_backed_by_autograd_fn",
2074            verbose=True,
2075        )
2076
2077        def fn():
2078            torch.ops.test_autograd_cpp_node_data_dependent.reset()
2079            for i in [10, 10, 10, 10]:
2080                x = torch.ones(i, i, requires_grad=True)
2081                y = torch.randn(i, i)
2082                (
2083                    out1,
2084                    out2,
2085                ) = torch.ops.test_autograd_cpp_node_data_dependent.custom_op_backed_by_autograd_fn(
2086                    x, y
2087                )
2088                loss = (out1 + out2).sum()
2089                loss.backward()
2090                yield x.grad
2091
2092        self.check_output_and_recompiles(fn, 3)
2093
2094    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2095    def test_free_activation_memory(self):
2096        script = """
2097import torch
2098
2099def main():
2100    assert(torch.cuda.memory_allocated() == 0)
2101
2102    # Use an op to check that the memory is freed by the time the op is executed
2103    def assertion_impl(to_clone):
2104        mem_allocated = torch.cuda.memory_allocated()
2105        assert mem_allocated < 4000000  # some activations should be freed
2106        return to_clone.clone()
2107
2108    with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib:
2109        lib.define(
2110            "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)
2111        )
2112        lib.impl("assertion_op", assertion_impl, "CPU")
2113        lib.impl("assertion_op", lambda x: x.clone(), "Meta")
2114
2115        # Create a graph that allows inputs stealing
2116        def forward(activations):
2117            add = activations[0] + 1
2118            out = add.cpu()
2119            cloned_out = torch.ops.test_compiled_autograd.assertion_op(out)
2120            return (cloned_out,)
2121
2122        gm = torch.fx.symbolic_trace(forward)
2123        torch._dynamo.utils.set_locals_to_steal(gm, ["activations"])
2124        compiled_fn = torch.compile(gm)
2125
2126        # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes)
2127        activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")]
2128        assert torch.cuda.memory_allocated() > 4000000
2129
2130        out = compiled_fn(activations)
2131        assert len(activations) == 0
2132
2133main()
2134        """
2135        self.run_as_subprocess(script)
2136
2137    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2138    def test_free_activation_memory_subclass(self):
2139        # cover the case when aot inputs have subclasses, resulting in a different runtime wrapper
2140
2141        script = """
2142import torch
2143
2144def main():
2145    assert torch.cuda.memory_allocated() == 0
2146
2147    # Use an op to check that the memory is freed by the time the op is executed
2148    def assertion_impl(to_clone):
2149        mem_allocated = torch.cuda.memory_allocated()
2150        assert mem_allocated < 1200000  # some activations should be freed
2151        assert mem_allocated > 800000  # currently subclasses don't seem to be freed in inductor
2152        return to_clone.clone()
2153
2154    with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib:
2155        lib.define(
2156            "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)
2157        )
2158        lib.impl("assertion_op", assertion_impl, "CPU")
2159        lib.impl("assertion_op", lambda x: x.clone(), "Meta")
2160        lib.impl("assertion_op", lambda x: x.clone(), "NestedTensor")
2161
2162        def fn(inputs):
2163            _, y = inputs
2164            out = y.cpu()
2165            cloned_out = torch.ops.test_compiled_autograd.assertion_op(out)
2166            return cloned_out
2167
2168        gm = torch.fx.symbolic_trace(fn)
2169        torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
2170        compiled_fn = torch.compile(gm)
2171
2172        from torch.nested._internal.nested_tensor import jagged_from_list
2173
2174        activations = [
2175            jagged_from_list(
2176                [
2177                    torch.ones((1, 100000), device="cuda"),  # 400,000 bytes
2178                    torch.ones((1, 100000), device="cuda"),  # 400,000 bytes
2179                ],
2180                None,
2181            )[
2182                0
2183            ],  # NestedTensor
2184            torch.ones((1, 100000), device="cuda"),  # 400,000 bytes
2185        ]
2186        # 1,200,000 bytes (3 * 4 * 100,000 bytes)
2187        assert torch.cuda.memory_allocated() > 1200000
2188
2189        out = compiled_fn(activations)
2190        assert len(activations) == 0
2191
2192main()
2193        """
2194
2195    def test_callback_graph_break_throws_error(self):
2196        called = [0]
2197
2198        def callback_final():
2199            called[0] += 1
2200
2201        class MyFunc(torch.autograd.Function):
2202            @staticmethod
2203            def forward(ctx, input):
2204                return input
2205
2206            @staticmethod
2207            @torch.autograd.function.once_differentiable
2208            def backward(ctx, grad):
2209                torch.autograd.Variable._execution_engine.queue_callback(callback_final)
2210                torch._dynamo.graph_break()
2211                return grad
2212
2213        a = torch.rand((3, 3), requires_grad=True)
2214        with self.assertRaisesRegex(
2215            AssertionError,
2216            "only supported when Compiled Autograd is enabled with fullgraph=True",
2217        ):
2218            with compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
2219                b = MyFunc.apply(a)
2220                b.sum().backward()
2221
2222    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2223    def test_cudagraphs_cpu_division(self):
2224        from torch._dynamo.testing import reduce_to_scalar_loss
2225
2226        model = torch.nn.Linear(10, 10, dtype=torch.float16).cuda()
2227        inputs = torch.randn(10, 10, dtype=torch.float16).cuda()
2228        out = model(inputs)
2229        loss = reduce_to_scalar_loss(out)
2230
2231        stderr_msgs = io.StringIO()
2232        with mock.patch("sys.stderr", stderr_msgs), compiled_autograd.enable(
2233            compiler_fn
2234        ):
2235            torch._inductor.config.triton.cudagraphs = True
2236            loss.backward()
2237            torch._inductor.config.triton.cudagraphs = False
2238
2239        self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue())
2240
2241    def test_cudagraphs_cpu_graph(self):
2242        from torch._dynamo.testing import reduce_to_scalar_loss
2243
2244        model = torch.nn.Linear(10, 10, dtype=torch.float16)
2245        inputs = torch.randn(10, 10, dtype=torch.float16)
2246        out = model(inputs)
2247        loss = reduce_to_scalar_loss(out)
2248
2249        with compiled_autograd.enable(compiler_fn):
2250            torch._inductor.config.triton.cudagraphs = True
2251            loss.backward()
2252            torch._inductor.config.triton.cudagraphs = False
2253
2254        self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
2255
2256    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2257    def test_cudagraphs_sdpa(self):
2258        query = torch.rand(
2259            32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True
2260        )
2261        key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
2262        value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
2263        out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
2264
2265        with config.patch(compiled_autograd=True), inductor_config.patch(
2266            "triton.cudagraphs", True
2267        ):
2268            opt_bwd = torch.compile(lambda: out.sum().backward())
2269            opt_bwd()
2270
2271        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
2272        self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
2273
2274    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2275    def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self):
2276        class MyFn(torch.autograd.Function):
2277            @staticmethod
2278            def forward(ctx, x):
2279                cpu_tensor = torch.tensor(5)
2280                ctx.save_for_backward(x, cpu_tensor)  # visible to c++/autograd
2281                ctx.cpu_scalar = 5  # opaque to c++/autograd
2282                return x.sum()
2283
2284            @staticmethod
2285            def backward(ctx, gO):
2286                x, cpu_tensor = ctx.saved_tensors
2287                expand = gO * torch.ones_like(x)
2288                return expand * cpu_tensor * ctx.cpu_scalar
2289
2290        x = torch.randn(10, requires_grad=True, device="cuda")
2291        out = MyFn.apply(x)
2292        with config.patch(compiled_autograd=True), inductor_config.patch(
2293            "triton.cudagraphs", True
2294        ):
2295            opt_bwd = torch.compile(lambda: out.backward())
2296            opt_bwd()
2297
2298        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
2299        # Compiled autograd lifts custom autograd.Function bwd instead of tracing it.
2300        # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
2301        self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
2302
2303    @unittest.skipIf(not HAS_CUDA, "requires cuda")
2304    def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self):
2305        cpp_source = """
2306struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2307  static constexpr bool is_traceable = true;
2308
2309  static torch::Tensor forward(
2310      torch::autograd::AutogradContext* ctx,
2311      const torch::Tensor& x) {
2312    const auto& cpu_tensor = torch::tensor(1);
2313    ctx->save_for_backward({x, cpu_tensor});
2314    ctx->saved_data["cpu_scalar"] = 1;
2315    return x;
2316  }
2317
2318  static torch::autograd::variable_list backward(
2319      torch::autograd::AutogradContext *ctx,
2320      torch::autograd::variable_list grad_output) {
2321    const auto& saved_variables = ctx->get_saved_variables();
2322    assert(saved_variables.size() == 2);
2323    torch::Tensor x = saved_variables[0];
2324    torch::Tensor cpu_tensor = saved_variables[1];
2325    int cpu_scalar = ctx->saved_data["cpu_scalar"].toInt();
2326    auto expand = grad_output[0] * torch::ones_like(x);
2327    torch::autograd::variable_list grad_inputs(1);
2328    grad_inputs[0] = expand * cpu_tensor * cpu_scalar;  // autograd engine asserts that tensors are on same device
2329    return grad_inputs;
2330  }
2331};
2332
2333torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
2334  return CustomOpAutogradFunction::apply(x);
2335}
2336
2337TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
2338    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2339}
2340        """
2341
2342        module = torch.utils.cpp_extension.load_inline(
2343            name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op",
2344            cpp_sources=cpp_source,
2345            functions="custom_op_backed_by_autograd_fn",
2346            verbose=True,
2347        )
2348
2349        x = torch.randn(2, 2, requires_grad=True, device="cuda")
2350        with config.patch(compiled_autograd=True), inductor_config.patch(
2351            "triton.cudagraphs", True
2352        ):
2353            out = torch.ops.test_cudagraphs_cpu_scalar_used_in_cpp_custom_op.custom_op_backed_by_autograd_fn(
2354                x
2355            )
2356            opt_bwd = torch.compile(lambda: out.sum().backward())
2357            opt_bwd()
2358
2359        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
2360        # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops
2361        self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
2362
2363    def test_logs(self):
2364        logs, ctx = logs_to_string(
2365            torch._dynamo.compiled_autograd.__name__, "compiled_autograd"
2366        )
2367        with compiled_autograd.enable(compiler_fn), ctx():
2368            torch.randn(4, 4, requires_grad=True).sum().backward()
2369
2370        self.assertEqual(counters["compiled_autograd"]["captures"], 1)
2371        self.assertEqual(counters["compiled_autograd"]["compiles"], 1)
2372        assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue()
2373        assert (
2374            "Cache miss due to new autograd node: torch::autograd::GraphRoot"
2375            not in logs.getvalue()
2376        )
2377
2378    def test_verbose_logs_graph(self):
2379        def fn():
2380            model = torch.nn.Sequential(
2381                torch.nn.Linear(4, 4),
2382                torch.nn.ReLU(),
2383                torch.nn.Linear(4, 4),
2384                torch.nn.ReLU(),
2385            )
2386            x = torch.randn([2, 4])
2387            result = model(x).sum()
2388            result.backward()
2389            yield model[0].weight.grad
2390            yield model[0].bias.grad
2391            yield model[2].weight.grad
2392            yield model[2].bias.grad
2393
2394        logs, ctx = logs_to_string(
2395            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2396        )
2397        with ctx():
2398            self.check_output_and_recompiles(fn)
2399
2400        expected_logs = [
2401            "SumBackward0 (NodeCall 1)",
2402            "ReluBackward0 (NodeCall 2)",
2403            "AddmmBackward0 (NodeCall 3)",
2404            "TBackward0 (NodeCall 4)",
2405            "torch::autograd::AccumulateGrad (NodeCall 5)",
2406            "ReluBackward0 (NodeCall 6)",
2407            "AddmmBackward0 (NodeCall 7)",
2408            "TBackward0 (NodeCall 8)",
2409            "torch::autograd::AccumulateGrad (NodeCall 9)",
2410            "torch::autograd::AccumulateGrad (NodeCall 10)",
2411            "torch::autograd::AccumulateGrad (NodeCall 11)",
2412        ]
2413
2414        self.assertEqual(
2415            sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs)
2416        )
2417
2418    @mock.patch(
2419        "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
2420    )
2421    @mock.patch("torch._dynamo.config.inline_inbuilt_nn_modules", True)
2422    def test_verbose_logs_aot_id(self, _):
2423        def fn():
2424            model = torch.nn.Sequential(
2425                torch.nn.Linear(4, 4),
2426                torch.nn.ReLU(),
2427                torch.nn.Linear(4, 4),
2428                torch.nn.ReLU(),
2429            )
2430            x = torch.randn([2, 4])
2431
2432            @torch.compile
2433            def forward(model, x):
2434                return model(x)
2435
2436            result = forward(model, x).sum()
2437            result.backward()
2438            yield model[0].weight.grad
2439            yield model[0].bias.grad
2440            yield model[2].weight.grad
2441            yield model[2].bias.grad
2442
2443        logs, ctx = logs_to_string(
2444            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2445        )
2446        with ctx():
2447            self.check_output_and_recompiles(fn)
2448
2449        self.assertTrue("CompiledFunctionBackward0" in logs.getvalue())
2450
2451    @mock.patch(
2452        "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
2453    )
2454    def test_verbose_logs_aot_dispatcher_nodes(self, _):
2455        def fn():
2456            @torch.compile
2457            def f(x):
2458                tmp1 = x.sin()
2459                tmp2 = x.cos()
2460                torch._dynamo.graph_break()
2461                return tmp1.sin() + tmp2.cos()
2462
2463            x = torch.randn(4, requires_grad=True)
2464            out = f(x)
2465            out.sum().backward()
2466            yield x.grad
2467
2468        logs, ctx = logs_to_string(
2469            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2470        )
2471        with ctx():
2472            self.check_output_and_recompiles(fn)
2473
2474        expected_logs = [
2475            "CompiledFunctionBackward1",
2476            "aot1_tangents_1",
2477            "aot1_sin_1",
2478            "aot1_primals_2",
2479            "aot1_neg",
2480            "aot0_tangents_2",
2481            "aot1_cos_1",
2482            "aot1_primals_1",
2483            "aot0_tangents_1",
2484            "CompiledFunctionBackward0",
2485            "aot0_neg",
2486            "aot0_sin",
2487            "aot0_mul",
2488            "aot0_mul_1",
2489            "aot0_cos",
2490            "aot0_add",
2491        ]
2492
2493        self.assertEqual(
2494            sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs)
2495        )
2496
2497    @mock.patch(
2498        "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
2499    )
2500    def test_verbose_logs_aot_dispatcher_nodes_hop(self, _):
2501        @dataclasses.dataclass
2502        class CustomObj:
2503            val: torch.Tensor
2504
2505        def fn(x, obj):
2506            y = x.sin()
2507            closure_var = y + 1
2508            y.register_hook(lambda grad: grad + obj.val + closure_var)
2509            z = y.sin()
2510            return z
2511
2512        opt_fn = torch.compile(fn)
2513
2514        x = torch.ones(4, requires_grad=True)
2515        y = torch.ones(4, requires_grad=True)
2516        obj = CustomObj(torch.tensor(88))
2517        fn(x, obj).sum().backward()
2518
2519        logs, ctx = logs_to_string(
2520            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2521        )
2522        with ctx(), compiled_autograd.enable(compiler_fn):
2523            opt_fn(y, obj).sum().backward()
2524        self.assertEqual(x.grad, y.grad)
2525
2526        expected_logs = [
2527            "CompiledFunctionBackward0",
2528            "aot0_primals_2",
2529            "aot0_tangents_2",
2530            "aot0_tangents_1",
2531            "aot0_sin",
2532            "aot0_cos",
2533            "aot0_mul",
2534            "aot0_add_1",
2535            "aot0_trace_wrapped",
2536            "aot0_cos_1",
2537            "aot0_mul_1",
2538        ]
2539
2540        self.assertEqual(
2541            sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs)
2542        )
2543
2544    @skipIfWindows(msg="AssertionError: Scalars are not equal!")
2545    def test_verbose_logs_cpp(self):
2546        torch._logging.set_logs(compiled_autograd_verbose=True)
2547
2548        def fn():
2549            model = torch.nn.Sequential(
2550                torch.nn.Linear(4, 4),
2551                torch.nn.ReLU(),
2552                torch.nn.Linear(4, 4),
2553                torch.nn.ReLU(),
2554            )
2555            for i in [10, 11, 12]:
2556                model.zero_grad()
2557                x = torch.randn([i, 4])
2558                result = model(x).sum()
2559                result.backward()
2560                yield model[0].weight.grad
2561                yield model[0].bias.grad
2562                yield model[2].weight.grad
2563                yield model[2].bias.grad
2564
2565        logs, ctx = logs_to_string(
2566            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2567        )
2568        with ctx():
2569            self.check_output_and_recompiles(fn, count=2)
2570
2571        patterns1 = [
2572            r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), "
2573            r"previous key sizes=\[\]\n",
2574        ]
2575
2576        # recompile
2577        patterns2 = [
2578            r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::GraphRoot \(NodeCall 0\) as dynamic\n",
2579            r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n",
2580            r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n",
2581            r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 2\) as dynamic\n",
2582            r".*Cache miss due to changed shapes: marking size idx (\d+) of AddmmBackward0 \(NodeCall 3\) as dynamic\n",
2583            r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::AccumulateGrad "
2584            r"\(NodeCall 5\) as dynamic\n",
2585            r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 6\) as dynamic\n",
2586        ]
2587
2588        all_logs = logs.getvalue()
2589
2590        pattern1 = r"".join(patterns1)
2591        matches1 = re.findall(pattern1, all_logs)
2592        self.assertEqual(len(matches1), 1)
2593        assert isinstance(
2594            matches1[0], str
2595        )  # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]...
2596        self.assertEqual(len(matches1), len(patterns1))
2597
2598        pattern2 = r"".join(patterns2)
2599        matches2 = re.findall(pattern2, all_logs)
2600        self.assertEqual(len(matches2), 1)
2601        self.assertEqual(len(matches2[0]), len(patterns2))
2602
2603    def test_verbose_logs_snapshot(self):
2604        def fn():
2605            model = torch.nn.Sequential(
2606                torch.nn.Linear(4, 4),
2607                torch.nn.ReLU(),
2608                torch.nn.Linear(4, 4),
2609                torch.nn.ReLU(),
2610            )
2611            x = torch.randn([2, 4])
2612            result = model(x).sum()
2613            result.backward()
2614            yield model[0].weight.grad
2615            yield model[0].bias.grad
2616            yield model[2].weight.grad
2617            yield model[2].bias.grad
2618
2619        logs, ctx = logs_to_string(
2620            torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
2621        )
2622        with ctx():
2623            with compiled_autograd.enable(compiler_fn):
2624                # unused, verbose level already snapshot with contextmanager
2625                torch._logging.set_logs(compiled_autograd_verbose=True)
2626                fn()
2627
2628        unexpected_logs = [
2629            "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)"
2630        ]
2631
2632        self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
2633
2634    @unittest.expectedFailure
2635    def test_saved_tensor_unpack_hook_ordering(self):
2636        # not the correct behaviour, I'm just preventing this from changing silently
2637        def f(x, y):
2638            return x * y
2639
2640        pack_count = 0
2641        unpack_count = 0
2642
2643        def pack_hook(x):
2644            nonlocal pack_count
2645            pack_count += 1
2646            return x
2647
2648        def unpack_hook(x):
2649            nonlocal unpack_count
2650            unpack_count += 1
2651            return x
2652
2653        def tensor_hook(_):
2654            # in eager, tensor_hook is fired before unpack_hook
2655            # but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
2656            self.assertEqual(unpack_count, 0)
2657
2658        x = torch.ones(4, requires_grad=True)
2659        y = torch.ones(4, requires_grad=False)
2660        with torch.autograd.graph.saved_tensors_hooks(
2661            pack_hook, unpack_hook
2662        ), compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
2663            out_test = f(x, y)
2664            self.assertEqual(pack_count, 1)
2665            self.assertEqual(unpack_count, 0)
2666            loss = out_test.sum()
2667            loss.register_hook(tensor_hook)
2668            loss.backward()
2669            self.assertEqual(pack_count, 1)
2670            self.assertEqual(unpack_count, 1)
2671
2672    def test_reentrant_checkpointing(self):
2673        def fn(x):
2674            y = x.sin()
2675            z = y.cos()
2676            return (y * z).sum()
2677
2678        inp = torch.rand(10, 10, requires_grad=True)
2679        out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True)
2680        with self.assertRaisesRegex(
2681            RuntimeError,
2682            r"\(e.g. reentrant checkpointing\), this is not supported yet\.",
2683        ), torch._dynamo.compiled_autograd.enable(torch.compile):
2684            out.backward()
2685
2686
2687def load_test_module(name):
2688    testdir = Path(__file__).absolute().parent.parent
2689    with mock.patch("sys.path", [*sys.path, str(testdir)]):
2690        return SourceFileLoader(
2691            name, str(testdir / f"{name.replace('.', '/')}.py")
2692        ).load_module()
2693
2694
2695def make_wrapped(fn, ctxs):
2696    @functools.wraps(fn)
2697    def wrapped(self):
2698        torch._dynamo.reset()
2699        stack = contextlib.ExitStack()
2700        for ctx in ctxs:
2701            stack.enter_context(ctx)
2702        out = fn(self)
2703        stack.close()
2704        return out
2705
2706    return wrapped
2707
2708
2709def wrap_test_class(orig_cls):
2710    dct = orig_cls.__dict__.copy()
2711    for name in list(dct.keys()):
2712        fn = dct[name]
2713        if not callable(fn) or name in skipped_tests:
2714            continue
2715        elif known_failures_re.match(name) or name in known_failing_tests:
2716            dct[name] = unittest.expectedFailure
2717        elif name.startswith("test_"):
2718            fullgraph = name not in known_graph_breaks_tests
2719            ctxs = [
2720                compiled_autograd.enable(make_compiler_fn(fullgraph=fullgraph)),
2721                test_contexts.get(name, contextlib.nullcontext()),
2722            ]
2723            dct[name] = make_wrapped(fn, ctxs)
2724
2725    cls = type(
2726        orig_cls.__name__ + "WithCompiledAutograd",
2727        orig_cls.__bases__,
2728        dct,
2729    )
2730    cls.__file__ = __file__
2731    return cls
2732
2733
2734known_graph_breaks_tests = {
2735    "test_hook_none",  # uses assert in hook
2736    "test_post_accumulate_grad_hook_e2e",  # optim.Adam manually graph breaks
2737    "test_tensor_hooks_inplace",  # uses assert in hook
2738    "test_tensor_hooks_inplace_over_view",  # uses assert in hook
2739    "test_grad_fn_prehooks",  # uses assert in hook
2740    "test_grad_fn_prehooks_multiple_outputs",  # uses assert in hook
2741    "test_grad_fn_prehooks_remove_hooks",  # uses handle.remove() in hook
2742    "test_tensor_hooks_inplace_multiple_outputs",  # uses assert in hook
2743    "test_hooks",  # uses assert in hook
2744    "test_accumulate_grad_posthooks_can_observe_tensor_prehook",  # allclose
2745    "test_saved_tensors_hook_version_counter_not_shared",  # assertEqual
2746    "test_post_accumulate_grad_hook_returns_not_None",  # throws
2747    "test_custom_function_cycle",  # assertEqual
2748    "test_mark_non_differentiable_mixed",  # assertTrue
2749    "test_materialize_grads",  # assertEqual
2750    "test_return_leaf",  # assertEqual
2751    "test_save_none_for_backward",  # assertIsNone
2752    "test_saved_variables_deprecated",  # warnings.warn
2753    "test_autograd_node_isinstance",  # assertIsInstance
2754    "test_set_materialize_non_diff_grads",  # assertIsNone
2755    "test_backward_dict_grad_for_nontensor",  # torch/_custom_op/autograd.py in skip files
2756    "test_backward_dict_invalid_keys",  # torch/_custom_op/autograd.py in skip files
2757    "test_backward_dict_requires_keys_for_input_optional_tensors",  # torch/_custom_op/autograd.py in skip files
2758    "test_backward_dict_requires_keys_for_input_tensors",  # torch/_custom_op/autograd.py in skip files
2759    "test_backward_grads_are_tensor_or_none",  # torch/_custom_op/autograd.py in skip files
2760    "test_backward_impl_on_existing_op",  # torch/_custom_op/autograd.py in skip files
2761    "test_backward_returns_dict",  # torch/_custom_op/autograd.py in skip files
2762    "test_backward_tensorlist_input_requires_list_grads",  # torch/_custom_op/autograd.py in skip files
2763    "test_backward_tensorlist_input_requires_list_grads_none_or_Tensor",  # torch/_custom_op/autograd.py in skip files
2764    "test_backward_tensorlist_input_requires_list_grads_with_same_numel",  # torch/_custom_op/autograd.py in skip files
2765    "test_save_for_backward_inputs_are_namedtuple",  # torch/_custom_op/autograd.py in skip files
2766}
2767
2768test_contexts = {
2769    "test_setitem_mask": config.patch(capture_dynamic_output_shape_ops=True),
2770    "test_index_backward_does_not_save_tensor": config.patch(
2771        capture_dynamic_output_shape_ops=True
2772    ),
2773}
2774
2775# These groups of tests aren't supported yet
2776known_failures_re = re.compile(
2777    r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)"
2778)
2779
2780# Bugs needing investigation:
2781skipped_tests = {
2782    "test_callback_propagates_errors_from_device_thread",  # fullgraph for queue_callback, but graph break for RuntimeError
2783}
2784
2785known_failing_tests = {
2786    # Category: Compiled autograd
2787    "test_current_graph_task_execution_order",  # nodes are already freed by the time dynamo traces the lifted hook
2788    "test_reentrant_with_leaf_variable_hook",  # hangs when enabled with graph breaks
2789    "test_reentrant_with_non_leaf_variable_hook",  # hangs when enabled with graph breaks
2790    "test_anomaly_grad_warnings",  # does not support anomaly mode
2791    "test_autograd_inplace_views_cross_dtype",  # view_fn not supported by compiled autograd
2792    "test_current_node",  # TorchDispatchMode not yet implemented for compiled autograd
2793    "test_post_accumulate_grad_hook_ordering",  # accuracy error
2794    "test_retain_grad_cycle",  # retains_grad_hooks
2795    "test_retain_grad_inplace",  # retains_grad_hooks
2796    "test_retain_grad_inplace_over_view",  # retains_grad_hooks
2797    "test_retains_grad_can_always_observe_tensor_prehook",  # retains_grad_hooks
2798    "test_retains_grad_inplace_multiple_outputs",  # retains_grad_hooks
2799    "test_reentrant_child_error",  # hangs when enabled with graph breaks
2800    "test_accumulate_grad",  # create_graph
2801    "test_anomaly_assign_parent_cleanup",  # create_graph
2802    "test_anomaly_mode_no_check_nan",  # anomaly mode
2803    "test_backward_create_graph_warns",  # create_graph
2804    "test_backward_with_nonleaf_inputs",  # create_graph
2805    "test_create_graph_and_full_backward_hook_cycle",  # create_graph
2806    "test_current_graph_task_id",  # autograd state already cleared once dynamo is called
2807    "test_custom_autograd_repeated_grad_grad",  # create_graph
2808    "test_custom_function_forward_mode_forward_is_no_op",  # forward AD
2809    "test_custom_function_forward_mode_inplace_checks",  # forward AD
2810    "test_custom_function_forward_mode_view_checks",  # forward AD
2811    "test_custom_function_forward_mode_wrong_formula",  # forward AD
2812    "test_default_saved_tensors_hooks_double_backward",  # create_graph
2813    "test_node_post_hook_registered_during_unpack_hook",  # 'NoneType' object has no attribute 'register_hook'
2814    "test_full_backward_hook_double_backward",  # create_graph
2815    "test_function",  # create_graph
2816    "test_grad",  # create_graph
2817    "test_grad_materialize_grads",  # create_graph
2818    "test_grad_nonleaf",  # create_graph
2819    "test_grad_nonleaf_many_outputs",  # create_graph
2820    "test_hessian_vector",  # create_graph
2821    "test_hook_edge_case_when_called_with_grad",  # retains_grad_hooks
2822    "test_inplace_on_view_backward",  # create_graph
2823    "test_multi_grad_any_hooks",  # register_multi_grad_hook
2824    "test_multi_grad_all_hooks",  # retains_grad_hooks
2825    "test_nested_anomaly_detect_nan",  # create_graph
2826    "test_nested_anomaly_printstack_cleanup",  # create_graph
2827    "test_once_differentiable",  # create_graph
2828    "test_prehook_ordering",  # retains_grad_hooks
2829    "test_retain_grad",  # retains_grad_hooks
2830    "test_saved_variable_packing_unpacking_saved_original_with_hooks",  # create_graph
2831    "test_select_sum",  # create_graph, also needs graph breaks
2832    "test_will_engine_execute_node",  # retains_grad_hooks
2833    "test_backward_to_node",  # retains_grad_hooks NYI
2834    "test_anomaly_detect_nan",  # anomaly mode
2835    "test_custom_autograd_no_early_free",  # create_graph
2836    "test_custom_function_error",  # vjp
2837    "test_custom_function_save_for_forward",  # vjp
2838    "test_deep_reentrant",  # hangs with graph breaks
2839    "test_dont_materialize_grads",  # undefined grad
2840    "test_grad_mode_restored_reentrant",  # hangs with graph breaks
2841    "test_no_grad_copy",  # setting static member in lifted backward
2842    "test_no_grad_copy_sparse",  # setting static member in lifted backward
2843    "test_reentrant_priority",  # hangs with graph breaks
2844    "test_reentrant_with_callbacks_both_depths",  # hangs with graph breaks
2845    "test_reentrant_with_callbacks_depth_0",  # probably hangs with graph breaks
2846    "test_reentrant_with_callbacks_depth_1",  # probably hangs with graph breaks
2847    "test_save_output_nr",  # output_nr grad passed as None
2848    "test_setup_context_when_forward_has_default_args",  # autograd.Function with class methods
2849    "test_simple_reentrant",  # hangs with graph breaks
2850    "test_lobpcg",  # create_graph
2851    "test_grad_nonleaf_register_hook",  # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
2852    "test_backward_twice_without_saved_values",  # https://github.com/pytorch/pytorch/issues/129938
2853    # Category: Dynamo
2854    "test_accumulate_grad_tensor_reference",  # Out of bounds: frame_state_entry.stride[i] is None
2855    "test_custom_function_exception",  # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
2856    "test_to_sparse_backward",  # Out of bounds: frame_state_entry.stride[i] is None
2857    "test_autograd_simple_views_python",  # gradient is None
2858    "test_function_returns_undefined_tensor",  # gradient is None
2859    "test_naughty_autograd_function_stashing_ctx",  # bytecode issue
2860    "test_unrelated_inputs",  # gradient batching rule not implemented for aten::sym_size.int
2861    "test_custom_function_non_tensor_inputs_outputs",  # gradient batching rule not implemented for aten::sym_size.int
2862    "test_return_duplicate",  # gradient batching rule not implemented for aten::sym_size.int
2863    "test_return_duplicate_inplace",  # gradient batching rule not implemented for aten::sym_size.int
2864    "test_setitem",  # CopySlices accuracy error
2865    # Category: Inductor
2866    "test_input_buffer_accum",  # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
2867    "test_graph_save_on_cpu",  # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
2868    # Category: FakeTensor
2869    "test_saving_variable_to_disk",  # torch.save should no-op and be recorded in the graph
2870    "test_wrapped_number_saved_tensors_hooks",  # Proxy tensor should carryover is_wrapped_number_ of its original
2871    "test_grad_batched_grad",  # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi
2872    "test_scalar_grad_mixed_device",  # Fake Tensors aren't propagating device properly for 0-dim grads
2873    # Category: Divergence from eager
2874    "test_invalid_gradients",  # can't give autograd error due to inaccurate output metadata of lifted backward
2875    "test_autograd_node_isinstance",  # backward ctx is a fake cls and not directly a Node instance
2876    # Uncategorized
2877}
2878
2879if not HAS_CUDA:
2880    # Found Tesla M60 which is too old to be supported by the triton GPU compiler
2881    known_failing_tests.add("test_type_conversions")
2882
2883test_autograd = load_test_module("test_autograd")
2884test_custom_ops = load_test_module("test_custom_ops")
2885
2886TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd)
2887TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp)
2888
2889if __name__ == "__main__":
2890    if HAS_CPU:
2891        run_tests(needs="filelock")
2892