xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_ops_nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2# This test file contains positive tests for c10d with NCCL backend.
3# During the test, it is expected that ProcessGroup will not be aborted, destroyed or incur fatal error.
4# Please be mindful of this when adding tests here.
5# If you need to add tests for group creation, abort or destroy, please add tests in test_c10d_nccl.py.
6
7# There are two ways to launch tests in this file:
8# 1. Run this file directly with `python test_c10d_ops_nccl.py`
9# 2. Use multi-process launcher, e.g. `torchrun --standalone --nproc-per-node 2 test_c10d_ops_nccl.py`
10
11import math
12import os
13import sys
14import tempfile
15
16import torch
17import torch.distributed as c10d
18
19
20if not c10d.is_available() or not c10d.is_nccl_available():
21    print("c10d NCCL not available, skipping tests", file=sys.stderr)
22    sys.exit(0)
23
24
25import torch.distributed as dist
26from torch.testing._internal.common_cuda import TEST_MULTIGPU
27from torch.testing._internal.common_distributed import (
28    init_multigpu_helper,
29    MultiProcContinousTest,
30    requires_nccl,
31)
32from torch.testing._internal.common_utils import (
33    skip_but_pass_in_sandcastle_if,
34    skipIfRocm,
35    TEST_WITH_DEV_DBG_ASAN,
36)
37
38
39if TEST_WITH_DEV_DBG_ASAN:
40    print(
41        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
42    )
43    sys.exit(0)
44
45
46class ProcessGroupNCCLOpTest(MultiProcContinousTest):
47    @classmethod
48    def backend_str(cls) -> str:
49        return "nccl"
50
51    @classmethod
52    def opts(cls, high_priority_stream=False):
53        opts = c10d.ProcessGroupNCCL.Options()
54        opts.is_high_priority_stream = high_priority_stream
55        return opts
56
57    @property
58    def rank_to_GPU(self):
59        # return rank to GPU map
60        return init_multigpu_helper(self.world_size, "nccl")
61
62    @requires_nccl()
63    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
64    def test_empty_tensors(self):
65        pg = self.pg
66        local_device_idx = self.rank_to_GPU[self.rank][0]
67
68        xs = [torch.FloatTensor([]).cuda(local_device_idx)]
69        pg.broadcast(xs).wait()
70        self.assertEqual(0, xs[0].numel())
71
72        pg.allreduce(xs).wait()
73        self.assertEqual(0, xs[0].numel())
74
75        pg.reduce(xs).wait()
76        self.assertEqual(0, xs[0].numel())
77
78        ys = [
79            [
80                torch.FloatTensor([]).cuda(local_device_idx)
81                for _ in range(self.world_size)
82            ]
83        ]
84        pg.allgather(ys, xs).wait()
85        for y in ys[0]:
86            self.assertEqual(0, y.numel())
87
88        ys = [torch.FloatTensor([]).cuda(local_device_idx)]
89        xs = [
90            [
91                torch.FloatTensor([]).cuda(local_device_idx)
92                for _ in range(self.world_size)
93            ]
94        ]
95        pg.reduce_scatter(ys, xs).wait()
96        self.assertEqual(0, ys[0].numel())
97
98    @requires_nccl()
99    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
100    def test_broadcast_ops(self):
101        pg = self.pg
102
103        def broadcast(xs, rootRank, rootTensor):
104            opts = c10d.BroadcastOptions()
105            opts.rootRank = rootRank
106            opts.rootTensor = rootTensor
107            work = pg.broadcast(xs, opts)
108            work.wait()
109            return xs
110
111        # Every rank is root once
112        for i in range(self.world_size):
113            # Run with 1 input tensor
114            x = torch.tensor([self.rank]).cuda(self.rank_to_GPU[self.rank][0])
115            output = broadcast([x], i, 0)
116            self.assertEqual(torch.tensor([i]), output[0])
117
118            expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1)
119            xs = [
120                torch.empty([i + 1, i + 1]).fill_(-1).cuda(device=device_idx)
121                for device_idx in self.rank_to_GPU[self.rank]
122            ]
123
124            # test with multiple input tensors (multiple gpu in one rank)
125            for j in range(len(xs)):
126                if self.rank == i:
127                    xs[j] = expected_tensor.cuda(device=self.rank_to_GPU[self.rank][j])
128
129                broadcast(xs, i, j)
130
131                for tensor in xs:
132                    self.assertEqual(tensor, expected_tensor)
133
134    @requires_nccl()
135    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
136    def test_sparse_allreduce_ops(self):
137        pg = self.pg
138
139        indices = torch.tensor([[0, 1]])
140        values = torch.tensor([[1, 2, 0], [4, 0, 6]])
141        sparse_tensor = torch.sparse_coo_tensor(indices, values, size=(2, 3)).to(
142            self.rank
143        )
144
145        # sparse allreduce call is wrapped in a try catch since the c10d API is only available in the nccl experimental branch
146        try:
147            tensor_list = [sparse_tensor]
148            work = pg.allreduce(tensor_list)
149            work.wait()
150
151            # tensor_list is a list of size 1, with the allreduce output as a dense tensor
152            a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank)
153            self.assertEqual(tensor_list[0], a)
154        except RuntimeError as e:
155            if "NCCL does not support all_reduce with sparse tensors" in str(e):
156                pass
157            else:
158                # Rethrow the exception if it's a different error
159                raise
160
161    @requires_nccl()
162    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
163    def test_allreduce_ops(self):
164        device_count = torch.cuda.device_count()
165        pg = self.pg
166        local_device_id = self.rank_to_GPU[self.rank][0]
167
168        def allreduce(tensors, op):
169            opts = c10d.AllreduceOptions()
170            opts.reduceOp = op
171            work = pg.allreduce(tensors, opts)
172            work.wait()
173
174        # Sum
175        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
176
177        allreduce(tensors, c10d.ReduceOp.SUM)
178
179        ndev = self.world_size
180        self.assertEqual(
181            torch.tensor([ndev * (ndev + 1) // 2]),
182            tensors[0],
183        )
184
185        # Avg (only available for NCCL 2.10+)
186        if torch.cuda.nccl.version() >= (2, 10, 0):
187            tensors = [torch.tensor([self.rank + 1.0]).cuda(local_device_id)]
188
189            allreduce(tensors, c10d.ReduceOp.AVG)
190            ndev = self.world_size
191            self.assertEqual(
192                torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]),
193                tensors[0],
194            )
195
196        # Premul Sum
197        if torch.cuda.nccl.version() >= (2, 11, 1):
198            for dtype in torch.half, torch.float, torch.double:
199                for factor in (
200                    3.0,
201                    torch.tensor([5.0], device=local_device_id, dtype=dtype),
202                ):
203                    tensors = [
204                        torch.tensor([self.rank + 1])
205                        .cuda(local_device_id)
206                        .to(dtype=dtype)
207                    ]
208
209                    allreduce(tensors, c10d._make_nccl_premul_sum(factor))
210
211                    self.assertEqual(
212                        factor
213                        * torch.tensor(
214                            [self.world_size * (self.world_size + 1) / 2],
215                            dtype=dtype,
216                            device=local_device_id,
217                        ),
218                        tensors[0],
219                    )
220
221        # Product
222        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
223
224        allreduce(tensors, c10d.ReduceOp.PRODUCT)
225        self.assertEqual(torch.tensor([math.factorial(self.world_size)]), tensors[0])
226
227        # Min
228        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
229
230        allreduce(tensors, c10d.ReduceOp.MIN)
231        self.assertEqual(torch.tensor([1]), tensors[0])
232
233        # Max
234        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
235
236        allreduce(tensors, c10d.ReduceOp.MAX)
237        self.assertEqual(torch.tensor([self.world_size]), tensors[0])
238
239        for op, err in zip(
240            (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
241            ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
242        ):
243            with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with NCCL"):
244                allreduce(tensors, op)
245
246    @requires_nccl()
247    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
248    def test_alltoall_ops_with_cudafree_race(self):
249        pg = self.pg
250        opts = c10d.AllToAllOptions()
251        local_device = f"cuda:{self.rank_to_GPU[self.rank][0]}"
252        torch.cuda.set_device(local_device)
253        input = torch.rand(1000, 1000, device=local_device)
254        output = torch.rand(1000, 1000, device=local_device)
255        race_tensors = []
256        # create some tensors to race with alltoall collective
257        for _ in range(10):
258            tmp = []
259            for i in range(5):
260                tmp.append(torch.rand(10 ** (3 + i), device=local_device))
261            race_tensors.append(tmp)
262
263        for i in range(10):
264            race_tensors.pop()
265            work = pg.alltoall_base(output, input, [], [], opts)
266            # this triggers cudaFree
267            torch.cuda.empty_cache()
268            work.wait()
269        torch.cuda.synchronize(device=local_device)
270
271    @requires_nccl()
272    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
273    def test_allreduce_in_cudagraph(self):
274        pg = self.pg
275        local_device_idx = self.rank_to_GPU[self.rank][0]
276        with torch.cuda.device(local_device_idx):
277            xs = [torch.FloatTensor([1]).cuda(local_device_idx)]
278
279            # single warmup
280            pg.allreduce(xs).wait()
281            self.assertEqual(xs[0].item(), 2)
282
283            graph = torch.cuda.CUDAGraph()
284            with torch.cuda.graph(graph):
285                pg.allreduce(xs).wait()
286            self.assertEqual(xs[0].item(), 2)
287
288            graph.replay()
289            graph.replay()
290            self.assertEqual(xs[0].item(), 8)
291
292    @requires_nccl()
293    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
294    @skipIfRocm()
295    def test_nccl_watchdog_cudagraph(self):
296        # test that the watchdog does not crash graphs with disallowed event query
297        pg = self.pg
298        rank = self.rank_to_GPU[self.rank][0]
299        with torch.cuda.device(rank):
300            for i in range(10):
301                xs = [torch.FloatTensor([1]).cuda(rank)]
302                ys = [torch.FloatTensor([4]).cuda(rank)]
303                for _ in range(30):
304                    pg.allreduce(xs[0]).wait()
305
306                graph = torch.cuda.CUDAGraph()
307                with torch.cuda.graph(graph):
308                    xs[0] += 0.0
309                    pg.allreduce(xs[0]).wait()
310                    pg.allreduce(xs[0]).wait()
311                    pg.allreduce(xs[0]).wait()
312                    xs[0] += 0.0
313
314                for _ in range(100):
315                    graph.replay()
316
317    @requires_nccl()
318    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
319    def test_reduce_ops(self):
320        pg = self.pg
321        local_device_id = self.rank_to_GPU[self.rank][0]
322
323        def reduce(xs, rootRank, rootTensor, op=None):
324            opts = c10d.ReduceOptions()
325            opts.rootRank = rootRank
326            opts.rootTensor = rootTensor
327            if op:
328                opts.reduceOp = op
329            work = pg.reduce(xs, opts)
330            work.wait()
331
332        # for every root tensor
333        for rt in range(self.world_size):
334            tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
335
336            reduce(tensors, rt, 0)
337
338            if self.rank == rt:
339                self.assertEqual(
340                    torch.tensor([self.world_size * (self.world_size + 1) // 2]),
341                    tensors[0],
342                )
343            else:
344                self.assertEqual(
345                    torch.tensor([self.rank + 1]),
346                    tensors[0],
347                )
348
349            for op, err in zip(
350                (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
351                ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
352            ):
353                with self.assertRaisesRegex(
354                    ValueError, "Cannot use " + err + " with NCCL"
355                ):
356                    reduce(tensors, self.rank, rt, op)
357
358            # Premul sum
359            if torch.cuda.nccl.version() >= (2, 11, 1):
360                for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
361                    if isinstance(factor, torch.Tensor):
362                        factor_ref = factor.cpu().item()
363                    else:
364                        factor_ref = factor
365                    float_tensors = [
366                        torch.tensor(
367                            [self.rank + 1.0], device=f"cuda:{local_device_id}"
368                        )
369                    ]
370                    float_tensors_ref = [
371                        torch.tensor(
372                            [(self.rank + 1.0) * factor_ref],
373                            device=f"cuda:{local_device_id}",
374                        )
375                    ]
376
377                    reduce(float_tensors_ref, rt, 0)
378                    reduce(float_tensors, rt, 0, c10d._make_nccl_premul_sum(factor))
379                    if self.rank == rt:
380                        self.assertEqual(float_tensors_ref[0], float_tensors[0])
381
382    @requires_nccl()
383    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
384    def test_allgather_ops(self):
385        pg = self.pg
386        local_device_ids = self.rank_to_GPU[self.rank]
387
388        def allgather(output_ts, input_ts):
389            work = pg.allgather(output_ts, input_ts)
390            return work.wait()
391
392        tensors = [torch.empty(2, 2).fill_(2).cuda(device=i) for i in local_device_ids]
393        output_tensors = []
394        expected_output = []
395
396        output_per_gpu = (
397            [torch.empty(2, 2).fill_(-1)] * len(local_device_ids) * self.world_size
398        )
399        expected_per_gpu = (
400            [torch.empty(2, 2).fill_(2)] * len(local_device_ids) * self.world_size
401        )
402
403        for gpu in local_device_ids:
404            output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
405            expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
406
407        result = allgather(output_tensors, tensors)
408
409        # Verification
410        self.assertEqual(output_tensors, expected_output)
411
412    @requires_nccl()
413    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
414    def test_allgather_base_ops(self):
415        pg = self.pg
416        local_device_id = self.rank_to_GPU[self.rank][0]
417
418        def allgather_base(output_t, input_t):
419            work = pg._allgather_base(output_t, input_t)
420            work.wait()
421
422        # allgather_base is GPU number agnostic.
423        # Each rank contribute one tensor regardless of GPU counts
424        tensor = torch.tensor([self.rank]).cuda(local_device_id)
425        output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
426            local_device_id
427        )
428
429        allgather_base(output_t, tensor)
430
431        # Verification
432        self.assertEqual(torch.arange(self.world_size), output_t)
433
434    @requires_nccl()
435    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
436    def test_allgather_base_basics(self):
437        pg = self.pg
438        local_device_id = self.rank_to_GPU[self.rank][0]
439
440        def allgather_base(output_t, input_t):
441            work = pg._allgather_base(output_t, input_t)
442            work.wait()
443
444        # anticipate an error
445        with self.assertRaisesRegex(
446            ValueError,
447            "output tensor size must be equal to world_size times input tensor size",
448        ):
449            tensor = torch.tensor([self.rank]).cuda(local_device_id)
450            output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).cuda(
451                local_device_id
452            )
453            # fails the check because output_t is not correctly sized
454            allgather_base(output_t, tensor)
455
456        # anticipate an error
457        with self.assertRaisesRegex(
458            TypeError, "output tensor must have the same type as input tensor"
459        ):
460            tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
461            output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
462                local_device_id
463            )
464            # fails the check because the dtype is different
465            allgather_base(output_t, tensor)
466
467    @requires_nccl()
468    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
469    def test_gather_ops(self):
470        pg = self.pg
471        local_device_ids = self.rank_to_GPU[self.rank]
472        num_gpus = len(local_device_ids)
473
474        def gather(output_t, input_t, rootRank):
475            opts = c10d.GatherOptions()
476            opts.rootRank = rootRank
477            if rootRank == self.rank:
478                work = pg.gather(output_t, input_t, opts)
479            else:
480                work = pg.gather([], input_t, opts)
481            work.wait()
482
483        # init input
484        tensors = []
485        for device_id in local_device_ids:
486            tensors.append(torch.tensor([self.rank]).cuda(device_id))
487
488        # init output
489        output_ts = []
490        for idx in range(num_gpus):
491            gpu_idx = local_device_ids[idx]
492            output_ts.append([])
493            for rank in range(self.world_size):
494                output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx))
495
496        expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
497        for rank in range(self.world_size):
498            gather(output_ts, tensors, rank)
499            if rank == self.rank:
500                self.assertEqual(expected, output_ts)
501
502    @requires_nccl()
503    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
504    def test_gather_stress(self):
505        pg = self.pg
506        local_device_ids = self.rank_to_GPU[self.rank]
507        num_gpus = len(local_device_ids)
508
509        def gather(output_t, input_t, rootRank):
510            opts = c10d.GatherOptions()
511            opts.rootRank = rootRank
512            if rootRank == self.rank:
513                work = pg.gather(output_t, input_t, opts)
514            else:
515                work = pg.gather([], input_t, opts)
516            work.wait()
517
518        stress_length = 1000
519
520        # init input
521        tensors = []
522        for i in range(stress_length):
523            tensors.append([])
524            for device_id in local_device_ids:
525                tensors[i].append(torch.tensor([self.rank]).cuda(device_id))
526
527        # init output
528        output_ts = []
529        for i in range(stress_length):
530            output_ts.append([[] for _ in range(num_gpus)])
531            for idx, ls in enumerate(output_ts[i]):
532                gpu_idx = local_device_ids[idx]
533                for _ in range(self.world_size):
534                    ls.append(torch.tensor([-1]).cuda(gpu_idx))
535
536        expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
537        for i in range(stress_length):
538            for rank in range(self.world_size):
539                gather(output_ts[i], tensors[i], rank)
540                # Verification
541                if rank == self.rank:
542                    self.assertEqual(output_ts[i], expected)
543
544    @requires_nccl()
545    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
546    def test_gather_checks(self):
547        pg = self.pg
548        device_id = self.rank_to_GPU[self.rank][0]
549
550        # init input
551        tensor = torch.tensor([self.rank]).cuda(device_id)
552
553        # init output
554        output_ts = []
555        for rank in range(self.world_size):
556            output_ts.append(torch.tensor([-1]).cuda(device_id))
557
558        with self.assertRaisesRegex(ValueError, "invalid root rank"):
559            opts = c10d.GatherOptions()
560            opts.rootRank = -1
561            pg.gather([output_ts], [tensor], opts)
562
563        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
564            pg.gather([output_ts], [tensor], 0)
565
566        with self.assertRaisesRegex(ValueError, "invalid root rank"):
567            opts = c10d.GatherOptions()
568            opts.rootRank = self.world_size
569            pg.gather([output_ts], [tensor], opts)
570
571        with self.assertRaisesRegex(
572            # throws error message from dispatcher
573            RuntimeError,
574            "There were no tensor arguments to this function",
575        ):
576            opts = c10d.GatherOptions()
577            opts.rootRank = 0
578            pg.gather([output_ts], [], opts)
579
580    @requires_nccl()
581    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
582    def test_scatter_ops(self):
583        pg = self.pg
584        local_device_ids = self.rank_to_GPU[self.rank]
585        num_gpus = len(local_device_ids)
586
587        def scatter(output_t, input_t, rootRank):
588            opts = c10d.ScatterOptions()
589            opts.rootRank = rootRank
590            if rootRank == self.rank:
591                work = pg.scatter(output_t, input_t, opts)
592            else:
593                work = pg.scatter(output_t, [], opts)
594            work.wait()
595
596        # init output
597        tensors = []
598        for device_id in local_device_ids:
599            tensors.append(torch.tensor([-1]).cuda(device_id))
600
601        # init input
602        scatter_list = []
603        for idx in range(num_gpus):
604            gpu_idx = local_device_ids[idx]
605            scatter_list.append([])
606            for rank in range(self.world_size):
607                scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
608
609        # test each rank to scatter
610        expected = [torch.tensor([self.rank])]
611        for rank in range(self.world_size):
612            scatter(tensors, scatter_list, rank)
613            self.assertEqual(expected, tensors)
614
615    @requires_nccl()
616    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
617    def test_scatter_stress(self):
618        pg = self.pg
619        local_device_ids = self.rank_to_GPU[self.rank]
620        num_gpus = len(local_device_ids)
621
622        def scatter(output_t, input_t, rootRank):
623            opts = c10d.ScatterOptions()
624            opts.rootRank = rootRank
625            if rootRank == self.rank:
626                work = pg.scatter(output_t, input_t, opts)
627            else:
628                work = pg.scatter(output_t, [], opts)
629            work.wait()
630
631        stress_length = 1000
632
633        # init output
634        tensors = []
635        for i in range(stress_length):
636            tensors.append([])
637            for device_id in local_device_ids:
638                tensors[i].append(torch.tensor([-1]).cuda(device_id))
639
640        # init input
641        scatter_list = []
642        for i in range(stress_length):
643            scatter_list.append([[] for _ in range(num_gpus)])
644            for idx, ls in enumerate(scatter_list[i]):
645                gpu_idx = local_device_ids[idx]
646                for rank in range(self.world_size):
647                    ls.append(torch.tensor([rank]).cuda(gpu_idx))
648
649        # test each rank to scatter
650        expected = [torch.tensor([self.rank])]
651        for i in range(stress_length):
652            for rank in range(self.world_size):
653                scatter(tensors[i], scatter_list[i], rank)
654                # Verification
655                self.assertEqual(tensors[i], expected)
656
657    @requires_nccl()
658    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
659    def test_scatter_checks(self):
660        pg = self.pg
661        local_device_ids = self.rank_to_GPU[self.rank]
662        num_gpus = len(local_device_ids)
663
664        # init output
665        tensors = []
666        for device_id in local_device_ids:
667            tensors.append(torch.tensor([-1]).cuda(device_id))
668
669        # init input
670        scatter_list = []
671        for idx in range(num_gpus):
672            gpu_idx = local_device_ids[idx]
673            scatter_list.append([])
674            for rank in range(self.world_size):
675                scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
676
677        with self.assertRaisesRegex(ValueError, "invalid root rank"):
678            opts = c10d.ScatterOptions()
679            opts.rootRank = -1
680            pg.scatter(tensors, scatter_list, opts)
681
682        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
683            pg.scatter(tensors, scatter_list, 0)
684
685        with self.assertRaisesRegex(ValueError, "invalid root rank"):
686            opts = c10d.ScatterOptions()
687            opts.rootRank = self.world_size
688            pg.scatter(tensors, scatter_list, opts)
689
690        with self.assertRaisesRegex(
691            # throws error message from dispatcher
692            RuntimeError,
693            "There were no tensor arguments to this function",
694        ):
695            opts = c10d.ScatterOptions()
696            opts.rootRank = 0
697            pg.scatter([], scatter_list, opts)
698
699    @requires_nccl()
700    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
701    def test_reduce_scatter_base_basics(self):
702        pg = self.pg
703        local_device_id = self.rank_to_GPU[self.rank][0]
704
705        def reduce_scatter_base(output_t, input_t):
706            work = pg._reduce_scatter_base(output_t, input_t)
707            work.wait()
708
709        # anticipate an error
710        with self.assertRaisesRegex(
711            ValueError,
712            "input tensor must be the same size as output size times world size",
713        ):
714            input_t = torch.tensor([self.rank]).cuda(local_device_id)
715            output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).cuda(
716                local_device_id
717            )
718            # fails the check because output_t is not correctly sized
719            reduce_scatter_base(output_t, input_t)
720
721        # anticipate an error
722        with self.assertRaisesRegex(
723            TypeError, "input tensor must be the same type as the output tensor."
724        ):
725            tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
726            output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
727                local_device_id
728            )
729            # fails the check because the dtype is different
730            reduce_scatter_base(output_t, tensor)
731
732    @requires_nccl()
733    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
734    def test_reduce_scatter_ops(self):
735        pg = self.pg
736        local_device_ids = self.rank_to_GPU[self.rank]
737        num_gpus = len(local_device_ids)
738
739        def reduce_scatter(outputs, input_lists, op):
740            opts = c10d.ReduceScatterOptions()
741            opts.reduceOp = op
742            work = pg.reduce_scatter(outputs, input_lists, opts)
743            work.wait()
744
745        output = [torch.tensor([0]).cuda(i) for i in local_device_ids]
746
747        #  GPU/rank
748        #   0         [1], [2], [3], [4]
749        #   1         [2], [3], [4], [5]
750        #   2         [3], [4], [5], [6]
751        #   3         [4], [5], [6], [7]
752
753        # Sum
754        tensor_lists = []
755        input_per_gpu = []
756
757        for i in range(self.world_size):
758            input_per_gpu.append(torch.tensor([self.rank + i + 1]))
759
760        for gpu in local_device_ids:
761            tensor_lists.append([t.cuda(device=gpu) for t in input_per_gpu])
762
763        reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM)
764
765        for i in range(num_gpus):
766            expected = torch.tensor(
767                [
768                    (1 + self.world_size) * self.world_size // 2
769                    + self.world_size * self.rank
770                ]
771            )
772
773            self.assertEqual(expected, output[i])
774
775        # Min
776        reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN)
777
778        for i in range(num_gpus):
779            expected = torch.tensor([self.rank + 1 + i])
780            self.assertEqual(expected, output[i])
781
782        # Max
783        reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX)
784
785        for i in range(num_gpus):
786            expected = torch.tensor([self.rank + self.world_size + i])
787            self.assertEqual(expected, output[i])
788
789        # Product
790        reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT)
791
792        # math package don't have math.perm until python 3.8, so
793        # we implement a naive version here.
794        def perm(n, k):
795            prod_val = n
796            for val in range(n - k + 1, n):
797                prod_val *= val
798            return prod_val
799
800        for i in range(num_gpus):
801            prod_val = perm(self.rank + self.world_size, self.world_size)
802
803            expected = torch.tensor([prod_val])
804            self.assertEqual(expected, output[i])
805
806        # Test the input params overridden scenarios, aka, when the input is
807        # a list and output is just one tensor.
808        # Sum
809        output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank)
810        input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu]
811        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
812        expected = torch.tensor(
813            (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank
814        )
815        self.assertEqual(expected, output_tensor)
816
817        # Min
818        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
819        expected = torch.tensor(self.rank + 1)
820        self.assertEqual(expected, output_tensor)
821
822        # Max
823        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
824        expected = torch.tensor(self.rank + self.world_size)
825        self.assertEqual(expected, output_tensor)
826
827        # Product
828        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait()
829        prod_val = self.rank + 1
830        for k in range(1, self.world_size):
831            prod_val = prod_val * (self.rank + 1 + k)
832        expected = torch.tensor(prod_val)
833        self.assertEqual(expected, output_tensor)
834
835        if torch.cuda.nccl.version() >= (2, 11, 1):
836            for factor in (3.0, torch.tensor([5.0], device=self.rank)):
837                if isinstance(factor, torch.Tensor):
838                    factor_ref = factor.cpu().item()
839                else:
840                    factor_ref = factor
841                output = [t.float() for t in output]
842                tensor_lists = [[t.float() for t in tl] for tl in tensor_lists]
843                output_ref = [t.float() for t in output]
844                tensor_lists_ref = [
845                    [t.float() * factor_ref for t in tl] for tl in tensor_lists
846                ]
847                reduce_scatter(output, tensor_lists, c10d._make_nccl_premul_sum(factor))
848                reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM)
849                self.assertEqual(output_ref, output)
850
851    @requires_nccl()
852    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
853    def test_reduce_scatter_base_ops(self):
854        pg = self.pg
855        local_device_id = self.rank_to_GPU[self.rank][0]
856
857        def reduce_scatter_base(output_t, input_t):
858            work = pg._reduce_scatter_base(output_t, input_t)
859            work.wait()
860
861        # reduce_scatter_base is GPU number agnostic.
862        # Each rank contribute one tensor regardless of GPU counts
863        output_t = torch.empty([1]).cuda(local_device_id)
864        tensor = torch.arange(self.world_size, dtype=output_t.dtype).cuda(
865            local_device_id
866        )
867
868        reduce_scatter_base(output_t, tensor)
869
870        # Verification
871        self.assertEqual(output_t[0], self.rank * self.world_size)
872
873    @requires_nccl()
874    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
875    def test_barrier(self):
876        pg = self.pg
877        local_device_ids = self.rank_to_GPU[self.rank]
878
879        def allreduce(tensors):
880            opts = c10d.AllreduceOptions()
881            work = pg.allreduce(tensors, opts)
882            return work
883
884        # Making the collective to operate on
885        # 1, 2, 3, 4, .... len(local_device_ids) GPUs
886        tensors_list = [[] for _ in range(len(local_device_ids))]
887
888        for i in range(1, len(local_device_ids) + 1):
889            for j in range(i):
890                tensors_list[i - 1].append(
891                    torch.tensor([j + 1]).cuda(local_device_ids[j])
892                )
893
894        works = []
895        for tensors in tensors_list:
896            work = allreduce(tensors)
897            works.append(work)
898
899        # Barrier will ensure that all previous work is completed
900        pg.barrier().wait()
901
902        for i in range(1, len(local_device_ids) + 1):
903            for j in range(i):
904                self.assertEqual(
905                    torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j]
906                )
907
908    @requires_nccl()
909    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
910    def test_send_recv(self):
911        pg = self.pg
912        device = self.rank_to_GPU[self.rank][0]
913
914        # Generate the same random tensor
915        torch.manual_seed(0)
916        send_tensor = torch.rand(10, 10, device=device)
917        if self.rank == 0:
918            dist.send(send_tensor, 1)
919        if self.rank == 1:
920            recv_tensor = torch.rand(10, 10, device=device)
921            dist.recv(recv_tensor, 0)
922            self.assertEqual(send_tensor, recv_tensor)
923
924    @requires_nccl()
925    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
926    def test_send_recv_complex(self):
927        pg = self.pg
928        device = self.rank_to_GPU[self.rank][0]
929
930        # Generate the same random tensor
931        torch.manual_seed(0)
932        send_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
933        if self.rank == 0:
934            dist.send(send_tensor, 1)
935        if self.rank == 1:
936            recv_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
937            dist.recv(recv_tensor, 0)
938            self.assertEqual(send_tensor, recv_tensor)
939
940    @requires_nccl()
941    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
942    def test_send_recv_object_list(self):
943        device = self.rank_to_GPU[self.rank][0]
944
945        val = 99 if self.rank == 0 else None
946        object_list = [val] * self.world_size
947        if self.rank == 0:
948            dist.send_object_list(object_list, 1, device=device)
949        if self.rank == 1:
950            dist.recv_object_list(object_list, 0, device=device)
951            self.assertEqual(object_list[0], 99)
952
953    @requires_nccl()
954    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
955    def test_tensor_register_hook(self):
956        os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "1"
957
958        pg = self.pg
959        local_device_id = self.rank_to_GPU[self.rank][0]
960
961        def allgather_base(output_t, input_t):
962            work = pg._allgather_base(output_t, input_t)
963            work.wait()
964
965        # allgather_base is GPU number agnostic.
966        # Each rank contribute one tensor regardless of GPU counts
967        tensor = torch.tensor([self.rank]).cuda(local_device_id)
968        output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
969            local_device_id
970        )
971
972        allgather_base(output_t, tensor)
973
974        # Verification
975        self.assertEqual(torch.arange(self.world_size), output_t)
976
977        # Unset env
978        del os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"]
979
980
981if __name__ == "__main__":
982    rank = int(os.getenv("RANK", -1))
983    world_size = int(os.getenv("WORLD_SIZE", 2))
984
985    if rank != -1:
986        # Launched with torchrun or other multi-proc launchers. Directly run the test.
987        ProcessGroupNCCLOpTest.run_rank(rank, world_size)
988    else:
989        # Launched as a single process. Spawn subprocess to run the tests.
990        # Also need a rendezvous file for `init_process_group` purpose.
991        rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
992        torch.multiprocessing.spawn(
993            ProcessGroupNCCLOpTest.run_rank,
994            nprocs=world_size,
995            args=(world_size, rdvz_file),
996        )
997