xref: /aosp_15_r20/external/pytorch/test/higher_order_ops/test_with_effects.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2# flake8: noqa: B950
3import unittest
4from collections import deque
5from functools import partial
6from typing import List, TYPE_CHECKING
7
8import torch
9import torch._dynamo
10import torch._functorch
11import torch._inductor
12import torch._inductor.decomposition
13from functorch.compile import (
14    aot_function,
15    default_decompositions,
16    min_cut_rematerialization_partition,
17    nop,
18)
19from torch._functorch.aot_autograd import aot_export_module
20from torch._higher_order_ops.effects import with_effects
21from torch._higher_order_ops.torchbind import enable_torchbind_tracing
22from torch.fx.experimental.proxy_tensor import make_fx
23from torch.testing import FileCheck
24from torch.testing._internal.common_cuda import (
25    _get_torch_cuda_version,
26    SM70OrLater,
27    SM80OrLater,
28)
29from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
30from torch.testing._internal.common_utils import (
31    IS_WINDOWS,
32    run_tests,
33    skipIfTorchDynamo,
34    TEST_CUDA,
35    TEST_WITH_ROCM,
36    TestCase,
37)
38from torch.testing._internal.torchbind_impls import init_torchbind_implementations
39
40
41if TYPE_CHECKING:
42    from torch.utils.hooks import RemovableHandle
43
44from torch.testing._internal.two_tensor import TwoTensor
45
46
47def extract_graph(fx_g, _, graph_cell):
48    graph_cell[0] = fx_g
49    return fx_g
50
51
52def get_fw_bw_graph(
53    f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
54):
55    fw_graph_cell = [None]
56    bw_graph_cell = [None]
57    requires_grad = False
58
59    def fn_req_grad(t):
60        nonlocal requires_grad
61        requires_grad = requires_grad or t.requires_grad
62        return t
63
64    torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps)
65
66    out = aot_function(
67        f,
68        fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
69        bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell)
70        if requires_grad
71        else nop,
72        partition_fn=partitioner,
73        decompositions=default_decompositions,
74        dynamic=dynamic,
75    )(*inps)
76
77    if requires_grad:
78        out.sum().backward()
79
80    return (fw_graph_cell[0], bw_graph_cell[0])
81
82
83def make_inputs_non_leaves(inps):
84    return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps)
85
86
87@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
88class TestWithEffects(TestCase):
89    def setUp(self):
90        init_torchbind_implementations()
91
92    def test_print(self):
93        class M(torch.nn.Module):
94            def forward(self, x):
95                torch.ops.aten._print("moo")
96                res = x + x
97                torch.ops.aten._print("moo")
98                return (res,)
99
100        inputs = (torch.randn(3),)
101
102        # Without functionalization, print should just appear in the graph directly
103        gm = make_fx(M())(*inputs)
104        FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
105            gm.code
106        )
107
108        # With functionalization, it should appear wrapped with with_effects()
109        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
110        self.assertExpectedInline(
111            str(gm.code).strip(),
112            """\
113def forward(self, arg0_1, arg1_1):
114    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
115    getitem = with_effects[0];  with_effects = None
116    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
117    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
118    getitem_2 = with_effects_1[0];  with_effects_1 = None
119    return (getitem_2, add)""",
120        )
121        self.assertEqual(len(gs.input_tokens), 1)
122        self.assertEqual(len(gs.output_tokens), 1)
123
124        with torch._functorch.config.patch(unlift_effect_tokens=True):
125            gm, gs = aot_export_module(M(), inputs, trace_joint=False)
126            self.assertExpectedInline(
127                str(gm.code).strip(),
128                """\
129def forward(self, arg1_1):
130    _make_token_default = torch.ops.prims._make_token.default()
131    with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo');  _make_token_default = None
132    getitem = with_effects[0];  with_effects = None
133    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
134    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
135    getitem_2 = with_effects_1[0];  with_effects_1 = None
136    _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]);  getitem_2 = _sink_tokens_default = None
137    return [add]""",  # noqa: B950
138            )
139
140    def test_torchbind_custom_op(self):
141        class M(torch.nn.Module):
142            def __init__(self) -> None:
143                super().__init__()
144                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
145
146            def forward(self, x):
147                return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
148
149        with enable_torchbind_tracing():
150            gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
151
152        self.assertExpectedInline(
153            str(gm.code).strip(),
154            """\
155def forward(self, arg0_1, arg1_1):
156    _torchbind_obj0 = self._torchbind_obj0
157    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1);  arg0_1 = _torchbind_obj0 = None
158    getitem = with_effects[0]
159    getitem_1 = with_effects[1];  with_effects = None
160    add = torch.ops.aten.add.Tensor(arg1_1, getitem_1);  arg1_1 = getitem_1 = None
161    return (getitem, add)""",  # noqa: B950
162        )
163        self.assertEqual(len(gs.input_tokens), 1)
164        self.assertEqual(len(gs.output_tokens), 1)
165
166    def test_print_with_buffer_mutations(self):
167        class M(torch.nn.Module):
168            def __init__(self) -> None:
169                super().__init__()
170                self.buf = torch.nn.Buffer(torch.ones(3))
171
172            def forward(self, x):
173                torch.ops.aten._print("moo")
174                res = x + x
175                self.buf.add_(res)
176                res = self.buf + x
177                torch.ops.aten._print("moo")
178                return (res,)
179
180        inputs = (torch.randn(3),)
181
182        # With functionalization, it should appear wrapped with with_effects()
183        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
184        self.assertExpectedInline(
185            str(gm.code).strip(),
186            """\
187def forward(self, arg0_1, arg1_1, arg2_1):
188    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
189    getitem = with_effects[0];  with_effects = None
190    add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
191    add_1 = torch.ops.aten.add.Tensor(arg1_1, add);  arg1_1 = add = None
192    add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1);  arg2_1 = None
193    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
194    getitem_2 = with_effects_1[0];  with_effects_1 = None
195    return (getitem_2, add_1, add_2)""",
196        )
197        self.assertEqual(len(gs.input_tokens), 1)
198        self.assertEqual(len(gs.output_tokens), 1)
199        self.assertEqual(len(gs.buffers_to_mutate), 1)
200
201    def test_print_with_input_mutations(self):
202        class M(torch.nn.Module):
203            def __init__(self) -> None:
204                super().__init__()
205
206            def forward(self, x):
207                torch.ops.aten._print("moo")
208                res = x + x
209                x.add_(res)
210                res = x + x
211                torch.ops.aten._print("moo")
212                return (res,)
213
214        inputs = (torch.randn(3),)
215
216        # With functionalization, it should appear wrapped with with_effects()
217        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
218        self.assertEqual(len(gs.input_tokens), 1)
219        self.assertEqual(len(gs.output_tokens), 1)
220        self.assertEqual(len(gs.user_inputs_to_mutate), 1)
221
222    def test_alias_op(self):
223        def f(token, x):
224            token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
225            return token, out
226
227        with self.assertRaisesRegex(
228            AssertionError, r"Ops with aliasing is not supported"
229        ):
230            make_fx(f)(torch.tensor([]), torch.tensor(4))
231
232    def test_compile_aot_eager(self):
233        def f(x):
234            torch.ops.aten._print("moo")
235            res = x + x
236            torch.ops.aten._print("moo")
237            return res
238
239        inputs = (torch.randn(2, 3),)
240
241        res = torch.compile(f, backend="aot_eager")(*inputs)
242        self.assertTrue(torch.allclose(res, f(*inputs)))
243
244    @unittest.skipIf(IS_WINDOWS, "triton")
245    @unittest.skipIf(not SM70OrLater, "triton")
246    def test_compile_inductor(self):
247        def f(x):
248            torch.ops.aten._print("moo")
249            res = x + x
250            torch.ops.aten._print("moo")
251            return res
252
253        inputs = (torch.randn(2, 3),)
254
255        res = torch.compile(f, backend="inductor")(*inputs)
256        self.assertTrue(torch.allclose(res, f(*inputs)))
257
258    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
259    @skipIfNoDynamoSupport
260    def test_compile_inductor_external_op_return_none(self):
261        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
262            torch.library.define(
263                "mylib::inplace_add",
264                "(Tensor input, Tensor(a!) output) -> ()",
265                lib=lib,
266            )
267
268            def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None:
269                assert input.device == output.device
270                output.add_(input)
271
272            lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd")
273
274            def f(x):
275                out = torch.empty(3)
276                out = torch.zeros_like(out)
277                torch.ops.mylib.inplace_add(x, out)
278                return out
279
280            inputs = (torch.randn(3),)
281
282            res = torch.compile(f, backend="inductor")(*inputs)
283            self.assertTrue(torch.allclose(res, f(*inputs)))
284
285    def test_compile_aot_eager_requires_grad(self):
286        def f(x):
287            torch.ops.aten._print("moo")
288            res = x + x
289            torch.ops.aten._print("moo")
290            return res
291
292        inputs = (torch.randn(2, 3, requires_grad=True),)
293
294        res = torch.compile(f, backend="aot_eager")(*inputs)
295        self.assertTrue(torch.allclose(res, f(*inputs)))
296
297        res.sum().backward()
298
299    @unittest.skipIf(IS_WINDOWS, "triton")
300    @unittest.skipIf(TEST_WITH_ROCM, "triton")
301    @unittest.skipIf(not SM80OrLater, "triton")
302    @unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
303    @unittest.skipIf(not TEST_CUDA, "triton")
304    @skipIfNoDynamoSupport
305    def test_register_effectful_custom_op(self):
306        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
307            torch._dynamo.config.capture_scalar_outputs = True
308            torch._dynamo.config.capture_dynamic_output_shape_ops = True
309
310            torch.library.define(
311                "mylib::record_scalar_tensor",
312                "(Tensor x, str prefix) -> ()",
313                lib=lib,
314            )
315
316            # global variable to store the recorded tensor and prefix.
317            recorded_dict = {}
318
319            # Pytorch custorm op implementation
320            @torch.library.impl(
321                "mylib::record_scalar_tensor",
322                "CompositeExplicitAutograd",
323                lib=lib,
324            )
325            def record_scalar_tensor(x, prefix):
326                recorded_dict[prefix] = x.clone()
327                return
328
329            # Meta function of the custom op
330            @torch.library.impl_abstract(
331                "mylib::record_scalar_tensor",
332                lib=lib,
333            )
334            def record_scalar_tensor_meta(x, prefix):
335                return
336
337            from torch._higher_order_ops.effects import (
338                _EffectType,
339                _register_effectful_op,
340            )
341
342            _register_effectful_op(
343                torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
344            )
345
346            my_config = {}
347            my_config["MockModule"] = "mean"
348            my_config["MockModule.linear"] = "mean"
349            my_config["MockModule.relu"] = "mean"
350
351            class MyLinear(torch.nn.Module):
352                def __init__(self, in_features, out_features):
353                    super().__init__()
354                    self.weight = torch.nn.Parameter(
355                        torch.randn(out_features, in_features), requires_grad=True
356                    )
357                    self.bias = torch.nn.Parameter(
358                        torch.randn(out_features), requires_grad=True
359                    )
360
361                def forward(self, x):
362                    return torch.nn.functional.linear(x, self.weight, self.bias)
363
364            class MockModule(torch.nn.Module):
365                def __init__(self) -> None:
366                    super().__init__()
367                    self.linear = MyLinear(10, 10)
368                    self.register_buffer(
369                        "buf0", torch.randn(10, 10, requires_grad=True)
370                    )
371
372                def forward(self, x):
373                    return torch.nn.functional.relu(self.linear(x) + self.buf0)
374
375            def forward_hook(
376                module: torch.nn.Module,
377                inputs: torch.Tensor,
378                output: torch.Tensor,
379                prefix: str,
380                aggregate_method: str,
381            ) -> torch.Tensor:
382                if aggregate_method == "mean":
383                    torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
384                elif aggregate_method == "max":
385                    torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
386                else:
387                    # demo purpose, using "min"
388                    torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
389                return output
390
391            def add_hooks(module, config):
392                handles: List[RemovableHandle] = []
393                q = deque([(module.__class__.__name__, module)])
394                while q:
395                    name, m = q.pop()
396                    children = [(name + "." + n, y) for (n, y) in m.named_children()]
397                    q.extend(children)
398                    aggregate_method = config.get(name, "mean")
399                    prefix = name + ":" + aggregate_method
400                    handle = m.register_forward_hook(
401                        partial(
402                            forward_hook,
403                            prefix=prefix,
404                            aggregate_method=aggregate_method,
405                        )
406                    )
407                    if handle:
408                        handles.append(handle)
409                return handles
410
411            x = torch.randn(10, 10, device="cuda")
412            mod = MockModule().to("cuda")
413
414            add_hooks(mod, my_config)
415
416            opt_mod = torch.compile(backend="inductor")(mod)
417            y = opt_mod(x)
418
419            self.assertTrue(torch.allclose(y, mod(x)))
420            # Ensure it works well with backward
421            y.sum().backward()
422            # Ensure the grad is existing
423            self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
424
425            self.assertEqual(len(recorded_dict), 2)
426            self.assertTrue("MockModule.linear:mean" in recorded_dict)
427            self.assertTrue("MockModule:mean" in recorded_dict)
428
429    @skipIfNoDynamoSupport
430    def test_effectful_custom_op_with_subclasses(self):
431        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
432            lib.define("zoo(Tensor x) -> Tensor")
433            lib.define("zoo2(Tensor x) -> Tensor")
434
435            d = {"fw": 0, "bw": 0}
436
437            def reset_counter():
438                d["fw"] = 0
439                d["bw"] = 0
440
441            def assert_counter(fw, bw):
442                self.assertEqual(d["fw"], fw)
443                self.assertEqual(d["bw"], bw)
444
445            def foo_impl(a):
446                d["fw"] = d["fw"] + 1
447                return 2 * a.clone()
448
449            def foo_meta(a):
450                return a.clone()
451
452            def foo2_impl(x):
453                d["bw"] = d["bw"] + 1
454                return x.clone()
455
456            def foo2_meta(a):
457                return a.clone()
458
459            for backend in ["CPU", "CUDA"]:
460                lib.impl("zoo", foo_impl, backend)
461                lib.impl("zoo2", foo2_impl, backend)
462            lib.impl("zoo", foo_meta, "Meta")
463            lib.impl("zoo2", foo2_meta, "Meta")
464
465            def foo_bwd(ctx, grad):
466                torch.ops._mylib.zoo2(grad)
467                return grad.clone()
468
469            torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
470
471            from torch._higher_order_ops.effects import (
472                _EffectType,
473                _register_effectful_op,
474            )
475
476            _register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
477            _register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
478
479            def fn(x, y):
480                return torch.ops._mylib.zoo(x) + y
481
482            def ins_sc():
483                return (
484                    TwoTensor(
485                        torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])
486                    ),
487                    torch.tensor([4.0, 5.0, 6.0]),
488                )
489
490            def ins_dense():
491                return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])
492
493            for i, (ins_fn, expected_fw_count) in enumerate(
494                zip([ins_sc, ins_dense], [2, 1])
495            ):
496                reset_counter()
497                ref_out = fn(*ins_fn())
498                assert_counter(expected_fw_count, 0)
499
500                compiled_fn = torch.compile(fn, backend="aot_eager")
501                out = compiled_fn(*ins_fn())
502                reset_counter()
503                out = compiled_fn(*ins_fn())
504                assert_counter(expected_fw_count, 0)
505
506                self.assertEqual(ref_out, out)
507
508            def ins_dense_req_grad():
509                return (
510                    torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
511                    torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
512                )
513
514            def ins_sc_req_grad():
515                return (
516                    TwoTensor(
517                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
518                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
519                    ),
520                    TwoTensor(
521                        torch.tensor([7.0, 8.0, 9.0], requires_grad=True),
522                        torch.tensor([10.0, 11.0, 12.0], requires_grad=True),
523                    ),
524                )
525
526            for i, (
527                ins_fn_req_grad,
528                (
529                    expected_fw_count,
530                    expected_fw_count_after_bw,
531                    expected_bw_count_after_bw,
532                ),
533            ) in enumerate(
534                zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
535            ):
536                ref_ins = ins_fn_req_grad()
537                reset_counter()
538                ref_out = fn(*ref_ins)
539                assert_counter(expected_fw_count, 0)
540                ref_out.sum().backward()
541                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
542
543                compiled_fn = torch.compile(fn, fullgraph=True)
544
545                ins = ins_fn_req_grad()
546                out = compiled_fn(*ins)
547                reset_counter()
548                out = compiled_fn(*ins)
549                assert_counter(expected_fw_count, 0)
550                self.assertEqual(ref_out, out)
551                out.sum().backward()
552                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
553                self.assertEqual(ref_ins[1].grad, ins[1].grad)
554                self.assertEqual(ref_ins[0].grad, ins[0].grad)
555
556            fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad())
557            self.assertExpectedInline(
558                fw_graph.code.strip(),
559                """\
560def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
561    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2);  primals_1 = primals_2 = None
562    getitem = with_effects[0]
563    getitem_1 = with_effects[1];  with_effects = None
564    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3);  getitem = primals_3 = None
565    getitem_2 = with_effects_1[0]
566    getitem_3 = with_effects_1[1];  with_effects_1 = None
567    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = primals_4 = None
568    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5);  getitem_3 = primals_5 = None
569    return (getitem_2, add, add_1)""",
570            )
571            self.assertExpectedInline(
572                bw_graph.code.strip(),
573                """\
574def forward(self, tangents_1, tangents_2, tangents_token):
575    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1);  tangents_token = None
576    getitem_4 = with_effects_2[0];  with_effects_2 = None
577    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2);  getitem_4 = None
578    getitem_6 = with_effects_3[0];  with_effects_3 = None
579    clone = torch.ops.aten.clone.default(tangents_1)
580    clone_1 = torch.ops.aten.clone.default(tangents_2)
581    return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
582            )
583
584    def test_effects_and_input_mutation_return(self):
585        def fn(a, b):
586            torch.ops.aten._print("effect")
587            return torch.sin(a, out=b)
588
589        inp = [torch.randn(3, 3), torch.ones(3, 3)]
590        ref_out = fn(*inp)
591        out = torch.compile(fn, fullgraph=True)(*inp)
592        self.assertEqual(ref_out, out)
593
594        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
595        self.assertExpectedInline(
596            fw_graph.code.strip(),
597            """\
598def forward(self, arg0_1, arg1_1, arg2_1):
599    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
600    getitem = with_effects[0];  with_effects = None
601    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
602    return (getitem, sin, sin)""",
603        )
604
605    def test_effects_and_input_output_view_simple(self):
606        def fn(a):
607            return a.view(-1)
608
609        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
610        ref_out = fn(*inp)
611        out = torch.compile(fn, fullgraph=True)(*inp)
612        self.assertEqual(ref_out, out)
613
614        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
615        ref_out = fn(*inp)
616        out = torch.compile(fn, fullgraph=True)(*inp)
617        self.assertEqual(ref_out, out)
618
619        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
620
621        self.assertExpectedInline(
622            fw_graph.code.strip(),
623            """\
624def forward(self, arg0_1):
625    view = torch.ops.aten.view.default(arg0_1, [-1]);  arg0_1 = None
626    return (view,)""",
627        )
628
629    def test_effects_and_aliased_outputs(self):
630        def fn(a):
631            b = a.mul(2)
632            torch.ops.aten._print("effect")
633            c = b.view(-1)
634            return b, c
635
636        f_compiled = aot_function(fn, nop)
637        for req_grad in [True, False]:
638            inp = torch.ones(3, requires_grad=req_grad)
639            out_ref = fn(inp)
640            out_test = f_compiled(inp)
641            self.assertEqual(out_ref[0], out_test[0])
642            self.assertEqual(out_ref[1], out_test[1])
643            # Try mutating one of the outputs, which is aliased.
644            out_ref[0].mul_(3)
645            out_test[0].mul_(3)
646            # Assert that the aliasing relationship was preserved
647            self.assertEqual(out_ref[0], out_test[0])
648            self.assertEqual(out_ref[1], out_test[1])
649
650    def test_effects_and_input_mutation_is_output(self):
651        def fn(a):
652            a.mul_(2)
653            torch.ops.aten._print("effect")
654            return a
655
656        inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)])
657        ref_out = fn(*inp)
658        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
659        self.assertEqual(ref_out, out)
660
661        inp = [torch.ones(3, 3, requires_grad=False)]
662        ref_out = fn(*inp)
663        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
664        self.assertEqual(ref_out, out)
665
666        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
667        self.assertExpectedInline(
668            fw_graph.code.strip(),
669            """\
670def forward(self, arg0_1, arg1_1):
671    mul = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
672    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
673    getitem = with_effects[0];  with_effects = None
674    return (getitem, mul, mul)""",
675        )
676
677    @skipIfTorchDynamo()
678    def test_effectful_op_in_backward(self):
679        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
680            lib.define("foo(Tensor x) -> Tensor")
681
682            def foo_impl(a):
683                return a.clone()
684
685            def foo_bwd(ctx, grad):
686                return torch.ops._mylib.foo(grad)
687
688            for backend in ["CPU", "CUDA", "Meta"]:
689                lib.impl("foo", foo_impl, backend)
690
691            torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
692
693            from torch._higher_order_ops.effects import (
694                _deregister_effectful_op,
695                _EffectType,
696                _register_effectful_op,
697            )
698
699            _register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
700            try:
701
702                def fn(x, y):
703                    return torch.ops._mylib.foo(x) + y
704
705                def ins_dense_req_grad():
706                    return (
707                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
708                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
709                    )
710
711                def ins_sc_req_grad():
712                    return (
713                        TwoTensor(
714                            torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
715                            torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
716                        ),
717                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
718                    )
719
720                for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
721                    ref_ins = ins_fn()
722
723                    ref_out = fn(*ref_ins)
724                    ref_out.sum().backward()
725
726                    compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
727                    ins = ins_fn()
728                    out = compiled_fn(*ins)
729                    self.assertEqual(ref_out, out)
730                    out.sum().backward()
731                    self.assertEqual(ref_ins[1].grad, ins[1].grad)
732                    self.assertEqual(ref_ins[0].grad, ins[0].grad)
733
734                    fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
735                    if i == 0:
736                        self.assertExpectedInline(
737                            fw_graph.code.strip(),
738                            """\
739def forward(self, primals_1, primals_2, primals_3):
740    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
741    getitem = with_effects[0]
742    getitem_1 = with_effects[1];  with_effects = None
743    add = torch.ops.aten.add.Tensor(getitem_1, primals_3);  getitem_1 = primals_3 = None
744    return (getitem, add)""",
745                        )
746                        self.assertExpectedInline(
747                            bw_graph.code.strip(),
748                            """\
749def forward(self, tangents_1, tangents_token):
750    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
751    getitem_2 = with_effects_1[0]
752    getitem_3 = with_effects_1[1];  with_effects_1 = None
753    return (getitem_3, tangents_1, getitem_2)""",
754                        )
755                    elif i == 1:
756                        self.assertExpectedInline(
757                            fw_graph.code.strip(),
758                            """\
759def forward(self, primals_1, primals_2, primals_3, primals_4):
760    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
761    getitem = with_effects[0]
762    getitem_1 = with_effects[1];  with_effects = None
763    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3);  getitem = primals_3 = None
764    getitem_2 = with_effects_1[0]
765    getitem_3 = with_effects_1[1];  with_effects_1 = None
766    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = None
767    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4);  getitem_3 = primals_4 = None
768    return (getitem_2, add, add_1)""",
769                        )
770                        self.assertExpectedInline(
771                            bw_graph.code.strip(),
772                            """\
773def forward(self, tangents_1, tangents_2, tangents_token):
774    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
775    getitem_4 = with_effects_2[0]
776    getitem_5 = with_effects_2[1];  with_effects_2 = None
777    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2);  getitem_4 = None
778    getitem_6 = with_effects_3[0]
779    getitem_7 = with_effects_3[1];  with_effects_3 = None
780    return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
781                        )
782                    else:
783                        raise NotImplementedError
784            finally:
785                _deregister_effectful_op(torch.ops._mylib.foo.default)
786
787    @skipIfNoDynamoSupport
788    def test_regular_effectful_op_only_in_backward(self):
789        from torch._higher_order_ops.effects import (
790            _deregister_effectful_op,
791            _EffectType,
792            _register_effectful_op,
793        )
794
795        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
796        try:
797
798            def fn(x):
799                return x.sin()
800
801            def inps_fn():
802                return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
803
804            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())
805
806            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
807            self.assertExpectedInline(
808                fw_graph.code.strip(),
809                """\
810def forward(self, primals_1):
811    sin = torch.ops.aten.sin.default(primals_1)
812    return (sin, primals_1)""",
813            )
814            self.assertExpectedInline(
815                bw_graph.code.strip(),
816                """\
817def forward(self, primals_1, tangents_1, tangents_token):
818    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
819    getitem = with_effects[0]
820    getitem_1 = with_effects[1];  with_effects = None
821    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
822    return (mul, getitem)""",
823            )
824
825            def inps_fn_sc():
826                return (
827                    TwoTensor(
828                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
829                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
830                    ),
831                )
832
833            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
834            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
835            self.assertExpectedInline(
836                fw_graph.code.strip(),
837                """\
838def forward(self, primals_1, primals_2):
839    sin = torch.ops.aten.sin.default(primals_1)
840    sin_1 = torch.ops.aten.sin.default(primals_2)
841    return (sin, sin_1, primals_1, primals_2)""",
842            )
843            self.assertExpectedInline(
844                bw_graph.code.strip(),
845                """\
846def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
847    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
848    getitem = with_effects[0]
849    getitem_1 = with_effects[1];  with_effects = None
850    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2);  getitem = primals_2 = None
851    getitem_2 = with_effects_1[0]
852    getitem_3 = with_effects_1[1];  with_effects_1 = None
853    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
854    mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3);  tangents_2 = getitem_3 = None
855    return (mul, mul_1, getitem_2)""",
856            )
857        finally:
858            _deregister_effectful_op(torch.ops.aten.cos.default)
859
860    @skipIfNoDynamoSupport
861    def test_regular_effectful_op_in_forward_and_backward(self):
862        from torch._higher_order_ops.effects import (
863            _deregister_effectful_op,
864            _EffectType,
865            _register_effectful_op,
866        )
867
868        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
869        try:
870
871            def fn(x):
872                x = x.cos()
873                return x.sin()
874
875            inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
876            torch.compile(fn, backend="inductor", fullgraph=True)(*inps)
877
878            fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
879            self.assertExpectedInline(
880                fw_graph.code.strip(),
881                """\
882def forward(self, primals_1, primals_2):
883    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2);  primals_1 = None
884    getitem = with_effects[0]
885    getitem_1 = with_effects[1];  with_effects = None
886    sin = torch.ops.aten.sin.default(getitem_1)
887    return (getitem, sin, primals_2, getitem_1)""",
888            )
889            self.assertExpectedInline(
890                bw_graph.code.strip(),
891                """\
892def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
893    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1);  tangents_token = getitem_1 = None
894    getitem_2 = with_effects_1[0]
895    getitem_3 = with_effects_1[1];  with_effects_1 = None
896    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3);  tangents_1 = getitem_3 = None
897    sin_1 = torch.ops.aten.sin.default(primals_2);  primals_2 = None
898    neg = torch.ops.aten.neg.default(sin_1);  sin_1 = None
899    mul_1 = torch.ops.aten.mul.Tensor(mul, neg);  mul = neg = None
900    return (mul_1, getitem_2)""",
901            )
902        finally:
903            _deregister_effectful_op(torch.ops.aten.cos.default)
904
905
906if __name__ == "__main__":
907    run_tests()
908