xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_spawn_ucc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import test_c10d_spawn
6from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
7
8import torch
9import torch.distributed as c10d
10from torch.testing._internal.common_cuda import TEST_MULTIGPU
11from torch.testing._internal.common_distributed import requires_ucc, skip_if_lt_x_gpu
12from torch.testing._internal.common_utils import (
13    run_tests,
14    skip_but_pass_in_sandcastle,
15    skip_but_pass_in_sandcastle_if,
16    TEST_WITH_DEV_DBG_ASAN,
17    TestCase,
18)
19
20
21NO_UCC = not hasattr(c10d, "ProcessGroupUCC")
22
23# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
24if sys.version_info < (3, 9):
25
26    class ProcessGroupShareTensorTest(
27        test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
28    ):
29        @classmethod
30        def _init_pg_ucc(cls, rank, filename, world_size):
31            store = c10d.FileStore(filename, world_size)
32            c10d.init_process_group(
33                backend="ucc", store=store, rank=rank, world_size=world_size
34            )
35            return c10d.distributed_c10d._get_default_group()
36
37        @skip_but_pass_in_sandcastle_if(
38            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
39        )
40        @skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
41        def test_shared_broadcast_ucc(self):
42            self._test_multiprocess(
43                ProcessGroupShareTensorTest._test_broadcast_process,
44                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
45                ProcessGroupShareTensorTest._init_pg_ucc,
46                1,
47            )
48
49        @skip_but_pass_in_sandcastle_if(
50            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
51        )
52        @skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
53        def test_shared_allreduce_ucc(self):
54            self._test_multiprocess(
55                ProcessGroupShareTensorTest._test_allreduce_process,
56                [torch.ones(2, 2).to(i) for i in range(self.world_size)],
57                ProcessGroupShareTensorTest._init_pg_ucc,
58                1,
59            )
60
61        @skip_but_pass_in_sandcastle_if(
62            not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
63        )
64        @skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
65        def test_shared_allgather_ucc(self):
66            self._test_multiprocess(
67                ProcessGroupShareTensorTest._test_allgather_process,
68                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
69                ProcessGroupShareTensorTest._init_pg_ucc,
70                self.world_size,
71            )
72
73
74# Skip dev-asan as torch + multiprocessing spawn have known issues
75if not TEST_WITH_DEV_DBG_ASAN:
76
77    class TestDistributedNNFunctionsUcc(TestDistributedNNFunctions):
78        # Test Common Ops First.
79        @requires_ucc()
80        @skip_if_lt_x_gpu(2)
81        @skip_but_pass_in_sandcastle_if(
82            not _torch_dist_nn_available, "torch.distributed.nn is not available"
83        )
84        def test_broadcast(self):
85            self._test_broadcast("ucc")
86
87        @requires_ucc()
88        @skip_if_lt_x_gpu(2)
89        @skip_but_pass_in_sandcastle_if(
90            not _torch_dist_nn_available, "torch.distributed.nn is not available"
91        )
92        def test_reduce(self):
93            self._test_reduce("ucc")
94
95        @requires_ucc()
96        @skip_if_lt_x_gpu(2)
97        @skip_but_pass_in_sandcastle_if(
98            not _torch_dist_nn_available, "torch.distributed.nn is not available"
99        )
100        def test_allreduce(self):
101            self._test_allreduce("ucc")
102
103        @requires_ucc()
104        @skip_if_lt_x_gpu(2)
105        @skip_but_pass_in_sandcastle_if(
106            not _torch_dist_nn_available, "torch.distributed.nn is not available"
107        )
108        @skip_but_pass_in_sandcastle(
109            "runs into illegal memory access on first assertEqual check when run locally"
110        )
111        def test_all_gather(self):
112            self._test_all_gather("ucc")
113
114        @requires_ucc()
115        @skip_if_lt_x_gpu(2)
116        @skip_but_pass_in_sandcastle_if(
117            not _torch_dist_nn_available, "torch.distributed.nn is not available"
118        )
119        def test_all_to_all(self):
120            self._test_all_to_all("ucc")
121
122        @requires_ucc()
123        @skip_if_lt_x_gpu(2)
124        @skip_but_pass_in_sandcastle_if(
125            not _torch_dist_nn_available, "torch.distributed.nn is not available"
126        )
127        def test_all_to_all_single(self):
128            self._test_all_to_all_single("ucc")
129
130
131if __name__ == "__main__":
132    run_tests()
133