xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/test_dtensor_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import copy
5import functools
6import unittest
7from unittest.mock import patch
8
9import torch
10import torch._dynamo
11import torch._dynamo.testing
12import torch.distributed as dist
13import torch.nn as nn
14from torch._C import FileCheck
15from torch._inductor.utils import run_and_get_triton_code
16from torch.distributed._tensor import (
17    DeviceMesh,
18    DTensor,
19    init_device_mesh,
20    Partial,
21    Replicate,
22    Shard,
23)
24from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
25from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
26    checkpoint_wrapper,
27    CheckpointImpl,
28)
29from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
30from torch.distributed.tensor.parallel import (
31    ColwiseParallel,
32    parallelize_module,
33    PrepareModuleInput,
34    PrepareModuleOutput,
35    RowwiseParallel,
36)
37from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
38from torch.testing._internal.common_utils import (
39    instantiate_parametrized_tests,
40    parametrize,
41    run_tests,
42)
43from torch.testing._internal.distributed._tensor.common_dtensor import (
44    DTensorTestBase,
45    MLPModule,
46    with_comms,
47)
48from torch.testing._internal.distributed.fake_pg import FakeStore
49from torch.utils._triton import has_triton
50from torch.utils.checkpoint import checkpoint
51
52
53class SimpleModel(nn.Module):
54    def __init__(self, device):
55        super().__init__()
56        self.mlp_0 = MLPModule(device)
57        self.mlp_1 = MLPModule(device)
58
59    def forward(self, input):
60        return self.mlp_1(self.mlp_0(input))
61
62
63def extract_graph(fx_g, _, graph_cell):
64    graph_cell[0] = fx_g.code
65    return fx_g
66
67
68# Make a custom compiler that runs aot autograd but extracts the fw graph
69fw_graph_cell = [None]
70bw_graph_cell = [None]
71fw_compiler = functools.partial(extract_graph, graph_cell=fw_graph_cell)
72bw_compiler = functools.partial(extract_graph, graph_cell=bw_graph_cell)
73
74from functorch.compile import min_cut_rematerialization_partition
75from torch._dynamo.backends.common import aot_autograd
76
77
78aot_eager_graph = aot_autograd(
79    fw_compiler=fw_compiler,
80    bw_compiler=bw_compiler,
81    partition_fn=min_cut_rematerialization_partition,
82)
83
84
85class TestDTensorCompile(torch._dynamo.test_case.TestCase):
86    def setUp(self):
87        super().setUp()
88        fake_store = FakeStore()
89        dist.init_process_group(
90            "fake", store=fake_store, rank=0, world_size=self.world_size
91        )
92
93    def tearDown(self):
94        super().tearDown()
95        dist.destroy_process_group()
96
97    @property
98    def device_type(self) -> str:
99        return "cuda" if torch.cuda.is_available() else "cpu"
100
101    @property
102    def world_size(self) -> int:
103        return 2
104
105    def test_placement_compile(self):
106        def fn(x):
107            a = 0
108            if x.is_replicate():
109                a += 1
110            if x.is_shard():
111                a += 2
112                if x.dim < 0:
113                    raise RuntimeError("dim < 0")
114            if x.is_shard(0):
115                a += 2
116            if x.is_shard(dim=0):
117                a += 2
118            if x.is_shard(dim=None):
119                a += 2
120            if x.is_partial():
121                a += 3
122            return a
123
124        compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
125
126        for x in [Shard(0), Replicate(), Partial()]:
127            opt_fn = fn(x)
128            compiled_out = compiled_fn(x)
129            self.assertEqual(opt_fn, compiled_out)
130
131    def test_device_mesh_compile(self):
132        def fn(x):
133            # test size()
134            a = x.size()
135            b = x.size(0)
136            c = x.size(mesh_dim=0)
137            size = a + b + c
138            # test get_coordinate()
139            coord = x.get_coordinate()
140            # test get_group()
141            group = x.get_group()
142            return size, coord, group
143
144        compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
145
146        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
147        opt_fn = fn(mesh)
148        compiled_out = compiled_fn(mesh)
149        self.assertEqual(opt_fn, compiled_out)
150
151    def test_fakify_dtensor(self):
152        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
153
154        # pass in DTensor as inputs/outputs to the function
155        def fn(x):
156            return x
157
158        x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
159        ref = fn(x)
160
161        opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
162        res = opt_fn(x)
163        self.assertEqual(res, ref)
164
165    def test_dynamo_dtensor(self):
166        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
167
168        # test passing in DTensor as inputs/outputs and run some tensor computation
169        def fn(x):
170            return x * x + 2
171
172        x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
173        ref = fn(x)
174
175        opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
176        res = opt_fn(x)
177        self.assertEqual(res, ref)
178
179    def test_dtensor_attribute_access_on_intermediate(self):
180        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
181
182        def fn(x):
183            tmp = x * 2
184            if tmp.placements[0].is_shard():
185                return tmp._local_tensor + 2
186            else:
187                return tmp._local_tensor + 3
188
189        x = DTensor.from_local(torch.ones(4), mesh, [Shard(0)], run_check=False)
190        ref = fn(x)
191
192        opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
193        res = opt_fn(x)
194        self.assertEqual(res, ref)
195
196    def test_dtensor_constructor_w_graph_break(self):
197        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
198        x = torch.randn(64, 32, requires_grad=True)
199        spec = DTensorSpec(
200            mesh,
201            (Replicate(), Shard(0)),
202            tensor_meta=TensorMeta(
203                shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype
204            ),
205        )
206
207        # test passing in DTensor as inputs/outputs and run some tensor computation
208        def fn(x):
209            print("graph break!")
210            return DTensor(
211                x,
212                spec,
213                requires_grad=x.requires_grad,
214            )
215
216        out = fn(x)
217        out2 = torch.compile(fn, backend="eager")(x)
218
219    def test_dtensor_constructor_w_dynamo_disable(self):
220        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
221        x = torch.randn(32, requires_grad=True)
222        spec = DTensorSpec(
223            mesh,
224            (Replicate(),),
225            tensor_meta=TensorMeta(shape=torch.Size([32]), stride=(1,), dtype=x.dtype),
226        )
227
228        @torch._dynamo.disable(recursive=False)
229        def fn(x):
230            print("foo")
231            return DTensor(
232                x,
233                spec,
234                requires_grad=x.requires_grad,
235            )
236
237        out = fn(x)
238        out2 = torch.compile(fn, backend="eager")(x)
239        self.assertEqual(out, out2)
240
241    def test_dtensor_noncontiguous_output(self):
242        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
243
244        # test passing in DTensor as inputs/outputs and run some tensor computation
245        def fn(x, y, z):
246            x_transposed = x.permute(0, 2, 1).contiguous()
247            tmp = torch._C._nn.linear(x_transposed, y, z)
248            return tmp.permute(0, 2, 1)
249
250        x_inner = torch.randn(4, 16, 4, requires_grad=True)
251        y_inner = torch.randn(4, 16, requires_grad=True)
252        z_inner = torch.randn(4, requires_grad=True)
253        x = DTensor.from_local(x_inner, mesh, [Shard(1)], run_check=False)
254        y = DTensor.from_local(y_inner, mesh, [Shard(1)], run_check=False)
255        z = DTensor.from_local(z_inner, mesh, [Replicate()], run_check=False)
256        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y, z)
257        out.contiguous().sum().backward()
258
259    def test_dynamo_dtensor_from_local(self):
260        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
261
262        # create DTensor inside fn and run some compute
263        def fn(x):
264            dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
265            return dt.to_local() + 2
266
267        # below is the op approach for reference
268        # from torch.distributed._tensor.api import _FromTorchTensor
269        # def from_local_tensor(x):
270        #     return _FromTorchTensor.apply(x, mesh, [Replicate()], False)
271
272        # _dt_lib_def = torch.library.Library("dtensor", "DEF")
273        # _dt_lib_def.define("from_local(Tensor self) -> Tensor")
274
275        # _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
276        # _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
277
278        x = torch.ones(1, requires_grad=True)
279        ref = fn(x)
280        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
281        opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
282        res = opt_fn(x)
283        # backward should work as well
284        res.sum().backward()
285
286        self.assertEqual(res, ref)
287        self.assertEqual(cnt.frame_count, 1)
288
289        # test if user calls from_local with mesh/placements as kwargs and that should still work
290        def from_local_kwargs_fn(x):
291            dt = DTensor.from_local(
292                x, device_mesh=mesh, placements=[Replicate()], run_check=False
293            )
294            return dt.to_local() + 2
295
296        ref = from_local_kwargs_fn(x)
297        opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True)
298        res = opt_kwargs_fn(x)
299        self.assertEqual(res, ref)
300        self.assertEqual(cnt.frame_count, 2)
301
302    def test_dynamo_dtensor_from_local_dynamic_shapes(self):
303        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
304
305        # Case 1: all dims dynamic
306        def fn(x):
307            dt = DTensor.from_local(
308                x,
309                mesh,
310                [Replicate()],
311                run_check=False,
312                shape=x.shape,
313                stride=x.stride(),
314            )
315            return dt.to_local() + 2
316
317        inp = torch.randn(4, 6, requires_grad=True)
318        ref = fn(inp)
319        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
320        res = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=True)(inp)
321        res.sum().backward()
322
323        self.assertEqual(res, ref)
324        self.assertEqual(cnt.frame_count, 1)
325
326        # Case 2: only sizes are dynamic, strides are static
327        def fn(x):
328            dt = DTensor.from_local(
329                x, mesh, [Replicate()], run_check=False, shape=x.shape, stride=(1,)
330            )
331            return dt.to_local() + 2
332
333        inp = torch.randn(4, requires_grad=True)
334        torch._dynamo.mark_dynamic(inp, 0)
335        ref = fn(inp)
336        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
337        res = torch.compile(fn, backend=cnt, fullgraph=True)(inp)
338        res.sum().backward()
339
340        self.assertEqual(res, ref)
341        self.assertEqual(cnt.frame_count, 1)
342
343        # Case 3: both sizes and strides have a mix of dynamic and static dims
344        def fn(x):
345            dt = DTensor.from_local(
346                x,
347                mesh,
348                [Replicate()],
349                run_check=False,
350                shape=(x.shape[0], x.shape[1], 2),
351                stride=(x.stride()[0], 2, 1),
352            )
353            return dt.to_local() + 2
354
355        inp = torch.randn(4, 6, 2, requires_grad=True)
356        torch._dynamo.mark_dynamic(inp, 0)
357        torch._dynamo.mark_dynamic(inp, 1)
358        ref = fn(inp)
359        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
360        res = torch.compile(fn, backend=cnt, fullgraph=True)(inp)
361        res.sum().backward()
362
363        self.assertEqual(res, ref)
364        self.assertEqual(cnt.frame_count, 1)
365
366    def test_dynamo_dtensor_recompile(self):
367        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
368
369        # test passing in DTensor as inputs/outputs and run some tensor computation
370        def fn(x):
371            return torch.mul(x, x)
372
373        x = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False)
374        x2 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False)
375        x3 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(1)], run_check=False)
376
377        cnt = torch._dynamo.testing.CompileCounter()
378        opt_fn = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=False)
379        self.assertEqual(fn(x), opt_fn(x))
380        self.assertEqual(cnt.frame_count, 1)
381        self.assertEqual(fn(x2), opt_fn(x2))
382        self.assertEqual(cnt.frame_count, 1)
383        self.assertEqual(fn(x3), opt_fn(x3))
384        self.assertEqual(cnt.frame_count, 2)
385
386    def test_dtensor_partial_placement_redistribute_unbalanced_correct_strides(self):
387        # Partial -> Shard on an unbalanced tensor results in:
388        # - A contiguous DTensor
389        # - where the inner _local_tensor is noncontiguous
390        placement = Shard(1)
391
392        def fn(x):
393            out = x.redistribute(mesh, [placement])
394            return out
395
396        # Temporarily ignore setUp(), and use rank3 graphs during tracing
397        dist.destroy_process_group()
398        fake_store = FakeStore()
399        dist.init_process_group("fake", store=fake_store, rank=3, world_size=2)
400        mesh = DeviceMesh(self.device_type, [1, 3])
401
402        x = torch.randn(10, 257, 160, requires_grad=True)
403        x_dt = DTensor.from_local(
404            x,
405            mesh,
406            [Partial()],
407            run_check=False,
408            shape=(10, 257, 160),
409            stride=(41120, 160, 1),
410        )
411
412        # tmp_dt has an inner, non-contiguous tensor, and is an autograd non-leaf
413        tmp_dt = fn(x_dt)
414        fake_mode = torch._subclasses.FakeTensorMode()
415        tmp_dt_fake = fake_mode.from_tensor(tmp_dt)
416        self.assertEqual(tmp_dt.shape, tmp_dt_fake.shape)
417        self.assertEqual(tmp_dt.stride(), tmp_dt_fake.stride())
418        self.assertEqual(tmp_dt._local_tensor.shape, tmp_dt_fake._local_tensor.shape)
419        self.assertEqual(
420            tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride()
421        )
422
423    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
424    def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self):
425        # Partial -> Shard on an unbalanced tensor results in:
426        # - A contiguous DTensor
427        # - where the inner _local_tensor is noncontiguous
428        # When this tensor is a fwd graph output,
429        # AOTAutograd needs to make sure we trace the backward
430        # with a contiguous tangent
431        placement = Shard(1)
432
433        def fn(x):
434            out = x.redistribute(mesh, [placement])
435            return out
436
437        # Temporarily ignore setUp(), and use rank3 graphs during tracing
438        dist.destroy_process_group()
439        fake_store = FakeStore()
440        dist.init_process_group("fake", store=fake_store, rank=3, world_size=2)
441        mesh = DeviceMesh(self.device_type, [1, 3])
442
443        x = torch.randn(10, 257, 160, requires_grad=True)
444        x_dt = DTensor.from_local(
445            x,
446            mesh,
447            [Partial()],
448            run_check=False,
449            shape=(10, 257, 160),
450            stride=(41120, 160, 1),
451        )
452
453        out_dt = torch.compile(fn)(x_dt)
454        # If we don't properly contiguify our traced tangents,
455        # this fails with an inductor stride assert
456        out_dt.to_local().sum().backward()
457
458    def test_dynamo_to_local_kwargs(self):
459        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
460
461        def fn(x):
462            return dt.to_local(grad_placements=[Shard(0)]) + 2
463
464        fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
465        x = torch.ones(4)
466        dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
467
468        out_ref = fn(dt)
469        out_test = fn_opt(dt)
470        self.assertEqual(out_ref, out_test)
471
472    def test_dynamo_to_local_kwargs_forward_hook(self):
473        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
474
475        def fw_hook(module, inp, out):
476            tmp = out.to_local(grad_placements=out.placements) + 2
477            return DTensor.from_local(tmp, mesh, out.placements, run_check=False)
478
479        mod = torch.nn.Linear(4, 4)
480        mod.register_forward_hook(fw_hook)
481
482        mod = torch.nn.Linear(4, 4)
483        mod.register_forward_hook(fw_hook)
484        mod.weight = torch.nn.Parameter(
485            DTensor.from_local(mod.weight, mesh, [Replicate()], run_check=False)
486        )
487        mod.bias = torch.nn.Parameter(
488            DTensor.from_local(mod.bias, mesh, [Replicate()], run_check=False)
489        )
490        opt_mod = torch.compile(mod, backend="aot_eager", fullgraph=True)
491
492        x = torch.ones(4, 4)
493        dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
494
495        out_ref = mod(dt)
496        out_test = opt_mod(dt)
497        self.assertEqual(out_ref, out_test)
498
499    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
500    def test_dtensor_different_gradient_placement(self):
501        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
502
503        def fn(x, y, z):
504            permute = x.permute(0, 2, 1)
505            permute2 = permute.contiguous()
506            layer_norm = torch.nn.functional.layer_norm(permute2, (4,), y, z, 1e-05)
507            out = layer_norm.permute(0, 2, 1)
508            return out
509
510        x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
511        x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
512
513        y = torch.randn(4, requires_grad=True, device="cuda")
514        y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
515
516        z = torch.randn(4, requires_grad=True, device="cuda")
517        z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
518
519        opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
520        tmp_dt = opt_fn(x_dt, y_dt, z_dt)
521        out_dt = torch.matmul(tmp_dt, x_dt).permute(0, 2, 1)
522        out_dt.sum().backward()
523
524    def test_dynamo_dtensor_from_local_redistribute(self):
525        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
526
527        # pass in tensor as inputs/outputs, create DTensor and run redistribute
528        # (allgather collective) inside the fn
529        def fn(x):
530            dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
531            return dt.redistribute(mesh, [Replicate()]).to_local() + 2
532
533        x = torch.ones(1)
534        ref = fn(x)
535        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
536        opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
537        res = opt_fn(x)
538        self.assertEqual(res, ref)
539
540        def redistribute_kwargs_fn(x):
541            dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
542            return (
543                dt.redistribute(device_mesh=mesh, placements=[Replicate()]).to_local()
544                + 2
545            )
546
547        x = torch.ones(1)
548        ref = redistribute_kwargs_fn(x)
549        opt_kwargs_fn = torch.compile(
550            redistribute_kwargs_fn, backend=cnt, fullgraph=True
551        )
552        res = opt_kwargs_fn(x)
553        self.assertEqual(res, ref)
554
555    def test_dtensor_dont_recompile_on_same_placement_devicemesh(self):
556        cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
557
558        @torch.compile(backend=cnt)
559        def fn(x):
560            dt = DTensor.from_local(x, mesh, [placement], run_check=False)
561
562        x = torch.ones(4, 4, requires_grad=True)
563
564        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
565        placement = Shard(1)
566        fn(x)
567
568        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
569        placement = Shard(1)
570        # no recompile, placement is unchanged
571        fn(x)
572
573        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
574        placement = Partial()
575        # recompile since placement is different
576        fn(x)
577
578        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
579        placement = Partial()
580        # no recompile, placement is unchanged
581        fn(x)
582
583        # 2 total frames (one for Partial(), one for Shard())
584        self.assertEqual(cnt.frame_count, 2)
585
586    def test_dtensor_dynamo_device_mesh_attrs(self):
587        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
588
589        # pass in tensor as inputs/outputs, create DTensor and run redistribute
590        # (allgather collective) inside the fn
591        def fn(x_dt):
592            if x_dt.device_mesh.device_type == "cuda":
593                return x_dt + 1
594            else:
595                return x_dt + 2
596
597        x = torch.ones(4, 4)
598        x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
599        ref = fn(x_dt)
600
601        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
602        res = opt_fn(x_dt)
603        self.assertEqual(ref, res)
604
605    def test_graph_input_is_async(self):
606        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
607
608        def fn(x):
609            return x.sin().sin()
610
611        opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)
612
613        x = torch.randn(4, 4, requires_grad=True)
614        x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
615        x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
616        x2 = x2.to_local()
617        out = opt_fn(x2)
618        # The important part: we get a wait_tensor() in the graph.
619        # At runtime, the input to the graph is an AsyncCollectiveTensor,
620        # and inside the graph we need to issue a wait() to synchronize.
621        self.assertExpectedInline(
622            str(fw_graph_cell[0]).strip(),
623            """\
624def forward(self, primals_1):
625    wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
626    sin = torch.ops.aten.sin.default(wait_tensor)
627    sin_1 = torch.ops.aten.sin.default(sin);  sin = None
628    return (sin_1, primals_1, wait_tensor)""",
629        )
630
631    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
632    def test_dtensor_partial_placement_graph_output(self):
633        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
634
635        def fn(x):
636            return x + x
637
638        x = torch.randn(4, 4, requires_grad=True)
639        x_dt = DTensor.from_local(x, mesh, [Partial()], run_check=False)
640
641        y = torch.randn(4, 4, requires_grad=True)
642        y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
643
644        opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
645        tmp_dt = opt_fn(x_dt)
646        out_dt = torch.matmul(tmp_dt, y_dt)
647        out_dt.sum().backward()
648
649    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
650    @skip_if_lt_x_gpu(1)
651    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
652    @patch.object(torch._inductor.config, "compile_threads", 1)
653    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
654    def test_tp_compile_comm_reordering(self):
655        class FakeAttention(nn.Module):
656            def __init__(self) -> None:
657                super().__init__()
658                self.wq = nn.Linear(16, 16)
659                self.wk = nn.Linear(16, 16)
660                self.wv = nn.Linear(16, 16)
661                self.wo = nn.Linear(16, 16)
662
663            def forward(self, x):
664                xq = self.wq(x)
665                xk = self.wk(x)
666                xv = self.wv(x)
667                # fake attention:
668                xo = xq + xk + xv
669                return self.wo(xo)
670
671        class FakeTransformerBlock(nn.Module):
672            def __init__(self) -> None:
673                super().__init__()
674                self.attn = FakeAttention()
675
676            def forward(self, x):
677                return self.attn(x)
678
679        class FakeTransformer(nn.Module):
680            def __init__(self) -> None:
681                super().__init__()
682                self.block = FakeTransformerBlock()
683
684            def forward(self, input):
685                return self.block(input)
686
687        model = FakeTransformer().to(self.device_type)
688
689        tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
690
691        # apply sequence parallel
692        parallel_plan = {
693            "attn": PrepareModuleInput(
694                input_layouts=Shard(0), desired_input_layouts=Replicate()
695            ),
696            "attn.wq": ColwiseParallel(),
697            "attn.wk": ColwiseParallel(),
698            "attn.wv": ColwiseParallel(),
699            "attn.wo": RowwiseParallel(output_layouts=Shard(0)),
700        }
701
702        parallelize_module(
703            module=model.block,
704            device_mesh=tp_mesh,
705            parallelize_plan=parallel_plan,
706        )
707
708        cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
709        compiled_model = torch.compile(model, backend=cnt, fullgraph=True)
710        inp = torch.rand(20, 16).to(self.device_type)
711        out = compiled_model(inp)
712        out.sum().backward()
713        self.assertEqual(cnt.frame_count, 1)
714
715        code = run_and_get_triton_code(compiled_model, inp)
716        FileCheck().check(
717            "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
718        ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
719            "extern_kernels.mm(buf0,"
720        ).run(
721            code
722        )
723
724
725@instantiate_parametrized_tests
726class TestDTensorCompileE2E(DTensorTestBase):
727    @property
728    def world_size(self):
729        return 4
730
731    @with_comms
732    @parametrize("is_seq_parallel", [True, False])
733    def test_tp_compile_fullgraph(self, is_seq_parallel):
734        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
735
736        model = SimpleModel(self.device_type)
737
738        colwise_style = (
739            ColwiseParallel(input_layouts=Shard(0))
740            if is_seq_parallel
741            else ColwiseParallel()
742        )
743        rowwise_style = (
744            RowwiseParallel(output_layouts=Shard(0))
745            if is_seq_parallel
746            else RowwiseParallel()
747        )
748
749        if is_seq_parallel:
750            # use input preparation to test out the compile of it
751            prepare_module_input = PrepareModuleInput(
752                input_layouts=Shard(0),
753                desired_input_layouts=Replicate(),
754            )
755            prepare_module_out = PrepareModuleOutput(
756                output_layouts=Replicate(),
757                desired_output_layouts=Shard(0),
758            )
759            plan = {
760                "mlp_0": prepare_module_input,
761                "mlp_0.net1": ColwiseParallel(),
762                "mlp_0.net2": rowwise_style,
763                "mlp_1.net1": colwise_style,
764                "mlp_1.net2": RowwiseParallel(),
765                "mlp_1": prepare_module_out,
766            }
767        else:
768            plan = {
769                "mlp_0.net1": colwise_style,
770                "mlp_0.net2": rowwise_style,
771                "mlp_1.net1": colwise_style,
772                "mlp_1.net2": rowwise_style,
773            }
774
775        model = parallelize_module(
776            model,
777            mesh,
778            parallelize_plan=plan,
779        )
780        rng_seed = self.rank if is_seq_parallel else 0
781        torch.manual_seed(rng_seed)
782        inp = torch.rand(20, 10, device=self.device_type)
783        out = model(inp)
784        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
785        compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
786        compiled_out = compiled_mod(inp)
787        compiled_out.sum().backward()
788        self.assertEqual(compiled_out, out)
789        self.assertEqual(cnt.frame_count, 1)
790
791    @with_comms
792    @skip_if_lt_x_gpu(4)
793    def test_2d_fsdp_tp_compile(self):
794        data_parallel_size = 2
795        model = SimpleModel(self.device_type)
796        model_copy = copy.deepcopy(model)
797
798        # 2-D mesh is [dp, tp]
799        twod_mesh = init_device_mesh(
800            "cuda",
801            (data_parallel_size, self.world_size // data_parallel_size),
802            mesh_dim_names=["dp", "tp"],
803        )
804
805        fsdp_pg = twod_mesh.get_group(mesh_dim=0)
806
807        inp = torch.rand(20, 10, device=self.device_type)
808        parallelize_plan = {
809            "mlp_0.net1": ColwiseParallel(),
810            "mlp_0.net2": RowwiseParallel(),
811            "mlp_1.net1": ColwiseParallel(),
812            "mlp_1.net2": RowwiseParallel(),
813        }
814        tp_model = parallelize_module(model, twod_mesh["tp"], parallelize_plan)
815        eager_2d = FSDP(
816            tp_model,
817            device_id=self.rank,
818            use_orig_params=True,
819            device_mesh=twod_mesh["dp"],
820        )
821        out = eager_2d(inp)
822        tp_model2 = parallelize_module(
823            model_copy,
824            twod_mesh["tp"],
825            parallelize_plan,
826        )
827        fsdp_2d = FSDP(
828            tp_model2,
829            device_id=self.rank,
830            use_orig_params=True,
831            device_mesh=twod_mesh["dp"],
832        )
833
834        # TODO: once aot autograd support is ready we can just use default backend
835        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
836        compiled_2d = torch.compile(fsdp_2d, backend=cnt)
837        compiled_output = compiled_2d(inp)
838
839        self.assertEqual(out, compiled_output)
840        self.assertEqual(cnt.frame_count, 1)
841
842    @with_comms
843    @skip_if_lt_x_gpu(4)
844    def test_2d_fsdp_tp_ac_compile(self):
845        dp_degree = 2
846        tp_degree = self.world_size // dp_degree
847        model = SimpleModel(self.device_type)
848        model_copy = copy.deepcopy(model)
849
850        # 2-D mesh is [dp, tp]
851        mesh_2d = init_device_mesh(
852            "cuda", mesh_shape=(dp_degree, tp_degree), mesh_dim_names=("dp", "tp")
853        )
854
855        inp = torch.rand(20, 10, device=self.device_type)
856        parallelize_plan = {
857            "mlp_0.net1": ColwiseParallel(),
858            "mlp_0.net2": RowwiseParallel(),
859            "mlp_1.net1": ColwiseParallel(),
860            "mlp_1.net2": RowwiseParallel(),
861        }
862        tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
863        tp_model = checkpoint_wrapper(
864            tp_model,
865            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
866            checkpoint_fn=checkpoint,
867            use_reentrant=False,
868        )
869        eager_2d = FSDP(tp_model, device_mesh=mesh_2d["dp"], use_orig_params=True)
870
871        tp_model2 = parallelize_module(model_copy, mesh_2d["tp"], parallelize_plan)
872        fsdp_2d = FSDP(
873            tp_model2,
874            device_mesh=mesh_2d["dp"],
875            use_orig_params=True,
876        )
877        # TODO: once aot autograd support is ready we can just use default backend
878        compiled_2d = torch.compile(fsdp_2d, backend="aot_eager")
879
880        # forward pass
881        out = eager_2d(inp)
882        compiled_output = compiled_2d(inp)
883        self.assertEqual(out, compiled_output)
884
885        # backward pass
886        out.sum().backward()
887        compiled_output.sum().backward()
888
889        # compare the gradients:
890        for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
891            self.assertEqual(n.grad, p.grad)
892
893    @with_comms
894    @skip_if_lt_x_gpu(4)
895    def test_compile_dtensor_redistribute_backward(self):
896        mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))
897
898        def fn(x, y):
899            dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
900            dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
901            dt_out = torch.matmul(dt, dt2)
902            dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
903            return dt_out_redistribute.to_local()
904
905        opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)
906
907        x_ref = torch.arange(8, requires_grad=True, dtype=torch.float32)
908        y_ref = torch.arange(8, requires_grad=True, dtype=torch.float32)
909        ref = fn(x_ref, y_ref)
910
911        x = torch.arange(8, requires_grad=True, dtype=torch.float32)
912        y = torch.arange(8, requires_grad=True, dtype=torch.float32)
913        res = opt_fn(x, y)
914
915        self.assertEqual(res, ref)
916
917        # Now run and assert the backward + gradients
918        ref.sum().backward()
919        res.sum().backward()
920
921        self.assertEqual(x_ref.grad, x.grad)
922        self.assertEqual(y_ref.grad, y.grad)
923
924
925if __name__ == "__main__":
926    run_tests()
927