xref: /aosp_15_r20/external/pytorch/test/distributed/test_nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: distributed"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport re
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
8*da0073e9SAndroid Build Coastguard Workerimport torch.cuda.nccl as nccl
9*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as c10d
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    dtypes,
13*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
16*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
17*da0073e9SAndroid Build Coastguard Worker    load_tests,
18*da0073e9SAndroid Build Coastguard Worker    NoTest,
19*da0073e9SAndroid Build Coastguard Worker    run_tests,
20*da0073e9SAndroid Build Coastguard Worker    skip_but_pass_in_sandcastle_if,
21*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
22*da0073e9SAndroid Build Coastguard Worker    TestCase,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerHIP_VERSION = (
27*da0073e9SAndroid Build Coastguard Worker    0.0
28*da0073e9SAndroid Build Coastguard Worker    if torch.version.hip is None
29*da0073e9SAndroid Build Coastguard Worker    else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
33*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
34*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard WorkernGPUs = torch.cuda.device_count()
37*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA:
38*da0073e9SAndroid Build Coastguard Worker    print("CUDA not available, skipping tests", file=sys.stderr)
39*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdatatypes = [torch.float]
43*da0073e9SAndroid Build Coastguard Workerif (
44*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA and c10d.is_nccl_available() and nccl.version() >= (2, 10)
45*da0073e9SAndroid Build Coastguard Worker) or TEST_WITH_ROCM:
46*da0073e9SAndroid Build Coastguard Worker    datatypes.append(torch.bfloat16)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerclass TestNCCL(TestCase):
50*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
51*da0073e9SAndroid Build Coastguard Worker    def test_unique_id(self, device):
52*da0073e9SAndroid Build Coastguard Worker        uid = nccl.unique_id()
53*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(uid, bytes)
54*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(len(uid), 1)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
57*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
58*da0073e9SAndroid Build Coastguard Worker    )
59*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
60*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
61*da0073e9SAndroid Build Coastguard Worker    @dtypes(*datatypes)
62*da0073e9SAndroid Build Coastguard Worker    def test_broadcast(self, device, dtype):
63*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(128).uniform_().to(dtype=dtype)
64*da0073e9SAndroid Build Coastguard Worker        tensors = [expected.cuda()]
65*da0073e9SAndroid Build Coastguard Worker        for device in range(1, torch.cuda.device_count()):
66*da0073e9SAndroid Build Coastguard Worker            tensors.append(torch.zeros(128, dtype=dtype, device=device))
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        nccl.broadcast(tensors)
69*da0073e9SAndroid Build Coastguard Worker        for i in range(torch.cuda.device_count()):
70*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensors[i], expected)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        # Test with tuple
73*da0073e9SAndroid Build Coastguard Worker        tensors = [expected.cuda()]
74*da0073e9SAndroid Build Coastguard Worker        for device in range(1, torch.cuda.device_count()):
75*da0073e9SAndroid Build Coastguard Worker            tensors.append(torch.zeros(128, dtype=dtype, device=device))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        nccl.broadcast(tuple(tensors))
78*da0073e9SAndroid Build Coastguard Worker        for i in range(torch.cuda.device_count()):
79*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensors[i], expected)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
82*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
83*da0073e9SAndroid Build Coastguard Worker    )
84*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
85*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
86*da0073e9SAndroid Build Coastguard Worker    @dtypes(*datatypes)
87*da0073e9SAndroid Build Coastguard Worker    def test_reduce(self, device, dtype):
88*da0073e9SAndroid Build Coastguard Worker        cpu_tensors = [
89*da0073e9SAndroid Build Coastguard Worker            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
90*da0073e9SAndroid Build Coastguard Worker        ]
91*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(128, dtype=dtype)
92*da0073e9SAndroid Build Coastguard Worker        for t in cpu_tensors:
93*da0073e9SAndroid Build Coastguard Worker            expected.add_(t)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
96*da0073e9SAndroid Build Coastguard Worker        nccl.reduce(tensors)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensors[0], expected)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        # Test with tuple
101*da0073e9SAndroid Build Coastguard Worker        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
102*da0073e9SAndroid Build Coastguard Worker        nccl.reduce(tuple(tensors))
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensors[0], expected)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
107*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
108*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
109*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,  # noqa: F821
110*da0073e9SAndroid Build Coastguard Worker        "Skip bfloat16 test for ROCm < 3.5",
111*da0073e9SAndroid Build Coastguard Worker    )
112*da0073e9SAndroid Build Coastguard Worker    @dtypes(*datatypes)
113*da0073e9SAndroid Build Coastguard Worker    def test_all_reduce(self, device, dtype):
114*da0073e9SAndroid Build Coastguard Worker        cpu_tensors = [
115*da0073e9SAndroid Build Coastguard Worker            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
116*da0073e9SAndroid Build Coastguard Worker        ]
117*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(128, dtype=dtype)
118*da0073e9SAndroid Build Coastguard Worker        for t in cpu_tensors:
119*da0073e9SAndroid Build Coastguard Worker            expected.add_(t)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
122*da0073e9SAndroid Build Coastguard Worker        nccl.all_reduce(tensors)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor, expected)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        # Test with tuple.
128*da0073e9SAndroid Build Coastguard Worker        tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs))
129*da0073e9SAndroid Build Coastguard Worker        nccl.all_reduce(tensors)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
132*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor, expected)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        # Test with set.
135*da0073e9SAndroid Build Coastguard Worker        tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)}
136*da0073e9SAndroid Build Coastguard Worker        nccl.all_reduce(tensors)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
139*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor, expected)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
142*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
143*da0073e9SAndroid Build Coastguard Worker    )
144*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
145*da0073e9SAndroid Build Coastguard Worker    def test_collective_errors(self, device):
146*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(10).cuda(0)
147*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
148*da0073e9SAndroid Build Coastguard Worker            TypeError, "Inputs should be a collection of tensors"
149*da0073e9SAndroid Build Coastguard Worker        ):
150*da0073e9SAndroid Build Coastguard Worker            nccl.all_reduce(t)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
153*da0073e9SAndroid Build Coastguard Worker            TypeError, "Inputs should be a collection of tensors"
154*da0073e9SAndroid Build Coastguard Worker        ):
155*da0073e9SAndroid Build Coastguard Worker            nccl.reduce(t)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
158*da0073e9SAndroid Build Coastguard Worker            TypeError, "Inputs should be a collection of tensors"
159*da0073e9SAndroid Build Coastguard Worker        ):
160*da0073e9SAndroid Build Coastguard Worker            nccl.broadcast(t)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
163*da0073e9SAndroid Build Coastguard Worker            TypeError, "Inputs should be a collection of tensors"
164*da0073e9SAndroid Build Coastguard Worker        ):
165*da0073e9SAndroid Build Coastguard Worker            nccl.all_gather(t, t)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
168*da0073e9SAndroid Build Coastguard Worker            TypeError, "Inputs should be a collection of tensors"
169*da0073e9SAndroid Build Coastguard Worker        ):
170*da0073e9SAndroid Build Coastguard Worker            nccl.reduce_scatter(t, t)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
173*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
174*da0073e9SAndroid Build Coastguard Worker    )
175*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
176*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
177*da0073e9SAndroid Build Coastguard Worker    @dtypes(*datatypes)
178*da0073e9SAndroid Build Coastguard Worker    def test_all_gather(self, device, dtype):
179*da0073e9SAndroid Build Coastguard Worker        cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
180*da0073e9SAndroid Build Coastguard Worker        expected = torch.cat(cpu_inputs, 0)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
183*da0073e9SAndroid Build Coastguard Worker        outputs = [
184*da0073e9SAndroid Build Coastguard Worker            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
185*da0073e9SAndroid Build Coastguard Worker        ]
186*da0073e9SAndroid Build Coastguard Worker        nccl.all_gather(inputs, outputs)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        for tensor in outputs:
189*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor, expected)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        # Test with tuple.
192*da0073e9SAndroid Build Coastguard Worker        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
193*da0073e9SAndroid Build Coastguard Worker        outputs = [
194*da0073e9SAndroid Build Coastguard Worker            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
195*da0073e9SAndroid Build Coastguard Worker        ]
196*da0073e9SAndroid Build Coastguard Worker        nccl.all_gather(tuple(inputs), tuple(outputs))
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        for tensor in outputs:
199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor, expected)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(
202*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
203*da0073e9SAndroid Build Coastguard Worker    )
204*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
205*da0073e9SAndroid Build Coastguard Worker    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
206*da0073e9SAndroid Build Coastguard Worker    @dtypes(*datatypes)
207*da0073e9SAndroid Build Coastguard Worker    def test_reduce_scatter(self, device, dtype):
208*da0073e9SAndroid Build Coastguard Worker        in_size = 32 * nGPUs
209*da0073e9SAndroid Build Coastguard Worker        out_size = 32
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        cpu_inputs = [
212*da0073e9SAndroid Build Coastguard Worker            torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs)
213*da0073e9SAndroid Build Coastguard Worker        ]
214*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(in_size, dtype=dtype)
215*da0073e9SAndroid Build Coastguard Worker        for t in cpu_inputs:
216*da0073e9SAndroid Build Coastguard Worker            expected.add_(t)
217*da0073e9SAndroid Build Coastguard Worker        expected = expected.view(nGPUs, 32)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
220*da0073e9SAndroid Build Coastguard Worker        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
221*da0073e9SAndroid Build Coastguard Worker        nccl.reduce_scatter(inputs, outputs)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        for i in range(nGPUs):
224*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputs[i], expected[i])
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        # Test with tuple
227*da0073e9SAndroid Build Coastguard Worker        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
228*da0073e9SAndroid Build Coastguard Worker        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
229*da0073e9SAndroid Build Coastguard Worker        nccl.reduce_scatter(tuple(inputs), tuple(outputs))
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        for i in range(nGPUs):
232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputs[i], expected[i])
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
238*da0073e9SAndroid Build Coastguard Worker    run_tests()
239