xref: /aosp_15_r20/external/pytorch/test/distributed/test_fake_pg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4import unittest
5
6import torch
7import torch.distributed as dist
8import torch.distributed._functional_collectives as funcol
9import torch.nn as nn
10from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.distributed.tensor.parallel import (
13    ColwiseParallel,
14    parallelize_module,
15    RowwiseParallel,
16)
17from torch.fx.experimental.proxy_tensor import make_fx
18from torch.testing import FileCheck
19from torch.testing._internal.common_utils import run_tests, TestCase
20from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
21from torch.testing._internal.distributed.fake_pg import FakeStore
22
23
24if not dist.is_available():
25    print("Distributed not available, skipping tests", file=sys.stderr)
26    sys.exit(0)
27
28HAS_CUDA = torch.cuda.is_available()
29
30
31class TestFakePG(TestCase):
32    def tearDown(self):
33        super().tearDown()
34        dist.destroy_process_group()
35
36    def test_all_reduce(self):
37        store = FakeStore()
38        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
39
40        output = torch.ones(3, 3) * dist.get_rank()
41        dist.all_reduce(output)
42        self.assertEqual(tuple(output.shape), (3, 3))
43
44    def test_allgather(self):
45        store = FakeStore()
46        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
47
48        input_tensor = torch.ones(3, 3) * dist.get_rank()
49        output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
50        dist.all_gather(output_tensors, input_tensor)
51        for _, out_tensor in enumerate(output_tensors):
52            self.assertEqual(tuple(out_tensor.shape), (3, 3))
53
54    def test_reduce_scatter(self):
55        store = FakeStore()
56        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
57
58        to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
59        output_tensor = torch.empty(3, 3)
60
61        dist.reduce_scatter(output_tensor, to_reduce_scatter)
62        self.assertEqual(tuple(output_tensor.shape), (3, 3))
63
64    @unittest.skipIf(not HAS_CUDA, "No CUDA")
65    def test_construct_fsdp(self):
66        store = FakeStore()
67        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
68        FSDP(nn.Linear(2, 3, device="cuda"))
69
70    @unittest.skipIf(not HAS_CUDA, "No CUDA")
71    def test_fsdp_fake_e2e(self):
72        store = dist.HashStore()
73        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
74        my_module = nn.Sequential(
75            nn.Linear(2, 3, device="cuda"),
76            nn.ReLU(),
77            nn.Linear(3, 2, device="cuda"),
78        )
79        sharded_module = FSDP(my_module, use_orig_params=True)
80        optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
81        input = torch.randn(2, 2)
82        x = sharded_module(input)
83        loss = x.sum()
84        loss.backward()
85        optim.step()
86
87    @unittest.skipIf(not HAS_CUDA, "No CUDA")
88    def test_fake_pg_tracing(self):
89        store = dist.HashStore()
90        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
91
92        default_pg = dist.distributed_c10d._get_default_group()
93
94        def allgather_fn(tensor):
95            return funcol.all_gather_tensor(tensor, 0, default_pg)
96
97        gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
98        FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
99
100    def test_broadcast(self):
101        store = FakeStore()
102        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
103
104        # src == rank
105        output = torch.ones(3, 3)
106        dist.broadcast(output, src=0)
107        self.assertEqual(tuple(output.shape), (3, 3))
108
109        # src != rank
110        output = torch.ones(3, 3)
111        dist.broadcast(output, src=1)
112        self.assertEqual(tuple(output.shape), (3, 3))
113
114    def test_scatter(self):
115        store = FakeStore()
116        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
117
118        # src == rank
119        output = torch.ones(3, 3)
120        to_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
121        dist.scatter(output, to_scatter)
122        self.assertEqual(tuple(output.shape), (3, 3))
123
124        # src != rank
125        output = torch.ones(3, 3)
126        dist.scatter(output, None, src=1)
127        self.assertEqual(tuple(output.shape), (3, 3))
128
129    def test_alltoall(self):
130        store = FakeStore()
131        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
132
133        output_list = [torch.ones(3, 3) for _ in range(2)]
134        input_list = [torch.ones(3, 3) for _ in range(2)]
135        dist.all_to_all(output_list, input_list)
136        self.assertEqual(len(output_list), 2)
137        for output in output_list:
138            self.assertEqual(tuple(output.shape), (3, 3))
139
140    def test_alltoall_base(self):
141        store = FakeStore()
142        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
143
144        out_tensor = torch.ones(3, 3)
145        in_tensor = torch.ones(3, 3)
146        output_split = [1, 1]
147        input_split = [1, 1]
148        dist.all_to_all_single(out_tensor, in_tensor, output_split, input_split)
149        self.assertEqual(tuple(out_tensor.shape), (3, 3))
150
151    def test_send(self):
152        store = FakeStore()
153        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
154
155        tensor = torch.ones(3, 3)
156        dist.send(tensor, 1)
157        self.assertEqual(tuple(tensor.shape), (3, 3))
158
159    def test_recv(self):
160        store = FakeStore()
161        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
162
163        output = torch.ones(3, 3)
164        dist.recv(output, 1)
165        self.assertEqual(tuple(output.shape), (3, 3))
166
167    @unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
168    def test_fsdp_tp_fake_e2e(self):
169        world_size = 4
170        tp_size = 2
171
172        store = dist.HashStore()
173        dist.init_process_group(
174            backend="fake", rank=0, world_size=world_size, store=store
175        )
176
177        device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
178        device_mesh = init_device_mesh(
179            "cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
180        )
181
182        sequence_parallelize_plan = {
183            "net1": ColwiseParallel(input_layouts=Shard(0)),
184            "net2": RowwiseParallel(output_layouts=Shard(0)),
185        }
186        pairwise_parallelize_plan = {
187            "net1": ColwiseParallel(),
188            "net2": RowwiseParallel(),
189        }
190        for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
191            my_module = parallelize_module(
192                MLPModule(device="cuda"),
193                device_mesh["tp"],
194                parallel_plan,
195            )
196
197            sharded_module = FSDP(
198                my_module, use_orig_params=True, device_mesh=device_mesh["dp"]
199            )
200            optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
201
202            for i in range(10):
203                dp_rank = dist.get_rank()
204                torch.manual_seed(i + dp_rank)
205                input = torch.randn(20, 10).cuda(dist.get_rank())
206                x = sharded_module(input)
207                loss = x.sum()
208                loss.backward()
209                optim.step()
210
211
212if __name__ == "__main__":
213    run_tests()
214