xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_spawn_nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import test_c10d_spawn
6from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
7
8import torch
9import torch.distributed as c10d
10from torch.testing._internal.common_cuda import TEST_MULTIGPU
11from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
12from torch.testing._internal.common_utils import (
13    run_tests,
14    skip_but_pass_in_sandcastle_if,
15    TEST_WITH_DEV_DBG_ASAN,
16    TestCase,
17)
18
19
20NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
21
22# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
23if sys.version_info < (3, 9):
24
25    class ProcessGroupShareTensorTest(
26        test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
27    ):
28        @classmethod
29        def _init_pg_nccl(cls, rank, filename, world_size):
30            store = c10d.FileStore(filename, world_size)
31            return c10d.ProcessGroupNCCL(store, rank, world_size)
32
33        @skip_but_pass_in_sandcastle_if(
34            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
35        )
36        @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
37        def test_shared_broadcast_nccl(self):
38            self._test_multiprocess(
39                ProcessGroupShareTensorTest._test_broadcast_process,
40                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
41                ProcessGroupShareTensorTest._init_pg_nccl,
42                1,
43            )
44
45        @skip_but_pass_in_sandcastle_if(
46            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
47        )
48        @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
49        def test_shared_allreduce_nccl(self):
50            self._test_multiprocess(
51                ProcessGroupShareTensorTest._test_allreduce_process,
52                [torch.ones(2, 2).to(i) for i in range(self.world_size)],
53                ProcessGroupShareTensorTest._init_pg_nccl,
54                1,
55            )
56
57        @classmethod
58        def _test_reduce_process(
59            cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
60        ):
61            pg = init_pg(rank, filename, world_size)
62            x = shared_tensors[rank]
63            pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
64            if rank == 0:
65                c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
66            else:
67                c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
68            p2c.get()
69
70        @skip_but_pass_in_sandcastle_if(
71            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
72        )
73        @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
74        def test_shared_reduce_nccl(self):
75            self._test_multiprocess(
76                ProcessGroupShareTensorTest._test_reduce_process,
77                [torch.ones(2, 2).to(i) for i in range(self.world_size)],
78                ProcessGroupShareTensorTest._init_pg_nccl,
79                1,
80            )
81
82        @skip_but_pass_in_sandcastle_if(
83            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
84        )
85        @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
86        def test_shared_allgather_nccl(self):
87            self._test_multiprocess(
88                ProcessGroupShareTensorTest._test_allgather_process,
89                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
90                ProcessGroupShareTensorTest._init_pg_nccl,
91                self.world_size,
92            )
93
94
95# Skip dev-asan as torch + multiprocessing spawn have known issues
96if not TEST_WITH_DEV_DBG_ASAN:
97
98    class TestDistributedNNFunctionsNccl(TestDistributedNNFunctions):
99        # Test Common Ops First.
100        @requires_nccl()
101        @skip_if_lt_x_gpu(2)
102        @skip_but_pass_in_sandcastle_if(
103            not _torch_dist_nn_available, "torch.distributed.nn is not available"
104        )
105        def test_broadcast(self):
106            self._test_broadcast("nccl")
107
108        @requires_nccl()
109        @skip_if_lt_x_gpu(2)
110        @skip_but_pass_in_sandcastle_if(
111            not _torch_dist_nn_available, "torch.distributed.nn is not available"
112        )
113        def test_reduce(self):
114            self._test_reduce("nccl")
115
116        @requires_nccl()
117        @skip_if_lt_x_gpu(2)
118        @skip_but_pass_in_sandcastle_if(
119            not _torch_dist_nn_available, "torch.distributed.nn is not available"
120        )
121        def test_allreduce(self):
122            self._test_allreduce("nccl")
123
124        @requires_nccl()
125        @skip_if_lt_x_gpu(2)
126        @skip_but_pass_in_sandcastle_if(
127            not _torch_dist_nn_available, "torch.distributed.nn is not available"
128        )
129        def test_all_gather(self):
130            self._test_all_gather("nccl")
131
132        @requires_nccl()
133        @skip_if_lt_x_gpu(2)
134        @skip_but_pass_in_sandcastle_if(
135            not _torch_dist_nn_available, "torch.distributed.nn is not available"
136        )
137        def test_all_to_all(self):
138            self._test_all_to_all("nccl")
139
140        @requires_nccl()
141        @skip_if_lt_x_gpu(2)
142        @skip_but_pass_in_sandcastle_if(
143            not _torch_dist_nn_available, "torch.distributed.nn is not available"
144        )
145        def test_all_to_all_single(self):
146            self._test_all_to_all_single("nccl")
147
148        # Test Ops only supported in NCCL.
149        @requires_nccl()
150        @skip_if_lt_x_gpu(2)
151        @skip_but_pass_in_sandcastle_if(
152            not _torch_dist_nn_available, "torch.distributed.nn is not available"
153        )
154        def test_reduce_scatter(self):
155            store = c10d.FileStore(self.file_name, self.world_size)
156            # This is required because these functions calls directly to the .dist and needs
157            # the world to be initialized
158            c10d.init_process_group(
159                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
160            )
161            device = torch.device(f"cuda:{self.rank}")
162            x0 = torch.ones(5, 5, device=device) + self.rank
163            x1 = torch.ones(5, 5, device=device) + self.rank + 1
164            x0.requires_grad = True
165            x1.requires_grad = True
166            y = torch.empty_like(x0)
167            expected = (
168                1 + self.world_size
169            ) * self.world_size / 2 + self.world_size * self.rank
170            y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
171            self.assertEqual(y, torch.ones(5, 5, device=device) * expected)
172            z = y.sin().sum()
173            z.backward()
174            expected_0 = (1 + self.world_size) * self.world_size / 2
175            expected_1 = expected_0 + self.world_size
176            x_s_0 = (expected_0 * torch.ones(5, 5, device=device)).cos()
177            x_s_1 = (expected_1 * torch.ones(5, 5, device=device)).cos()
178            self.assertEqual(x0.grad, x_s_0)
179            self.assertEqual(x1.grad, x_s_1)
180
181        @requires_nccl()
182        @skip_if_lt_x_gpu(2)
183        @skip_but_pass_in_sandcastle_if(
184            not _torch_dist_nn_available, "torch.distributed.nn is not available"
185        )
186        def test_reduce_scatter_non_contiguous(self):
187            store = c10d.FileStore(self.file_name, self.world_size)
188            # This is required because these functions calls directly to the .dist and needs
189            # the world to be initialized
190            c10d.init_process_group(
191                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
192            )
193            device = torch.device(f"cuda:{self.rank}")
194
195            class NonContiguousGrad(torch.autograd.Function):
196                @staticmethod
197                def forward(ctx, input):
198                    return input
199
200                @staticmethod
201                def backward(ctx, grad_output):
202                    # Make grad non-contiguous
203                    return grad_output.clone().transpose(0, 1)
204
205            x0 = torch.rand(5, 5, device=device, requires_grad=True)
206            x1 = torch.rand(5, 5, device=device, requires_grad=True)
207            y = torch.empty(5, 5, device=device)
208
209            y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
210            NonContiguousGrad.apply(y).sum().backward()
211
212        @requires_nccl()
213        @skip_if_lt_x_gpu(2)
214        @skip_but_pass_in_sandcastle_if(
215            not _torch_dist_nn_available, "torch.distributed.nn is not available"
216        )
217        def test_all_gather_base(self):
218            store = c10d.FileStore(self.file_name, self.world_size)
219            c10d.init_process_group(
220                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
221            )
222
223            device = torch.device(f"cuda:{self.rank}")
224            x = torch.ones(5, 5, device=device) + self.rank
225            x.requires_grad = True
226
227            output = torch.empty(5 * self.world_size, 5, device=device)
228            output = torch.distributed.nn.functional._all_gather_base(output, x)
229            self.assertEqual(output.size(), torch.Size((5 * self.world_size, 5)))
230
231            for idx in range(self.world_size):
232                self.assertEqual(
233                    output[5 * idx : 5 * (idx + 1)],
234                    torch.ones(5, 5, device=device) + idx,
235                )
236
237            y = torch.sum(output.view(self.world_size, 5, 5), axis=0)
238            z = y.sin().sum()
239            z.backward()
240
241            x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
242            self.assertEqual(x.grad, x_s)
243
244
245if __name__ == "__main__":
246    run_tests()
247