xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3
4import contextlib
5import copy
6import functools
7import unittest
8from unittest import mock
9
10import torch
11import torch._dynamo.testing
12import torch.distributed._composable.fsdp._fsdp_param
13import torch.nn.functional as F
14from torch import nn
15from torch._dynamo import compiled_autograd
16from torch._inductor import comms
17from torch._inductor.utils import is_fallback_op, run_and_get_code
18from torch.distributed._composable.fsdp import fully_shard
19from torch.distributed._composable.fsdp._fsdp_common import TrainingState
20from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
21from torch.distributed._tensor import init_device_mesh
22from torch.testing import FileCheck
23from torch.testing._internal.common_distributed import at_least_x_gpu, skip_if_lt_x_gpu
24from torch.testing._internal.common_fsdp import FSDPTest, MLP
25from torch.testing._internal.common_utils import run_tests, skipIfRocm
26from torch.testing._internal.distributed._tensor.common_dtensor import (
27    ModelArgs,
28    Transformer,
29)
30from torch.utils._triton import has_triton
31
32
33def _is_op_in_graph(graph, op):
34    return any(node.target is op for node in graph.nodes)
35
36
37def _is_fallback_op_in_snodes(snodes, op):
38    return any(is_fallback_op(snode.node, op) for snode in snodes)
39
40
41class TestFullyShardCompileCompute(FSDPTest):
42    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
43    @skip_if_lt_x_gpu(2)
44    def test_disable_compiling_hooks(self):
45        self.run_subtests(
46            {
47                "skip_fsdp_hooks": [False, True],
48            },
49            self._test_disable_compiling_hooks,
50        )
51
52    def _test_disable_compiling_hooks(
53        self,
54        skip_fsdp_hooks: bool,
55    ):
56        torch._dynamo.reset()
57        trace_rules_check_count = 0
58        HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py"
59        HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
60
61        def patched_trace_rules_check(*args, **kwargs):
62            nonlocal trace_rules_check_count
63            f_code = args[0]
64            if (
65                hasattr(f_code, "co_filename")
66                and f_code.co_filename.endswith(HOOKS_FILE_NAME)
67                and f_code.co_name != HOOK_WRAPPER_NAME
68            ):
69                trace_rules_check_count += 1
70            return orig_trace_rules_check(*args, **kwargs)
71
72        original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
73        orig_trace_rules_check = torch._dynamo.trace_rules.check
74        torch.distributed.barrier()
75        torch._dynamo.config.skip_fsdp_hooks = skip_fsdp_hooks
76        torch._dynamo.trace_rules.check = patched_trace_rules_check
77        model = MLP(4)
78        fully_shard(model)
79        model.compile()
80        model(torch.randn((4, 4), device="cuda"))
81        torch.distributed.barrier()
82        torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
83        torch._dynamo.trace_rules.check = orig_trace_rules_check
84        if skip_fsdp_hooks:
85            self.assertEqual(trace_rules_check_count, 0)
86        else:
87            self.assertTrue(trace_rules_check_count > 0)
88
89
90class TestFullyShardCompile(FSDPTest):
91    fake_pg = not at_least_x_gpu(2)
92
93    @property
94    def world_size(self) -> int:
95        return 2
96
97    def test_dynamo_trace_use_training_state(self):
98        torch._dynamo.reset()
99        # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
100        param_group = FSDPParamGroup(
101            [],  # params: List[nn.Parameter],
102            (torch.nn.Linear(1, 1),),  # module: Tuple[nn.Module, ...],
103            None,  # mesh_info: FSDPMeshInfo,
104            None,  # post_forward_mesh_info: Optional[FSDPMeshInfo],
105            None,  # device: torch.device,
106            None,  # mp_policy: MixedPrecisionPolicy,
107            None,  # offload_policy: OffloadPolicy,
108        )
109
110        def f(x):
111            param_group._training_state = TrainingState.IDLE
112            with param_group.use_training_state(TrainingState.FORWARD):
113                if param_group._training_state == TrainingState.FORWARD:
114                    return x + 1
115                else:
116                    return x
117
118        inp = torch.zeros(1)
119        self.assertEqual(param_group._training_state, TrainingState.IDLE)
120
121        eager_out = f(inp)
122        self.assertEqual(param_group._training_state, TrainingState.IDLE)
123        self.assertEqual(eager_out, inp + 1)
124
125        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
126        compiled_out = torch.compile(f, backend=cnt, fullgraph=True)(inp)
127        self.assertEqual(param_group._training_state, TrainingState.IDLE)
128        self.assertEqual(eager_out, compiled_out)
129        self.assertEqual(cnt.frame_count, 1)
130        self.assertEqual(cnt.op_count, 1)
131        self.assertEqual(len(cnt.graphs), 1)
132
133    def test_trace_fsdp_set_(self):
134        @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
135        def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
136            torch.add(x, 1, out=out)
137
138        def f(x):
139            buf = torch.zeros(2)
140            buf_view = buf.view(-1)
141            torch.ops.mylib.add_one_out(x, out=buf_view)
142            buf_view2 = buf.view(-1)
143            torch.ops.fsdp.set_(x, buf_view2)
144
145        ref_x = torch.zeros(2)
146        x = copy.deepcopy(ref_x)
147        f(ref_x)
148        torch.compile(f, backend="aot_eager")(x)
149        self.assertEqual(x, ref_x)
150
151    def _reinplace_all_gather_with_optional_checks(self, fullgraph):
152        def _run_with_checks(graph, orig_fn):
153            self.assertTrue(
154                _is_op_in_graph(
155                    graph,
156                    torch.ops._c10d_functional.all_gather_into_tensor.default,
157                )
158            )
159            orig_fn(graph)
160            self.assertFalse(
161                _is_op_in_graph(
162                    graph,
163                    torch.ops._c10d_functional.all_gather_into_tensor.default,
164                )
165            )
166            self.assertTrue(
167                _is_op_in_graph(
168                    graph,
169                    torch.ops._c10d_functional.all_gather_into_tensor_out.default,
170                )
171            )
172
173        if fullgraph:
174            return mock.patch.object(
175                comms,
176                "reinplace_fsdp_all_gather",
177                functools.partial(
178                    _run_with_checks,
179                    orig_fn=comms.reinplace_fsdp_all_gather,
180                ),
181            )
182        else:
183            return contextlib.nullcontext()
184
185    def _is_fwd_graph(self, snodes):
186        ag_copy_in_snode = None
187        for snode in snodes:
188            if is_fallback_op(snode.node, torch.ops.fsdp.all_gather_copy_in.default):
189                ag_copy_in_snode = snode
190                break
191        self.assertTrue(ag_copy_in_snode is not None)
192        if any(
193            dep.name.startswith("primals_")
194            for dep in ag_copy_in_snode.read_writes.reads
195        ):
196            return True
197        else:
198            return False
199
200    def _maybe_run_decide_global_ordering_of_comms_with_checks(self, fullgraph):
201        def _check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True):
202            assert_method = self.assertTrue if expect else self.assertFalse
203            common_ops = {
204                torch.ops.fsdp.all_gather_copy_in.default,
205                torch.ops._c10d_functional.all_gather_into_tensor_out.default,
206                torch.ops.fsdp.split_with_sizes_copy.default,
207            }
208            bwd_only_ops = {
209                torch.ops.fsdp.chunk_cat.default,
210                torch.ops._c10d_functional.reduce_scatter_tensor.default,
211            }
212            for op in common_ops:
213                assert_method(
214                    _is_fallback_op_in_snodes(
215                        snodes,
216                        op,
217                    ),
218                    msg=f"{op}",
219                )
220            if not is_fwd_graph:
221                for op in bwd_only_ops:
222                    assert_method(
223                        _is_fallback_op_in_snodes(
224                            snodes,
225                            op,
226                        ),
227                        msg=f"{op}",
228                    )
229
230        def _decide_global_ordering_of_comms_with_checks(
231            snodes, name_to_buf, name_to_fused_node, orig_fn
232        ):
233            is_fwd_graph = self._is_fwd_graph(snodes)
234            _check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True)
235            new_snodes = orig_fn(snodes, name_to_buf, name_to_fused_node)
236            _check_fsdp_ops_in_snodes(new_snodes, is_fwd_graph, expect=False)
237            return new_snodes
238
239        if fullgraph:
240            return mock.patch.object(
241                comms,
242                "decide_global_ordering_of_comms",
243                functools.partial(
244                    _decide_global_ordering_of_comms_with_checks,
245                    orig_fn=comms.decide_global_ordering_of_comms,
246                ),
247            )
248        else:
249            return contextlib.nullcontext()
250
251    def inductor_code_check_no_compute_op(self, file_check):
252        return (
253            file_check.check_not(" = aten.")
254            .check_not(" = extern_kernels.")
255            .check_not(" = triton_")
256            .check_not(" = torch.ops.")
257            .check_not(" = inductor_ops.")
258            .check_not("    aten.")
259            .check_not("    extern_kernels.")
260            .check_not("    triton_")
261            .check_not("    torch.ops.")
262            .check_not("    inductor_ops.")
263        )
264
265    def inductor_code_check_fsdp_all_gather(
266        self,
267        file_check,
268        overlapped_compute_op_str,
269        num_resize,
270        num_set,
271        last_all_gather=False,
272    ):
273        file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.")
274        file_check = self.inductor_code_check_no_compute_op(file_check)
275        file_check = file_check.check(
276            "torch.ops._c10d_functional.all_gather_into_tensor_out."
277        )
278        # Checks that AGWait is delayed, making the AG overlap with some compute op.
279        if overlapped_compute_op_str is not None:
280            file_check = file_check.check(f"{overlapped_compute_op_str}")
281        file_check = file_check.check_count(
282            "inductor_ops.resize_storage_bytes_(", num_resize, exactly=True
283        )
284        file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
285        file_check = self.inductor_code_check_no_compute_op(file_check)
286        file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.")
287        file_check = self.inductor_code_check_no_compute_op(file_check)
288        file_check = file_check.check_count(
289            "torch.ops.aten.set_.", num_set, exactly=True
290        )
291        if not last_all_gather:
292            # Checks that there is no compute op between this AGWait and next AG.
293            file_check = self.inductor_code_check_no_compute_op(file_check)
294        return file_check
295
296    def inductor_code_check_fsdp_reduce_scatter(
297        self, file_check, overlapped_compute_op_str
298    ):
299        file_check = file_check.check("torch.ops.fsdp.chunk_cat.")
300        file_check = self.inductor_code_check_no_compute_op(file_check)
301        file_check = file_check.check(
302            "torch.ops._c10d_functional.reduce_scatter_tensor."
303        )
304        # Checks that RSWait is delayed, making the RS overlap with some compute op.
305        if overlapped_compute_op_str is not None:
306            file_check = file_check.check(f"{overlapped_compute_op_str}")
307        file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
308        return file_check
309
310    @torch._dynamo.config.patch(
311        inline_inbuilt_nn_modules=True,
312        skip_fsdp_hooks=False,
313    )
314    @torch._functorch.config.patch(recompute_views=True)
315    @torch._functorch.config.patch(cse=False)
316    @torch._inductor.config.patch(
317        reorder_for_compute_comm_overlap=True,
318        reorder_for_compute_comm_overlap_passes=[
319            "sink_waits",
320            "raise_comms",
321            "reorder_compute_for_overlap",
322        ],
323    )
324    def _test_traceable_fsdp(
325        self, model_init_fn, input_creation_fn, backend, fullgraph
326    ):
327        def compiler_fn(compiled_autograd_backend):
328            def _fn(gm):
329                # fullgraph=True because graph-break in Compiled Autograd BWD graph is not supported by Traceable FSDP2 yet
330                # (main difficulty comes from queue_callback not working well when BWD has graph break).
331                return torch.compile(
332                    gm, backend=compiled_autograd_backend, fullgraph=True
333                )
334
335            return _fn
336
337        def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None):
338            torch.manual_seed(42)
339            losses = []
340            for i in range(n_iter):
341                inp = input_creation_fn()
342                if compiled_autograd_backend is not None:
343                    maybe_compiled_autograd_ctx = compiled_autograd.enable(
344                        compiler_fn(compiled_autograd_backend)
345                    )
346                else:
347                    maybe_compiled_autograd_ctx = contextlib.nullcontext()
348                with maybe_compiled_autograd_ctx:
349                    out = model(inp)
350                    loss = out.sum()
351                    losses.append(loss.item())
352                    loss.backward()
353                optim.step()
354                optim.zero_grad(set_to_none=True)
355            return losses
356
357        def test_compiled():
358            model, optim = model_init_fn()
359            # FSDP2 does lazy init using 1st run, so run it once to init using eager mode
360            run_iters(model, optim, n_iter=1)
361
362            model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph)
363            res = run_iters(model_compiled, optim, compiled_autograd_backend=backend)
364            return res
365
366        def test_eager():
367            model, optim = model_init_fn()
368            # FSDP2 does lazy init using 1st run, so run it once to init using eager mode
369            run_iters(model, optim, n_iter=1)
370
371            res = run_iters(model, optim)
372            return res
373
374        losses_compiled = test_compiled()
375        losses_eager = test_eager()
376        if not self.fake_pg:
377            for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
378                self.assertTrue(
379                    torch.allclose(
380                        torch.tensor(loss_compiled),
381                        torch.tensor(loss_eager),
382                        rtol=1e-5,
383                        atol=1e-8,
384                    ),
385                    f"{loss_compiled} vs {loss_eager}",
386                )
387
388    def _create_simple_mlp_factory_fns(self):
389        hidden_dim = 16
390
391        def model_init_fn():
392            torch.manual_seed(self.rank)
393            fsdp_config = {}
394            model = nn.Sequential(
395                nn.Linear(hidden_dim, hidden_dim, device="cuda"),
396                nn.ReLU(),
397                nn.Linear(hidden_dim, hidden_dim, device="cuda"),
398                nn.ReLU(),
399                nn.Linear(hidden_dim, hidden_dim, device="cuda"),
400            )
401            fully_shard(model, reshard_after_forward=True, **fsdp_config)
402            optim = torch.optim.SGD(model.parameters(), lr=1e-4)
403            return model, optim
404
405        def input_creation_fn():
406            torch.manual_seed(self.rank)
407            inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
408            return inp
409
410        return model_init_fn, input_creation_fn
411
412    @skipIfRocm
413    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
414    def test_simple_mlp_fullgraph_backend_aot_eager(self):
415        self._test_traceable_fsdp(
416            *self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
417        )
418
419    @skipIfRocm
420    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
421    def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
422        self._test_traceable_fsdp(
423            *self._create_simple_mlp_factory_fns(),
424            "aot_eager_decomp_partition",
425            fullgraph=True,
426        )
427
428    @skipIfRocm
429    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
430    def test_simple_mlp_fullgraph_backend_inductor(self):
431        self._test_traceable_fsdp(
432            *self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
433        )
434
435    def _create_nested_fully_shard_factory_fns(self, fullgraph):
436        hidden_dim = 16
437
438        class TestSubmodule(nn.Module):
439            def __init__(self, hidden_dim):
440                super().__init__()
441                self.param1 = nn.Parameter(
442                    torch.zeros(
443                        hidden_dim, hidden_dim, dtype=torch.float, device="cuda"
444                    )
445                )
446                self.param2 = nn.Parameter(
447                    torch.zeros(hidden_dim, dtype=torch.float, device="cuda")
448                )
449
450            def forward(self, x):
451                if not fullgraph:
452                    torch._dynamo.graph_break()
453                ret = torch.matmul(x, self.param1)
454                ret = ret * self.param2
455                ret = torch.relu(ret)
456                return ret
457
458        class TestModule(nn.Module):
459            def __init__(self, n_layers):
460                super().__init__()
461                self.layers = torch.nn.ModuleList()
462                for layer_id in range(n_layers):
463                    self.layers.append(TestSubmodule(hidden_dim))
464
465            def forward(self, x):
466                # Intentionally reusing all layers a few times,
467                # to test "multiple all-gathers for the same parameter" case.
468                for layer in self.layers:
469                    x = layer(x)
470                for layer in self.layers:
471                    x = layer(x)
472                for layer in self.layers:
473                    x = layer(x)
474                return x
475
476        def model_init_fn():
477            torch.manual_seed(self.rank)
478            fsdp_config = {}
479            mesh = init_device_mesh("cuda", (self.world_size,))
480            model = TestModule(n_layers=3)
481            for layer_id, mod in enumerate(model.layers):
482                fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
483            model = fully_shard(
484                model, mesh=mesh, reshard_after_forward=True, **fsdp_config
485            )
486            optim = torch.optim.SGD(model.parameters(), lr=1e-4)
487            return model, optim
488
489        def input_creation_fn():
490            torch.manual_seed(self.rank)
491            inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
492            return inp
493
494        return model_init_fn, input_creation_fn
495
496    @skipIfRocm
497    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
498    def test_nested_fully_shard_backend_aot_eager(self):
499        for fullgraph in [True, False]:
500            self._test_traceable_fsdp(
501                *self._create_nested_fully_shard_factory_fns(fullgraph=fullgraph),
502                "aot_eager",
503                fullgraph=fullgraph,
504            )
505
506    @skipIfRocm
507    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
508    def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
509        for fullgraph in [True, False]:
510            self._test_traceable_fsdp(
511                *self._create_nested_fully_shard_factory_fns(fullgraph=fullgraph),
512                "aot_eager_decomp_partition",
513                fullgraph=fullgraph,
514            )
515
516    @skipIfRocm
517    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
518    def test_nested_fully_shard_backend_inductor(self):
519        for fullgraph in [True, False]:
520            with self._reinplace_all_gather_with_optional_checks(
521                fullgraph
522            ), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph):
523                _, triton_codes = run_and_get_code(
524                    lambda: self._test_traceable_fsdp(
525                        *self._create_nested_fully_shard_factory_fns(
526                            fullgraph=fullgraph
527                        ),
528                        "inductor",
529                        fullgraph=fullgraph,
530                    )
531                )
532            if fullgraph:
533                self.assertTrue(
534                    len(triton_codes) == 2,
535                    "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph",
536                )
537                fwd_code = triton_codes[0]
538                file_check = FileCheck().check("def call(args):")
539                for fwd_ag_block_info in [
540                    dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
541                    dict(
542                        overlapped_compute_op_str="extern_kernels.mm(",
543                        num_resize=2,
544                        num_set=2,
545                    ),
546                    dict(
547                        overlapped_compute_op_str="extern_kernels.mm(",
548                        num_resize=2,
549                        num_set=2,
550                    ),
551                    dict(
552                        overlapped_compute_op_str="extern_kernels.mm(",
553                        num_resize=2,
554                        num_set=2,
555                    ),
556                    dict(
557                        overlapped_compute_op_str="extern_kernels.mm(",
558                        num_resize=2,
559                        num_set=2,
560                    ),
561                    dict(
562                        overlapped_compute_op_str="extern_kernels.mm(",
563                        num_resize=2,
564                        num_set=2,
565                    ),
566                    dict(
567                        overlapped_compute_op_str="extern_kernels.mm(",
568                        num_resize=2,
569                        num_set=2,
570                    ),
571                    dict(
572                        overlapped_compute_op_str="extern_kernels.mm(",
573                        num_resize=2,
574                        num_set=2,
575                    ),
576                    dict(
577                        overlapped_compute_op_str="extern_kernels.mm(",
578                        num_resize=2,
579                        num_set=2,
580                        last_all_gather=True,
581                    ),
582                ]:
583                    file_check = self.inductor_code_check_fsdp_all_gather(
584                        file_check, **fwd_ag_block_info
585                    )
586                file_check.run(fwd_code)
587
588                bwd_code = triton_codes[1]
589                file_check = FileCheck().check("def call(args):")
590                for bwd_ag_block_info in [
591                    dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
592                    dict(
593                        overlapped_compute_op_str="extern_kernels.mm(",
594                        num_resize=0,
595                        num_set=2,
596                    ),
597                    dict(
598                        overlapped_compute_op_str="extern_kernels.mm(",
599                        num_resize=0,
600                        num_set=2,
601                        last_all_gather=True,
602                    ),
603                ]:
604                    file_check = self.inductor_code_check_fsdp_all_gather(
605                        file_check, **bwd_ag_block_info
606                    )
607                for bwd_rs_block_info in [
608                    dict(overlapped_compute_op_str="extern_kernels.mm("),
609                    dict(
610                        overlapped_compute_op_str=None
611                    ),  # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
612                    dict(overlapped_compute_op_str=None),
613                ]:
614                    file_check = self.inductor_code_check_fsdp_reduce_scatter(
615                        file_check, **bwd_rs_block_info
616                    )
617                file_check.run(bwd_code)
618            else:
619                # TODO: when fullgraph=False and there is graph break in FWD graph,
620                # there are several recompiles, need to figure out why.
621                self.assertTrue(
622                    len(triton_codes) > 2,
623                    "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
624                )
625
626    def _create_transformer_factory_fns(self):
627        seq_len = 16
628        vocab_size = 8
629
630        def model_init_fn():
631            torch.manual_seed(self.rank)
632            fsdp_config = {}
633            mesh = init_device_mesh("cuda", (self.world_size,))
634            model_args = ModelArgs(
635                vocab_size=vocab_size,
636                n_layers=3,
637            )
638            model = Transformer(model_args)
639            for layer_id, mod in enumerate(model.layers):
640                fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
641            model = fully_shard(
642                model, mesh=mesh, reshard_after_forward=True, **fsdp_config
643            )
644            optim = torch.optim.SGD(model.parameters(), lr=1e-4)
645            return model, optim
646
647        def input_creation_fn():
648            torch.manual_seed(self.rank)
649            inp = torch.randint(
650                0, vocab_size, (2, seq_len), device="cuda", requires_grad=False
651            )
652            return inp
653
654        return model_init_fn, input_creation_fn
655
656    def _maybe_add_graph_break_to_sdpa(self, fullgraph):
657        def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs):
658            if not fullgraph:
659                torch._dynamo.graph_break()
660            return orig_fn(*args, **kwargs)
661
662        return mock.patch.object(
663            F,
664            "scaled_dot_product_attention",
665            functools.partial(
666                _sdpa_with_graph_break,
667                F.scaled_dot_product_attention,
668                fullgraph,
669            ),
670        )
671
672    @skipIfRocm
673    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
674    def test_transformer_backend_aot_eager(self):
675        for fullgraph in [True, False]:
676            with self._maybe_add_graph_break_to_sdpa(
677                fullgraph
678            ), self._reinplace_all_gather_with_optional_checks(fullgraph):
679                self._test_traceable_fsdp(
680                    *self._create_transformer_factory_fns(),
681                    "aot_eager",
682                    fullgraph=fullgraph,
683                )
684
685    @skipIfRocm
686    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
687    # TODO: native_dropout has worse accuracy after decomp, need to figure out why
688    @torch._inductor.config.patch(fallback_random=True)
689    def test_transformer_backend_aot_eager_decomp_partition(self):
690        for fullgraph in [True, False]:
691            with self._maybe_add_graph_break_to_sdpa(fullgraph):
692                self._test_traceable_fsdp(
693                    *self._create_transformer_factory_fns(),
694                    "aot_eager_decomp_partition",
695                    fullgraph=fullgraph,
696                )
697
698    @skipIfRocm
699    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
700    # TODO: native_dropout causes CUDA IMA error, need to figure out why
701    @torch._inductor.config.patch(fallback_random=True)
702    def test_transformer_backend_inductor(self):
703        for fullgraph in [True, False]:
704            with self._maybe_add_graph_break_to_sdpa(
705                fullgraph
706            ), self._reinplace_all_gather_with_optional_checks(
707                fullgraph
708            ), self._maybe_run_decide_global_ordering_of_comms_with_checks(
709                fullgraph
710            ):
711                _, triton_codes = run_and_get_code(
712                    lambda: self._test_traceable_fsdp(
713                        *self._create_transformer_factory_fns(),
714                        "inductor",
715                        fullgraph=fullgraph,
716                    )
717                )
718            if fullgraph:
719                self.assertTrue(
720                    len(triton_codes) == 2,
721                    "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph",
722                )
723                fwd_code = triton_codes[0]
724                file_check = FileCheck().check("def call(args):")
725                for fwd_ag_block_info in [
726                    dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4),
727                    dict(
728                        overlapped_compute_op_str="aten.native_dropout.",
729                        num_resize=0,
730                        num_set=12,
731                    ),
732                    dict(
733                        overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
734                        num_resize=12,
735                        num_set=12,
736                    ),
737                    dict(
738                        overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
739                        num_resize=12,
740                        num_set=12,
741                        last_all_gather=True,
742                    ),
743                ]:
744                    file_check = self.inductor_code_check_fsdp_all_gather(
745                        file_check, **fwd_ag_block_info
746                    )
747                file_check.run(fwd_code)
748
749                bwd_code = triton_codes[1]
750                file_check = FileCheck().check("def call(args):")
751                for bwd_ag_block_info in [
752                    dict(
753                        overlapped_compute_op_str="extern_kernels.mm(",
754                        num_resize=0,
755                        num_set=12,
756                    ),
757                    dict(
758                        overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
759                        num_resize=0,
760                        num_set=12,
761                    ),
762                    dict(
763                        overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
764                        num_resize=0,
765                        num_set=12,
766                        last_all_gather=True,
767                    ),
768                ]:
769                    file_check = self.inductor_code_check_fsdp_all_gather(
770                        file_check, **bwd_ag_block_info
771                    )
772                for bwd_rs_block_info in [
773                    dict(overlapped_compute_op_str="extern_kernels.mm("),
774                    dict(
775                        overlapped_compute_op_str=None
776                    ),  # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
777                    dict(overlapped_compute_op_str=None),
778                    dict(overlapped_compute_op_str=None),
779                ]:
780                    file_check = self.inductor_code_check_fsdp_reduce_scatter(
781                        file_check, **bwd_rs_block_info
782                    )
783                file_check.run(bwd_code)
784            else:
785                # TODO: when fullgraph=False and there is graph break in FWD graph,
786                # there are several recompiles, need to figure out why.
787                self.assertTrue(
788                    len(triton_codes) > 2,
789                    "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
790                )
791
792
793if __name__ == "__main__":
794    run_tests()
795