xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_gloo.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import logging
5import math
6import operator
7import os
8import random
9import sys
10import tempfile
11from datetime import timedelta
12from functools import reduce
13from itertools import groupby
14
15import torch
16import torch.distributed as c10d
17
18
19if not c10d.is_available() or not c10d.is_gloo_available():
20    print("c10d GLOO not available, skipping tests", file=sys.stderr)
21    sys.exit(0)
22
23import test_c10d_common
24from test_c10d_common import (
25    gpus_for_rank,
26    LOOPBACK,
27    ModuleForDdpCommHook,
28    SparseGradientModule,
29    Task,
30)
31
32import torch.distributed as dist
33import torch.nn.functional as F
34import torch.testing._internal.common_utils as common
35from torch import nn
36from torch.distributed._shard.sharded_tensor import (
37    init_from_local_shards,
38    Shard,
39    ShardedTensor,
40    ShardMetadata,
41)
42from torch.nn.parallel import DistributedDataParallel
43from torch.testing._internal.common_distributed import (
44    create_device,
45    MultiProcessTestCase,
46    requires_gloo,
47    simple_sparse_reduce_tests,
48    skip_if_lt_x_gpu,
49    verify_ddp_error_logged,
50)
51from torch.testing._internal.common_utils import (
52    retry_on_connect_failures,
53    run_tests,
54    skip_but_pass_in_sandcastle,
55    TestCase,
56)
57
58
59def simple_reduce_tests(rank, world_size):
60    tests = [
61        (
62            c10d.ReduceOp.SUM,
63            torch.tensor([rank + 1.0]),
64            torch.tensor([float(world_size * (world_size + 1) / 2)]),
65        ),
66        (
67            c10d.ReduceOp.PRODUCT,
68            torch.tensor([rank + 1.0]),
69            torch.tensor([float(math.factorial(world_size))]),
70        ),
71        (
72            c10d.ReduceOp.MIN,
73            torch.tensor([rank + 1.0]),
74            torch.tensor([1.0]),
75        ),
76        (
77            c10d.ReduceOp.MAX,
78            torch.tensor([rank + 1.0]),
79            torch.tensor([float(world_size)]),
80        ),
81    ]
82
83    # Generate tests for BAND.
84    # The bit that is set changes in every iteration to check
85    # that the output changes accordingly.
86    for i in range(4):
87        vin = rank | (1 << i)
88        vout = 1 << i
89        tests.append(
90            (
91                c10d.ReduceOp.BAND,
92                torch.tensor([vin], dtype=torch.int32),
93                torch.tensor([vout], dtype=torch.int32),
94            ),
95        )
96
97    # Generate tests for BOR.
98    # These emulate a larger world size per iteration by having every
99    # rank contribute multiple values that are pre-OR'ed.
100    for i in range(1, 5):
101        vin = reduce(operator.or_, [rank * i + j for j in range(i)])
102        vout = reduce(operator.or_, range(world_size * i))
103        tests.append(
104            (
105                c10d.ReduceOp.BOR,
106                torch.tensor([vin], dtype=torch.int32),
107                torch.tensor([vout], dtype=torch.int32),
108            ),
109        )
110
111    # Generate tests for XOR.
112    # These emulate a larger world size per iteration by having every
113    # rank contribute multiple values that are pre-XOR'ed.
114    for i in range(1, 5):
115        vin = reduce(operator.xor, [rank * i + j for j in range(i)])
116        vout = reduce(operator.xor, range(world_size * i))
117        tests.append(
118            (
119                c10d.ReduceOp.BXOR,
120                torch.tensor([vin], dtype=torch.int32),
121                torch.tensor([vout], dtype=torch.int32),
122            ),
123        )
124
125    return tests
126
127
128def simple_coalesced_reduce_tests(rank, world_size):
129    return [
130        (
131            c10d.ReduceOp.SUM,
132            [torch.tensor([rank + 1.0]), torch.tensor([(rank + 1.0) ** 2])],
133            [
134                torch.tensor([float(world_size * (world_size + 1) / 2)]),
135                torch.tensor(
136                    [float(world_size * (world_size + 1) * (2 * world_size + 1) / 6)]
137                ),
138            ],
139        ),
140        (
141            c10d.ReduceOp.PRODUCT,
142            [torch.tensor([rank + 1.0]), torch.tensor([rank + 2.0])],
143            [
144                torch.tensor([float(math.factorial(world_size))]),
145                torch.tensor([float(math.factorial(world_size + 1))]),
146            ],
147        ),
148        (
149            c10d.ReduceOp.MIN,
150            [torch.tensor([rank + x]) for x in [0.0, 1.0]],
151            [torch.tensor([0.0]), torch.tensor([1.0])],
152        ),
153        (
154            c10d.ReduceOp.MAX,
155            [torch.tensor([rank + x]) for x in [1.0, 2.0]],
156            [torch.tensor([float(world_size)]), torch.tensor([world_size + 1.0])],
157        ),
158    ]
159
160
161def simple_multi_input_reduce_tests(rank, world_size):
162    return [
163        (
164            c10d.ReduceOp.SUM,
165            [torch.tensor([2 * rank + 0.0]), torch.tensor([2 * rank + 1.0])],
166            torch.tensor([float(world_size * (2 * world_size - 1))]),
167        ),
168        (
169            c10d.ReduceOp.PRODUCT,
170            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
171            torch.tensor([float(math.factorial(2 * world_size))]),
172        ),
173        (
174            c10d.ReduceOp.MIN,
175            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
176            torch.tensor([1.0]),
177        ),
178        (
179            c10d.ReduceOp.MAX,
180            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
181            torch.tensor([2.0 * world_size]),
182        ),
183    ]
184
185
186class RendezvousEnvTest(TestCase):
187    @requires_gloo()
188    @retry_on_connect_failures
189    def test_logging_init(self):
190        os.environ["WORLD_SIZE"] = "1"
191        os.environ["MASTER_ADDR"] = "127.0.0.1"
192        os.environ["MASTER_PORT"] = str(common.find_free_port())
193        os.environ["RANK"] = "0"
194
195        previous_handlers = logging.root.handlers
196
197        c10d.init_process_group(backend="gloo", init_method="env://")
198
199        current_handlers = logging.root.handlers
200        self.assertEqual(len(previous_handlers), len(current_handlers))
201        for current, previous in zip(current_handlers, previous_handlers):
202            self.assertEqual(current, previous)
203
204        c10d.destroy_process_group()
205
206
207class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
208    @requires_gloo()
209    @retry_on_connect_failures
210    def test_default_store_timeout_gloo(self):
211        self._test_default_store_timeout("gloo")
212
213
214class ProcessGroupGlooTest(MultiProcessTestCase):
215    def _create_process_group_gloo(self, store, rank, world_size, opts):
216        pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts)
217        dist.barrier(group=pg)
218        return pg
219
220    def setUp(self):
221        super().setUp()
222        self._spawn_processes()
223
224    def opts(self, threads=2):
225        opts = c10d.ProcessGroupGloo._Options()
226        opts._timeout = 50.0
227        opts._devices = [create_device(interface=LOOPBACK)]
228        opts._threads = threads
229        return opts
230
231    @requires_gloo()
232    def test_multi_device_constructor(self):
233        store = c10d.FileStore(self.file_name, self.world_size)
234        opts = c10d.ProcessGroupGloo._Options()
235        opts._timeout = 5.0
236        opts._devices = [
237            create_device(interface=LOOPBACK),
238            create_device(interface=LOOPBACK),
239        ]
240        pg = self._create_process_group_gloo(store, self.rank, self.world_size, opts)
241
242        # Execute 2x the number of operations to ensure we use every device.
243        for fut in [pg.allreduce(torch.ones(i + 1)).get_future() for i in range(4)]:
244            fut.wait()
245
246    @requires_gloo()
247    def test_empty_tensors(self):
248        store = c10d.FileStore(self.file_name, self.world_size)
249        pg = self._create_process_group_gloo(
250            store, self.rank, self.world_size, self.opts()
251        )
252
253        xs = [torch.FloatTensor([])]
254        fut = pg.broadcast(xs).get_future()
255        fut.wait()
256        output = fut.value()
257        self.assertEqual(0, output[0].numel())
258        self.assertEqual(xs[0], output[0])
259
260    @requires_gloo()
261    def test_broadcast_checks(self):
262        store = c10d.FileStore(self.file_name, self.world_size)
263        pg = self._create_process_group_gloo(
264            store, self.rank, self.world_size, self.opts()
265        )
266
267        t1 = torch.zeros([1], dtype=torch.float32)
268        t2 = torch.zeros([1], dtype=torch.float64)
269        t3 = torch.zeros([2], dtype=torch.float32)
270
271        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
272            opts = c10d.BroadcastOptions()
273            opts.rootRank = -1
274            opts.rootTensor = 0
275            pg.broadcast([t1], opts)
276
277        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
278            opts = c10d.BroadcastOptions()
279            opts.rootRank = self.world_size
280            opts.rootTensor = 0
281            pg.broadcast([t1], opts)
282
283        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
284            opts = c10d.BroadcastOptions()
285            opts.rootRank = self.rank
286            opts.rootTensor = -1
287            pg.broadcast([t1], opts)
288
289        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
290            opts = c10d.BroadcastOptions()
291            opts.rootRank = self.rank
292            opts.rootTensor = 1
293            pg.broadcast([t1], opts)
294
295        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
296            opts = c10d.BroadcastOptions()
297            opts.rootRank = self.rank
298            opts.rootTensor = 0
299            pg.broadcast([], opts)
300
301        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
302            opts = c10d.BroadcastOptions()
303            opts.rootRank = self.rank
304            opts.rootTensor = 0
305            pg.broadcast([t1, t2], opts)
306
307        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
308            opts = c10d.BroadcastOptions()
309            opts.rootRank = self.rank
310            opts.rootTensor = 0
311            pg.broadcast([t1, t3], opts)
312
313    def _test_broadcast_basics(self, fn):
314        store = c10d.FileStore(self.file_name, self.world_size)
315        pg = self._create_process_group_gloo(
316            store, self.rank, self.world_size, self.opts()
317        )
318
319        def broadcast(xs, rootRank, rootTensor):
320            opts = c10d.BroadcastOptions()
321            opts.rootRank = rootRank
322            opts.rootTensor = rootTensor
323            fut = pg.broadcast(xs, opts).get_future()
324            fut.wait()
325            return fut.value()
326
327        # Every rank is root once
328        for i in range(self.world_size):
329            # Run with 1 input tensor
330            x = fn(torch.tensor([self.rank]))
331            output = broadcast([x], i, 0)
332            self.assertEqual(torch.tensor([i]), output[0])
333
334            # Run with 2 input tensors
335            num = 2
336            for j in range(num):
337                xs = [
338                    fn(torch.tensor([self.rank * num + 0.0])),
339                    fn(torch.tensor([self.rank * num + 1.0])),
340                ]
341
342                output = broadcast(xs, i, j)
343                self.assertEqual(
344                    torch.tensor([i * num + j], dtype=torch.float32), output[0]
345                )
346                self.assertEqual(
347                    torch.tensor([i * num + j], dtype=torch.float32), output[1]
348                )
349
350        # Test overloaded convenience function
351        x = torch.tensor([self.rank + 1.0])
352        fut = pg.broadcast(x, root=0).get_future()
353        fut.wait()
354        result = fut.value()
355        self.assertEqual(torch.tensor([1.0]), result[0])
356
357    @requires_gloo()
358    def test_broadcast_basics(self):
359        self._test_broadcast_basics(lambda t: t.clone())
360
361    @skip_if_lt_x_gpu(2)
362    @requires_gloo()
363    def test_broadcast_basics_cuda(self):
364        self._test_broadcast_basics(lambda t: t.clone().cuda())
365
366    def _test_broadcast_stress(self, inputs):
367        store = c10d.FileStore(self.file_name, self.world_size)
368        pg = self._create_process_group_gloo(
369            store, self.rank, self.world_size, self.opts(threads=8)
370        )
371        work_handles = [
372            pg.broadcast(inputs[i], root=(i % self.world_size))
373            for i in range(len(inputs))
374        ]
375        for i, work_handle in enumerate(work_handles):
376            work_handle.wait()
377            self.assertEqual(
378                torch.tensor([(i * self.world_size) + (i % self.world_size)]),
379                inputs[i],
380                msg=("Mismatch in iteration %d" % i),
381            )
382
383    @requires_gloo()
384    def test_broadcast_stress(self):
385        inputs = [torch.tensor([i * self.world_size + self.rank]) for i in range(1000)]
386        self._test_broadcast_stress(inputs)
387
388    @skip_if_lt_x_gpu(2)
389    @requires_gloo()
390    def test_broadcast_stress_cuda(self):
391        inputs = [
392            torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000)
393        ]
394        self._test_broadcast_stress(inputs)
395
396    @requires_gloo()
397    def test_allreduce_checks(self):
398        store = c10d.FileStore(self.file_name, self.world_size)
399        pg = self._create_process_group_gloo(
400            store, self.rank, self.world_size, self.opts()
401        )
402
403        t1 = torch.zeros([1], dtype=torch.float32)
404        t2 = torch.zeros([1], dtype=torch.float64)
405        t3 = torch.zeros([2], dtype=torch.float32)
406
407        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
408            opts = c10d.AllreduceOptions()
409            pg.allreduce([], opts)
410
411        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
412            opts = c10d.AllreduceOptions()
413            pg.allreduce([t1, t2], opts)
414
415        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
416            opts = c10d.AllreduceOptions()
417            pg.allreduce([t1, t3], opts)
418
419    def _test_allreduce_basics(self, fn):
420        store = c10d.FileStore(self.file_name, self.world_size)
421        pg = self._create_process_group_gloo(
422            store, self.rank, self.world_size, self.opts()
423        )
424
425        # Single input tests
426        tests = simple_reduce_tests(self.rank, self.world_size)
427        for op, input, expected in tests:
428            opts = c10d.AllreduceOptions()
429            opts.reduceOp = op
430            tensor = fn(input)
431            fut = pg.allreduce([tensor], opts).get_future()
432            fut.wait()
433            result = fut.value()
434            self.assertEqual(expected, result[0])
435
436        # Multi input tests
437        tests = simple_multi_input_reduce_tests(self.rank, self.world_size)
438        for op, inputs, output in tests:
439            opts = c10d.AllreduceOptions()
440            opts.reduceOp = op
441            tensors = [fn(input) for input in inputs]
442            fut = pg.allreduce(tensors, opts).get_future()
443            fut.wait()
444            result = fut.value()
445            for tensor in result:
446                self.assertEqual(output, tensor)
447
448        # Test overloaded convenience function (defaults to using sum)
449        x = fn(torch.tensor([self.rank + 1.0]))
450        fut = pg.allreduce(x).get_future()
451        fut.wait()
452        result = fut.value()
453        self.assertEqual(
454            torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
455            result[0],
456        )
457
458    @requires_gloo()
459    def test_allreduce_basics(self):
460        self._test_allreduce_basics(lambda t: t.clone())
461
462    @skip_if_lt_x_gpu(2)
463    @requires_gloo()
464    def test_allreduce_basics_cuda(self):
465        self._test_allreduce_basics(lambda t: t.clone().cuda())
466
467    def _test_allreduce_stress(self, inputs):
468        store = c10d.FileStore(self.file_name, self.world_size)
469        pg = self._create_process_group_gloo(
470            store, self.rank, self.world_size, self.opts(threads=8)
471        )
472        future_handles = [
473            pg.allreduce(inputs[i]).get_future() for i in range(len(inputs))
474        ]
475        for i, future_handle in enumerate(future_handles):
476            future_handle.wait()
477            self.assertEqual(
478                torch.tensor(
479                    [
480                        (i * self.world_size)
481                        + (self.world_size * (self.world_size - 1) // 2)
482                    ]
483                ),
484                future_handle.value()[0],
485                msg=("Mismatch in iteration %d" % i),
486            )
487
488    @requires_gloo()
489    def test_allreduce_stress(self):
490        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
491        self._test_allreduce_stress(inputs)
492
493    @skip_if_lt_x_gpu(2)
494    @requires_gloo()
495    def test_allreduce_stress_cuda(self):
496        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
497        self._test_allreduce_stress(inputs)
498
499    @requires_gloo()
500    def test_allreduce_coalesced_checks(self):
501        store = c10d.FileStore(self.file_name, self.world_size)
502        pg = self._create_process_group_gloo(
503            store, self.rank, self.world_size, self.opts()
504        )
505
506        t1 = torch.zeros(1, dtype=torch.float32)
507        t2 = torch.zeros(1, dtype=torch.float64)
508        t3 = torch.sparse_coo_tensor([[0]], [1], size=(1,))
509
510        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
511            opts = c10d.AllreduceCoalescedOptions()
512            pg.allreduce_coalesced([], opts)
513
514        with self.assertRaisesRegex(
515            RuntimeError, "tensors must all have the same type"
516        ):
517            opts = c10d.AllreduceCoalescedOptions()
518            pg.allreduce_coalesced([t1, t2], opts)
519
520        with self.assertRaisesRegex(RuntimeError, "invalid tensor layout at index"):
521            opts = c10d.AllreduceCoalescedOptions()
522            pg.allreduce_coalesced([t1, t3], opts)
523
524        with self.assertRaisesRegex(RuntimeError, "unsupported layout"):
525            opts = c10d.AllreduceCoalescedOptions()
526            pg.allreduce_coalesced([t3, t3.clone()], opts)
527
528    @skip_if_lt_x_gpu(1)
529    @requires_gloo()
530    def test_allreduce_coalesced_checks_cuda(self):
531        store = c10d.FileStore(self.file_name, self.world_size)
532        pg = self._create_process_group_gloo(
533            store, self.rank, self.world_size, self.opts()
534        )
535
536        t1 = torch.zeros(1, dtype=torch.float32)
537
538        with self.assertRaisesRegex(RuntimeError, "unsupported device type"):
539            opts = c10d.AllreduceCoalescedOptions()
540            pg.allreduce_coalesced([t1.cuda(), t1.cuda()], opts)
541
542    def _test_allreduce_coalesced_basics(self, fn):
543        store = c10d.FileStore(self.file_name, self.world_size)
544        pg = self._create_process_group_gloo(
545            store, self.rank, self.world_size, self.opts()
546        )
547
548        test_cases = simple_coalesced_reduce_tests(self.rank, self.world_size)
549        for op, inputs, outputs in test_cases:
550            opts = c10d.AllreduceCoalescedOptions()
551            opts.reduceOp = op
552            tensors = [fn(x) for x in inputs]
553            fut = pg.allreduce_coalesced(tensors, opts).get_future()
554            fut.wait()
555            result = fut.value()
556            for result_tensor, expected in zip(result, outputs):
557                self.assertEqual(result_tensor, expected)
558
559    @requires_gloo()
560    def test_allreduce_coalesced_basics(self):
561        self._test_allreduce_coalesced_basics(lambda t: t.clone())
562
563    def _expected_output(self, i):
564        ws = self.world_size
565        return 2 * [torch.tensor([(i * ws) + (ws * (ws - 1) // 2)])]
566
567    def _test_allreduce_coalesced_stress(self, inputs):
568        store = c10d.FileStore(self.file_name, self.world_size)
569        pg = self._create_process_group_gloo(
570            store, self.rank, self.world_size, self.opts(threads=8)
571        )
572        future_handles = [
573            pg.allreduce_coalesced(input).get_future() for input in inputs
574        ]
575        for i, future_handle in enumerate(future_handles):
576            future_handle.wait()
577            result = future_handle.value()
578            self.assertEqual(
579                self._expected_output(i),
580                result,
581                msg=f"Mismatch in iteration {i}",
582            )
583
584    @requires_gloo()
585    def test_allreduce_coalesced_stress(self):
586        inputs = [2 * [torch.tensor([i + self.rank])] for i in range(1000)]
587        self._test_allreduce_coalesced_stress(inputs)
588
589    @requires_gloo()
590    def test_allreduce_coalesced_async(self):
591        store = c10d.FileStore(self.file_name, self.world_size)
592        c10d.init_process_group(
593            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
594        )
595
596        xs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]
597        futs = [c10d.all_reduce_coalesced(x, async_op=True) for x in xs]
598        torch.futures.wait_all(futs)
599        for i, fut in enumerate(futs):
600            self.assertEqual(
601                self._expected_output(i),
602                fut.wait(),
603                msg=f"Mismatch in iteration {i}",
604            )
605
606    @requires_gloo()
607    def test_sparse_allreduce_checks(self):
608        store = c10d.FileStore(self.file_name, self.world_size)
609        pg = self._create_process_group_gloo(
610            store, self.rank, self.world_size, self.opts()
611        )
612
613        t1 = torch.zeros([1])
614        t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,))
615        t3 = torch.sparse_coo_tensor([[0]], [1], size=(4,))
616
617        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
618            opts = c10d.AllreduceOptions()
619            pg.allreduce([], opts)
620
621        with self.assertRaisesRegex(RuntimeError, "invalid tensor layout"):
622            opts = c10d.AllreduceOptions()
623            pg.allreduce([t1, t2], opts)
624
625        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
626            opts = c10d.AllreduceOptions()
627            pg.allreduce([t2, t3], opts)
628
629        # Sparse allreduce only works with c10d.ReduceOp.SUM.
630        for op in [c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX]:
631            with self.assertRaisesRegex(
632                RuntimeError, "unsupported reduction operation"
633            ):
634                opts = c10d.AllreduceOptions()
635                opts.reduceOp = op
636                pg.allreduce([t3], opts)
637
638    def _test_sparse_allreduce_basics(self, fn):
639        store = c10d.FileStore(self.file_name, self.world_size)
640        pg = self._create_process_group_gloo(
641            store, self.rank, self.world_size, self.opts()
642        )
643
644        for num_inputs_per_rank in [1, 2]:
645            tests = simple_sparse_reduce_tests(
646                self.rank, self.world_size, num_inputs=num_inputs_per_rank
647            )
648            for inputs, outputs in tests:
649                tensors = [fn(input) for input in inputs]
650                fut = pg.allreduce(tensors).get_future()
651                fut.wait()
652                result = fut.value()
653                self.assertEqual(tensors, outputs)
654                self.assertEqual(result, outputs)
655
656    @requires_gloo()
657    def test_sparse_allreduce_basics(self):
658        self._test_sparse_allreduce_basics(lambda t: t)
659
660    @skip_if_lt_x_gpu(2)
661    @requires_gloo()
662    def test_sparse_allreduce_basics_cuda(self):
663        self._test_sparse_allreduce_basics(lambda t: t.clone().cuda())
664
665    @skip_if_lt_x_gpu(2)
666    @requires_gloo()
667    def test_sparse_allreduce_cuda_dispatched(self):
668        store = c10d.FileStore(self.file_name, self.world_size)
669        dist.init_process_group(
670            backend="gloo", store=store, rank=self.rank, world_size=self.world_size
671        )
672        tests = simple_sparse_reduce_tests(self.rank, self.world_size, num_inputs=1)
673        for inputs, outputs in tests:
674            tensors = inputs[-1].clone().cuda()
675            work = dist.all_reduce(tensors, async_op=True)
676            work.wait()
677            self.assertEqual([tensors], outputs)
678
679    @requires_gloo()
680    def test_allgather_into_tensor_coalesced(self):
681        store = c10d.FileStore(self.file_name, self.world_size)
682        dist.init_process_group(
683            backend="gloo",
684            store=store,
685            rank=self.rank,
686            world_size=self.world_size,
687        )
688        torch.manual_seed(42)
689        in_shapes = [(5, 5), (10, 10), (15, 15)]
690        out_shapes = [(s[0] * self.world_size,) + s[1:] for s in in_shapes]
691
692        outputs = [torch.empty(s) for s in out_shapes]
693        inputs = [torch.rand(s) for s in in_shapes]
694        work = dist.group.WORLD.allgather_into_tensor_coalesced(outputs, inputs)
695        work.wait()
696
697        for output, input in zip(outputs, inputs):
698            expect = torch.cat([input] * self.world_size)
699            self.assertTrue(torch.allclose(output, expect))
700
701    @requires_gloo()
702    def test_reduce_scatter_tensor(self):
703        store = c10d.FileStore(self.file_name, self.world_size)
704        dist.init_process_group(
705            backend="gloo",
706            store=store,
707            rank=self.rank,
708            world_size=self.world_size,
709        )
710        torch.manual_seed(42)
711        out_shape = (20, 20)
712        in_shape = (out_shape[0] * self.world_size,) + out_shape[1:]
713
714        output = torch.empty(out_shape)
715        input = torch.rand(in_shape)
716        work = dist.reduce_scatter_tensor(output, input, async_op=True)
717        work.wait()
718
719        expect = (
720            input.view(self.world_size, *out_shape).chunk(self.world_size)[self.rank]
721            * self.world_size
722        )
723        self.assertTrue(torch.allclose(output, expect))
724
725    @requires_gloo()
726    def test_reduce_scatter_tensor_coalesced(self):
727        store = c10d.FileStore(self.file_name, self.world_size)
728        dist.init_process_group(
729            backend="gloo",
730            store=store,
731            rank=self.rank,
732            world_size=self.world_size,
733        )
734        torch.manual_seed(42)
735        out_shapes = [(5, 5), (10, 10), (15, 15)]
736        in_shapes = [(s[0] * self.world_size,) + s[1:] for s in out_shapes]
737
738        outputs = [torch.empty(s) for s in out_shapes]
739        inputs = [torch.rand(s) for s in in_shapes]
740        work = dist.group.WORLD.reduce_scatter_tensor_coalesced(outputs, inputs)
741        work.wait()
742
743        for output, input in zip(outputs, inputs):
744            expect = (
745                input.view(self.world_size, *output.shape).chunk(self.world_size)[
746                    self.rank
747                ]
748                * self.world_size
749            )
750            self.assertTrue(torch.allclose(output, expect))
751
752    @requires_gloo()
753    def test_scatter_checks(self):
754        store = c10d.FileStore(self.file_name, self.world_size)
755        pg = self._create_process_group_gloo(
756            store, self.rank, self.world_size, self.opts()
757        )
758
759        t1 = torch.zeros([1], dtype=torch.float32)
760        t2 = torch.zeros([1], dtype=torch.float64)
761        t3 = torch.zeros([2], dtype=torch.float32)
762
763        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
764            opts = c10d.ScatterOptions()
765            opts.rootRank = -1
766            pg.scatter([t1], [], opts)
767
768        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
769            opts = c10d.ScatterOptions()
770            opts.rootRank = self.world_size
771            pg.scatter([t1], [], opts)
772
773        with self.assertRaisesRegex(
774            RuntimeError, "requires a single-element output tensor list"
775        ):
776            opts = c10d.ScatterOptions()
777            opts.rootRank = 0
778            pg.scatter([], [], opts)
779
780        with self.assertRaisesRegex(
781            RuntimeError, "requires a single-element output tensor list"
782        ):
783            opts = c10d.ScatterOptions()
784            opts.rootRank = 0
785            pg.scatter([t1, t1], [], opts)
786
787        with self.assertRaisesRegex(
788            RuntimeError, "requires a single-element input list"
789        ):
790            opts = c10d.ScatterOptions()
791            opts.rootRank = self.rank
792            pg.scatter([t1], [], opts)
793
794        with self.assertRaisesRegex(
795            RuntimeError, "requires a single-element input list"
796        ):
797            opts = c10d.ScatterOptions()
798            opts.rootRank = self.rank
799            pg.scatter([t1], [[t1] * self.world_size, [t1] * self.world_size], opts)
800
801        desired_list_size = self.world_size
802        incorrect_list_size = self.world_size - 1
803        err_str = "Incorrect input list size {}. Input list size should be {}"
804        with self.assertRaisesRegex(
805            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
806        ):
807            opts = c10d.ScatterOptions()
808            opts.rootRank = self.rank
809            pg.scatter([t1], [[t1] * incorrect_list_size], opts)
810
811        incorrect_list_size = self.world_size + 1
812        with self.assertRaisesRegex(
813            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
814        ):
815            opts = c10d.ScatterOptions()
816            opts.rootRank = self.rank
817            pg.scatter([t1], [[t1] * incorrect_list_size], opts)
818
819        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
820            opts = c10d.ScatterOptions()
821            opts.rootRank = self.rank
822            pg.scatter([t1], [[t2] * self.world_size], opts)
823
824        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
825            opts = c10d.ScatterOptions()
826            opts.rootRank = self.rank
827            pg.scatter([t1], [[t3] * self.world_size], opts)
828
829        with self.assertRaisesRegex(RuntimeError, "requires empty input on non-root"):
830            opts = c10d.ScatterOptions()
831            opts.rootRank = (self.rank + 1) % self.world_size
832            pg.scatter([t1], [[t1] * self.world_size], opts)
833
834    def _test_scatter_basics(self, fn):
835        store = c10d.FileStore(self.file_name, self.world_size)
836        pg = self._create_process_group_gloo(
837            store, self.rank, self.world_size, self.opts()
838        )
839
840        # Preallocate tensors for input/output
841        input = [fn(torch.tensor([self.rank])) for _ in range(self.world_size)]
842        outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]
843
844        # Take turns being the scatter root and accumulate work items
845        futures = []
846        for i in range(self.world_size):
847            opts = c10d.ScatterOptions()
848            opts.rootRank = i
849            if i == self.rank:
850                futures.append(pg.scatter([outputs[i]], [input], opts).get_future())
851            else:
852                futures.append(pg.scatter([outputs[i]], [], opts).get_future())
853
854        # Wait for work to complete
855        for i in range(self.world_size):
856            futures[i].wait()
857            result = futures[i].value()
858            self.assertEqual(torch.tensor([i]), result[0])
859
860    @requires_gloo()
861    def test_scatter_basics(self):
862        self._test_scatter_basics(lambda t: t.clone())
863
864    @skip_if_lt_x_gpu(2)
865    @requires_gloo()
866    def test_scatter_basics_cuda(self):
867        self._test_scatter_basics(lambda t: t.clone().cuda())
868
869    def _test_scatter_stress(self, inputs, fn):
870        store = c10d.FileStore(self.file_name, self.world_size)
871        pg = self._create_process_group_gloo(
872            store, self.rank, self.world_size, self.opts(threads=8)
873        )
874        outputs = [
875            [fn(torch.tensor([-1])) for _ in range(self.world_size)]
876            for _ in range(len(inputs))
877        ]
878        future_handles = []
879        for i in range(len(inputs)):
880            for root in range(self.world_size):
881                opts = c10d.ScatterOptions()
882                opts.rootRank = root
883                if root == self.rank:
884                    fut = pg.scatter(
885                        [outputs[i][root]], [[fn(e) for e in inputs[i]]], opts
886                    ).get_future()
887                else:
888                    fut = pg.scatter([outputs[i][root]], [], opts).get_future()
889                future_handles.append(fut)
890
891        for i, future_handle in enumerate(future_handles):
892            future_handle.wait()
893            iter = i // self.world_size
894            root = i % self.world_size
895            result = future_handle.value()
896
897            self.assertEqual(
898                torch.tensor([iter + root]),
899                result[0],
900                msg=("Mismatch in iteration %d for rank %d" % (iter, root)),
901            )
902
903    @requires_gloo()
904    def test_set_gloo_pg_timeout(self):
905        store = c10d.FileStore(self.file_name, self.world_size)
906        pg = self._create_process_group_gloo(
907            store, self.rank, self.world_size, self.opts()
908        )
909        pg.allreduce(torch.rand(10))
910        self.assertEqual(pg.options._timeout, timedelta(seconds=50))
911        pg._set_default_timeout(timedelta(seconds=23))
912        self.assertEqual(pg.options._timeout, timedelta(seconds=23))
913
914    @requires_gloo()
915    def test_scatter_stress(self):
916        inputs = [
917            [torch.tensor([i + self.rank]) for _ in range(self.world_size)]
918            for i in range(1000)
919        ]
920        self._test_scatter_stress(inputs, lambda t: t.clone())
921
922    @skip_but_pass_in_sandcastle(
923        "Test is flaky, see https://github.com/pytorch/pytorch/issues/15963"
924    )
925    @skip_if_lt_x_gpu(2)
926    @requires_gloo()
927    def test_scatter_stress_cuda(self):
928        inputs = [
929            [torch.tensor([i + self.rank]) for _ in range(self.world_size)]
930            for i in range(1000)
931        ]
932        self._test_scatter_stress(inputs, lambda t: t.clone().cuda())
933
934    @requires_gloo()
935    def test_gather_checks(self):
936        store = c10d.FileStore(self.file_name, self.world_size)
937        pg = self._create_process_group_gloo(
938            store, self.rank, self.world_size, self.opts()
939        )
940
941        t1 = torch.zeros([1], dtype=torch.float32)
942        t2 = torch.zeros([1], dtype=torch.float64)
943        t3 = torch.zeros([2], dtype=torch.float32)
944
945        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
946            opts = c10d.GatherOptions()
947            opts.rootRank = -1
948            pg.gather([], [t1], opts)
949
950        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
951            opts = c10d.GatherOptions()
952            opts.rootRank = self.world_size
953            pg.gather([], [t1], opts)
954
955        with self.assertRaisesRegex(
956            RuntimeError, "requires a single-element input tensor list"
957        ):
958            opts = c10d.GatherOptions()
959            opts.rootRank = 0
960            pg.gather([], [], opts)
961
962        with self.assertRaisesRegex(
963            RuntimeError, "requires a single-element input tensor list"
964        ):
965            opts = c10d.GatherOptions()
966            opts.rootRank = 0
967            pg.gather([], [t1, t1], opts)
968
969        with self.assertRaisesRegex(
970            RuntimeError, "requires a single-element output list"
971        ):
972            opts = c10d.GatherOptions()
973            opts.rootRank = self.rank
974            pg.gather([], [t1], opts)
975
976        with self.assertRaisesRegex(
977            RuntimeError, "requires a single-element output list"
978        ):
979            opts = c10d.GatherOptions()
980            opts.rootRank = self.rank
981            pg.gather([[t1] * self.world_size, [t1] * self.world_size], [t1], opts)
982
983        desired_list_size = self.world_size
984        incorrect_list_size = self.world_size - 1
985        err_str = "Incorrect output list size {}. Output list size should be {}"
986        with self.assertRaisesRegex(
987            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
988        ):
989            opts = c10d.GatherOptions()
990            opts.rootRank = self.rank
991            pg.gather([[t1] * incorrect_list_size], [t1], opts)
992
993        incorrect_list_size = self.world_size + 1
994        with self.assertRaisesRegex(
995            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
996        ):
997            opts = c10d.GatherOptions()
998            opts.rootRank = self.rank
999            pg.gather([[t1] * incorrect_list_size], [t1], opts)
1000
1001        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
1002            opts = c10d.GatherOptions()
1003            opts.rootRank = self.rank
1004            pg.gather([[t2] * self.world_size], [t1], opts)
1005
1006        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
1007            opts = c10d.GatherOptions()
1008            opts.rootRank = self.rank
1009            pg.gather([[t3] * self.world_size], [t1], opts)
1010
1011        with self.assertRaisesRegex(RuntimeError, "requires empty output on non-root"):
1012            opts = c10d.GatherOptions()
1013            opts.rootRank = (self.rank + 1) % self.world_size
1014            pg.gather([[t1] * self.world_size], [t1], opts)
1015
1016    def _test_gather_basics(self, fn):
1017        store = c10d.FileStore(self.file_name, self.world_size)
1018        pg = self._create_process_group_gloo(
1019            store, self.rank, self.world_size, self.opts()
1020        )
1021
1022        # Preallocate tensors for input/output
1023        input = [fn(torch.tensor([self.rank]))]
1024        outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]
1025
1026        # Take turns being the gather root and accumulate work items
1027        futures = []
1028        for i in range(self.world_size):
1029            opts = c10d.GatherOptions()
1030            opts.rootRank = i
1031            if i == self.rank:
1032                futures.append(pg.gather([outputs], input, opts).get_future())
1033            else:
1034                futures.append(pg.gather([], input, opts).get_future())
1035
1036        # Wait for work to complete
1037        expected = [fn(torch.tensor([rank])) for rank in range(self.world_size)]
1038        for i in range(self.world_size):
1039            futures[i].wait()
1040            result = futures[i].value()
1041            if i == self.rank:
1042                self.assertEqual(expected, result)
1043
1044    @requires_gloo()
1045    def test_gather_basics(self):
1046        self._test_gather_basics(lambda t: t.clone())
1047
1048    @skip_if_lt_x_gpu(2)
1049    @requires_gloo()
1050    def test_gather_basics_cuda(self):
1051        self._test_gather_basics(lambda t: t.clone().cuda())
1052
1053    @requires_gloo()
1054    def test_gather_noncontiguous_input(self):
1055        # Take a column of 2D tensor, such that memory is not dense
1056        self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
1057
1058    def _test_gather_stress(self, inputs, fn):
1059        store = c10d.FileStore(self.file_name, self.world_size)
1060        pg = self._create_process_group_gloo(
1061            store, self.rank, self.world_size, self.opts(threads=8)
1062        )
1063        future_handles = []
1064        outputs = [
1065            [[fn(torch.tensor([-1])) for _ in range(self.world_size)]]
1066            for _ in range(len(inputs))
1067        ]
1068        expected_outputs = [
1069            [[torch.tensor([i + j]) for j in range(self.world_size)]]
1070            for i in range(len(inputs))
1071        ]
1072        for i in range(len(inputs)):
1073            for root in range(self.world_size):
1074                opts = c10d.GatherOptions()
1075                opts.rootRank = root
1076                if root == self.rank:
1077                    fut = pg.gather(outputs[i], [fn(inputs[i])], opts).get_future()
1078                else:
1079                    fut = pg.gather([], [fn(inputs[i])], opts).get_future()
1080                future_handles.append(fut)
1081
1082        for i, future_handle in enumerate(future_handles):
1083            future_handle.wait()
1084            iter = i // self.world_size
1085            root = i % self.world_size
1086            if root == self.rank:
1087                result = future_handle.value()
1088                self.assertEqual(
1089                    expected_outputs[iter],
1090                    [result],
1091                    msg=("Mismatch in iteration %d for root %d" % (iter, root)),
1092                )
1093
1094    @requires_gloo()
1095    def test_gather_stress(self):
1096        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1097        self._test_gather_stress(inputs, lambda t: t.clone())
1098
1099    @skip_if_lt_x_gpu(2)
1100    @requires_gloo()
1101    def test_gather_stress_cuda(self):
1102        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1103        self._test_gather_stress(inputs, lambda t: t.clone().cuda())
1104
1105    @requires_gloo()
1106    def test_allgather_checks(self):
1107        store = c10d.FileStore(self.file_name, self.world_size)
1108        pg = self._create_process_group_gloo(
1109            store, self.rank, self.world_size, self.opts()
1110        )
1111
1112        t1 = torch.zeros([1], dtype=torch.float32)
1113        t2 = torch.zeros([1], dtype=torch.float64)
1114        t3 = torch.zeros([2], dtype=torch.float32)
1115
1116        with self.assertRaisesRegex(
1117            RuntimeError, "requires non-empty input tensor list"
1118        ):
1119            pg.allgather([], [])
1120
1121        with self.assertRaisesRegex(
1122            RuntimeError, "requires input/output tensor lists to have the same length"
1123        ):
1124            pg.allgather([], [t1])
1125
1126        with self.assertRaisesRegex(
1127            RuntimeError, "requires input/output tensor lists to have the same length"
1128        ):
1129            pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1])
1130
1131        with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):
1132            pg.allgather([[t1] * (self.world_size - 1)], [t1])
1133
1134        with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):
1135            pg.allgather([[t1] * (self.world_size + 1)], [t1])
1136
1137        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
1138            pg.allgather(
1139                [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2]
1140            )
1141
1142        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
1143            pg.allgather(
1144                [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3]
1145            )
1146
1147        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
1148            pg.allgather([([t1, t2] * (self.world_size))[: self.world_size]], [t1])
1149
1150        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
1151            pg.allgather([([t1, t3] * (self.world_size))[: self.world_size]], [t1])
1152
1153    def _test_allgather_basics(self, fn):
1154        store = c10d.FileStore(self.file_name, self.world_size)
1155        pg = self._create_process_group_gloo(
1156            store, self.rank, self.world_size, self.opts()
1157        )
1158
1159        # Run with N input tensor per rank
1160        for n in [1, 2, 3]:
1161            input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
1162            output = [
1163                [fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
1164                for _ in range(n)
1165            ]
1166            expected_output = [
1167                [fn(torch.tensor([i])) for i in range(n * self.world_size)]
1168                for _ in range(n)
1169            ]
1170            fut = pg.allgather(output, input).get_future()
1171            fut.wait()
1172            result = fut.value()
1173            if n == 1:
1174                result = [result]
1175            self.assertEqual(expected_output, result)
1176
1177    @requires_gloo()
1178    def test_allgather_basics(self):
1179        self._test_allgather_basics(lambda t: t.clone())
1180
1181    @skip_if_lt_x_gpu(2)
1182    @requires_gloo()
1183    def test_allgather_basics_cuda(self):
1184        self._test_allgather_basics(lambda t: t.clone().cuda())
1185
1186    @requires_gloo()
1187    def test_allgather_noncontiguous_input(self):
1188        # Take a column of 2D tensor, such that memory is not dense
1189        self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
1190
1191    def _test_allgather_stress(self, inputs, fn):
1192        store = c10d.FileStore(self.file_name, self.world_size)
1193        pg = self._create_process_group_gloo(
1194            store, self.rank, self.world_size, self.opts(threads=8)
1195        )
1196        future_handles = []
1197        outputs = [
1198            [[fn(torch.tensor([-1])) for _ in range(self.world_size)]]
1199            for _ in range(len(inputs))
1200        ]
1201        expected_outputs = [
1202            [[torch.tensor([i + j]) for j in range(self.world_size)]]
1203            for i in range(len(inputs))
1204        ]
1205        input_holder = {}
1206        for i in range(len(inputs)):
1207            # Note that this works around the data race discussed in
1208            # https://github.com/pytorch/pytorch/issues/75529, but we should
1209            # actually be able to pass the list directly into allgather when
1210            # that race is fixed.
1211            input_holder[i] = [fn(inputs[i])]
1212            fut = pg.allgather(outputs[i], input_holder[i]).get_future()
1213            future_handles.append(fut)
1214
1215        for i, future_handle in enumerate(future_handles):
1216            future_handle.wait()
1217            result = future_handle.value()
1218            self.assertEqual(
1219                expected_outputs[i],
1220                [result],
1221                msg=("Mismatch in iteration %d" % i),
1222            )
1223
1224    @requires_gloo()
1225    def test_allgather_stress(self):
1226        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1227        self._test_allgather_stress(inputs, lambda t: t.clone())
1228
1229    @skip_if_lt_x_gpu(2)
1230    @requires_gloo()
1231    def test_allgather_stress_cuda(self):
1232        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1233        self._test_allgather_stress(inputs, lambda t: t.clone().cuda())
1234
1235    @requires_gloo()
1236    def test_allgather_coalesced_checks(self):
1237        store = c10d.FileStore(self.file_name, self.world_size)
1238        pg = self._create_process_group_gloo(
1239            store, self.rank, self.world_size, self.opts()
1240        )
1241        dummy_input = [torch.zeros([1], dtype=torch.float32)]
1242        dummy_output_lists = [
1243            [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)
1244        ]
1245
1246        # One of output tensors does not match input list.
1247        dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)]
1248        with self.assertRaisesRegex(
1249            RuntimeError, "invalid size of output tensor at index 0"
1250        ):
1251            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1252
1253        # One of output tensors does not match input list.
1254        dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)]
1255        with self.assertRaisesRegex(RuntimeError, "invalid tensor type at index 0"):
1256            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1257
1258        # Output lists have too many elements
1259        dummy_output_lists = [
1260            [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size + 1)
1261        ]
1262        with self.assertRaisesRegex(
1263            RuntimeError, "output lists should be equal to world size"
1264        ):
1265            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1266
1267        # Output is not a list of lists.
1268        dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
1269        with self.assertRaisesRegex(
1270            TypeError, "Invalid function argument.*output_tensor_lists"
1271        ):
1272            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1273
1274    @requires_gloo()
1275    def test_allgather_coalesced_async(self):
1276        store = c10d.FileStore(self.file_name, self.world_size)
1277        c10d.init_process_group(
1278            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
1279        )
1280
1281        xxs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]
1282        yys = [
1283            [[torch.zeros_like(x) for x in xx] for _ in range(self.world_size)]
1284            for xx in xxs
1285        ]
1286        futs = [
1287            c10d.all_gather_coalesced(yy, xx, async_op=True) for xx, yy in zip(xxs, yys)
1288        ]
1289
1290        # expected outputs
1291        zzs = [
1292            [2 * [torch.tensor([i + r])] for r in range(self.world_size)]
1293            for i in range(2)
1294        ]
1295
1296        torch.futures.wait_all(futs)
1297        for yy, zz in zip(yys, zzs):
1298            # one iteration
1299            for y_out, z_out in zip(yy, zz):
1300                # one output tensor list
1301                for y, z in zip(y_out, z_out):
1302                    # one tensor in output tensor list
1303                    self.assertEqual(y, z)
1304
1305        # Added to address https://github.com/pytorch/pytorch/issues/65231
1306        # In the failed tests, all assertEqual are passed on all processes.
1307        # However, one of the processes didn't call ProcessGroupGloo
1308        # destructor before exiting program. This is not surprising as the only
1309        # guarantee that Python makes is that garbage collection MAY happen
1310        # before the program exits. If GC didn't happen, the two threads in
1311        # ProcessGroup might be destructed before joined.
1312        # FIXME: it's still unclear why only this test require explicit
1313        # destroy_process_group()
1314        c10d.destroy_process_group()
1315
1316    @requires_gloo()
1317    def test_reduce_checks(self):
1318        store = c10d.FileStore(self.file_name, self.world_size)
1319        pg = pg = self._create_process_group_gloo(
1320            store, self.rank, self.world_size, self.opts()
1321        )
1322
1323        t1 = torch.zeros([1], dtype=torch.float32)
1324
1325        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
1326            opts = c10d.ReduceOptions()
1327            opts.rootRank = -1
1328            opts.rootTensor = 0
1329            pg.reduce([t1], opts)
1330
1331        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
1332            opts = c10d.ReduceOptions()
1333            opts.rootRank = self.world_size
1334            opts.rootTensor = 0
1335            pg.reduce([t1], opts)
1336
1337        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
1338            opts = c10d.ReduceOptions()
1339            opts.rootRank = self.rank
1340            opts.rootTensor = 1
1341            pg.reduce([t1], opts)
1342
1343        with self.assertRaisesRegex(
1344            RuntimeError, "requires a single-element tensor list"
1345        ):
1346            opts = c10d.ReduceOptions()
1347            opts.rootRank = self.rank
1348            opts.rootTensor = 0
1349            pg.reduce([t1, t1], opts)
1350
1351    def _test_reduce_basics(self, fn):
1352        store = c10d.FileStore(self.file_name, self.world_size)
1353        pg = self._create_process_group_gloo(
1354            store, self.rank, self.world_size, self.opts()
1355        )
1356        for op, input, output in simple_reduce_tests(self.rank, self.world_size):
1357            for root in range(self.world_size):
1358                opts = c10d.ReduceOptions()
1359                opts.reduceOp = op
1360                opts.rootRank = root
1361                tmp = fn(input)
1362                fut = pg.reduce([tmp], opts).get_future()
1363                fut.wait()
1364                result = fut.value()
1365                if root == self.rank:
1366                    self.assertEqual(output, result[0])
1367
1368    @requires_gloo()
1369    def test_reduce_basics(self):
1370        self._test_reduce_basics(lambda t: t.clone())
1371
1372    @skip_if_lt_x_gpu(2)
1373    @requires_gloo()
1374    def test_reduce_basics_cuda(self):
1375        self._test_reduce_basics(lambda t: t.clone().cuda())
1376
1377    def _test_reduce_stress(self, inputs):
1378        store = c10d.FileStore(self.file_name, self.world_size)
1379        pg = self._create_process_group_gloo(
1380            store, self.rank, self.world_size, self.opts(threads=8)
1381        )
1382        future_handles = []
1383        outputs = []
1384        for i in range(len(inputs)):
1385            for root in range(self.world_size):
1386                opts = c10d.ReduceOptions()
1387                opts.rootRank = root
1388                tmp = inputs[i].clone()
1389                outputs.append(tmp)
1390                fut = pg.reduce([tmp], opts).get_future()
1391                future_handles.append(fut)
1392
1393        for i, future_handle in enumerate(future_handles):
1394            future_handle.wait()
1395            result = future_handle.value()
1396            iter = i // self.world_size
1397            root = i % self.world_size
1398            if root == self.rank:
1399                self.assertEqual(
1400                    torch.tensor(
1401                        [
1402                            (iter * self.world_size)
1403                            + (self.world_size * (self.world_size - 1) // 2)
1404                        ]
1405                    ),
1406                    result[0],
1407                    msg=("Mismatch in iteration %d with root rank %d" % (iter, root)),
1408                )
1409
1410    @requires_gloo()
1411    def test_reduce_stress(self):
1412        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1413        self._test_reduce_stress(inputs)
1414
1415    @skip_if_lt_x_gpu(2)
1416    @requires_gloo()
1417    def test_reduce_stress_cuda(self):
1418        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1419        self._test_reduce_stress(inputs)
1420
1421    @requires_gloo()
1422    def test_send_recv_all_to_all(self):
1423        store = c10d.FileStore(self.file_name, self.world_size)
1424        pg = self._create_process_group_gloo(
1425            store, self.rank, self.world_size, self.opts()
1426        )
1427
1428        # Preallocate tensors for input/output
1429        inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
1430        outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
1431
1432        # Issue sends
1433        send_work = []
1434        for i in range(self.world_size):
1435            if i == self.rank:
1436                continue
1437            send_work.append(pg.send([inputs[i]], i, 0))
1438
1439        # Issue recvs
1440        recv_work = []
1441        for i in range(self.world_size):
1442            if i == self.rank:
1443                continue
1444            recv_work.append(pg.recv([outputs[i]], i, 0))
1445
1446        # Wait for sends to complete
1447        for work in send_work:
1448            work.wait()
1449            self.assertTrue(work.is_completed())
1450
1451        # Wait for recvs to complete
1452        for work in recv_work:
1453            work.wait()
1454            self.assertTrue(work.is_completed())
1455
1456        # Test that every output other than our own contains the respective rank
1457        for i in range(self.world_size):
1458            if i == self.rank:
1459                continue
1460            self.assertEqual(torch.tensor([i]), outputs[i])
1461
1462    @requires_gloo()
1463    def test_barrier_implies_wait(self):
1464        store = c10d.FileStore(self.file_name, self.world_size)
1465        pg = self._create_process_group_gloo(
1466            store, self.rank, self.world_size, self.opts()
1467        )
1468
1469        # Kick off allreduce operations
1470        size = (100, 100)
1471        num = 16
1472        tensors = [torch.full(size, float(i)) for i in range(num)]
1473        for tensor in tensors:
1474            # Note: leak the returned work handle
1475            pg.allreduce(tensor)
1476
1477        # Barrier should ensure all previous work has completed
1478        pg.barrier().get_future().wait()
1479
1480        for i, tensor in enumerate(tensors):
1481            self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
1482
1483
1484class DistributedDataParallelTest(
1485    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
1486):
1487    def setUp(self):
1488        super().setUp()
1489        self._spawn_processes()
1490
1491    def _get_process_group(self):
1492        store = self._get_store()
1493        c10d.init_process_group(
1494            backend="gloo", store=store, rank=self.rank, world_size=self.world_size
1495        )
1496        return c10d.distributed_c10d._get_default_group()
1497
1498    def _test_gloo_backend(
1499        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
1500    ):
1501        store = c10d.FileStore(self.file_name, self.world_size)
1502        c10d.init_process_group(
1503            backend="gloo", store=store, rank=self.rank, world_size=self.world_size
1504        )
1505        process_group = c10d.distributed_c10d._get_default_group()
1506        device = devices[-1]
1507        backend = process_group._get_backend(device)
1508        backend.create_device(interface=LOOPBACK)
1509        self._test_ddp_with_process_group(
1510            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
1511        )
1512
1513    @requires_gloo()
1514    def test_gloo_backend_cpu_module(self):
1515        self._test_gloo_backend([torch.device("cpu")], None)
1516
1517    @requires_gloo()
1518    def test_gloo_backend_cpu_module_grad_is_view(self):
1519        self._test_gloo_backend(
1520            [torch.device("cpu")], None, gradient_as_bucket_view=True
1521        )
1522
1523    @requires_gloo()
1524    @skip_if_lt_x_gpu(2)
1525    def test_gloo_backend_1gpu_module_device_ids_integer_list(self):
1526        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1527        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1528        self._test_gloo_backend(devices, int_devices)
1529
1530    @requires_gloo()
1531    @skip_if_lt_x_gpu(2)
1532    def test_gloo_backend_1gpu_module_device_ids_torch_device_list(self):
1533        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1534        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1535        self._test_gloo_backend(devices, devices)
1536
1537    @requires_gloo()
1538    @skip_if_lt_x_gpu(4)
1539    def test_gloo_backend_2gpu_module(self):
1540        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
1541        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1542        self._test_gloo_backend(devices, None, multi_device=True)
1543
1544    @requires_gloo()
1545    @skip_if_lt_x_gpu(8)
1546    def test_gloo_backend_4gpu_module(self):
1547        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
1548        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1549        self._test_gloo_backend(devices, None, multi_device=True)
1550
1551    def _test_global_local_unused_params_grad(
1552        self, gradient_as_bucket_view=False, static_graph=False
1553    ):
1554        """
1555        By simulating a multi-task training, this test is to make sure:
1556        1) DDP does not touch the grad of globally unused parameters.
1557        2) DDP does update the grad of locally unused parameters.
1558        """
1559
1560        class GlobalLocalUnusedParamModule(nn.Module):
1561            def __init__(self) -> None:
1562                super().__init__()
1563                self.t0 = Task()
1564                self.t1 = Task()
1565                self.task_unused = Task()
1566
1567            def task_parameters(self):
1568                return (self.t0.p, self.t1.p, self.task_unused.p)
1569
1570            def forward(self, x, rank):
1571                return self.t0(x) if rank == 0 else self.t1(x)
1572
1573        def run_and_verify_grad(model):
1574            # Run forward
1575            output = model(8, self.rank)
1576
1577            # The grads of all parameters should be None at this point.
1578            t0_p, t1_p, task_unused_p = model.module.task_parameters()
1579            self.assertIsNone(t0_p.grad)
1580            self.assertIsNone(t1_p.grad)
1581            self.assertIsNone(task_unused_p.grad)
1582
1583            # Run backward
1584            output.mean().backward()
1585
1586            # Now locally unused parameter should have grad updated on all ranks.
1587            # However the globally unused parameter should still have None grad.
1588            self.assertIsNotNone(t0_p.grad)
1589            self.assertIsNotNone(t1_p.grad)
1590            self.assertIsNone(task_unused_p.grad)
1591
1592        process_group = self._get_process_group()
1593
1594        # Test on CPU
1595        cpu_model = DistributedDataParallel(
1596            GlobalLocalUnusedParamModule().cpu(),
1597            process_group=process_group,
1598            find_unused_parameters=True,
1599            gradient_as_bucket_view=gradient_as_bucket_view,
1600            static_graph=static_graph,
1601        )
1602        run_and_verify_grad(cpu_model)
1603
1604        # Test on GPU
1605        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1606        gpu_model = DistributedDataParallel(
1607            GlobalLocalUnusedParamModule().to(device_id),
1608            device_ids=[device_id],
1609            process_group=process_group,
1610            find_unused_parameters=True,
1611            gradient_as_bucket_view=gradient_as_bucket_view,
1612            static_graph=static_graph,
1613        )
1614        run_and_verify_grad(gpu_model)
1615
1616    @requires_gloo()
1617    @skip_if_lt_x_gpu(2)
1618    def test_global_local_unused_params_grad(self):
1619        self._test_global_local_unused_params_grad()
1620
1621    @requires_gloo()
1622    @skip_if_lt_x_gpu(2)
1623    def test_global_local_unused_params_grad_with_grad_is_view(self):
1624        self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
1625
1626    @requires_gloo()
1627    @skip_if_lt_x_gpu(2)
1628    def test_global_local_unused_params_grad_with_static_graph(self):
1629        self._test_global_local_unused_params_grad(static_graph=True)
1630
1631    @requires_gloo()
1632    @skip_if_lt_x_gpu(2)
1633    def test_find_unused_parameters_when_unused_parameters_empty(self):
1634        """
1635        An empty unused_parameters array does not imply find_unused_parameters =
1636        false. This test makes sure that DDP allreduces unused parameters
1637        accordingly where the forward pass in some process uses all parameters.
1638        This unit test creates a module that uses all parameters in rank = 0, and
1639        has unused parameters in other ranks.
1640        """
1641
1642        class FindUnusedParamModule(nn.Module):
1643            def __init__(self) -> None:
1644                super().__init__()
1645                self.t0 = Task()
1646                self.t1 = Task()
1647
1648            def task_parameters(self):
1649                return (self.t0.p, self.t1.p)
1650
1651            def forward(self, x, rank):
1652                return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
1653
1654        def run_and_verify_grad(model):
1655            # Run forward
1656            output = model(8, self.rank)
1657
1658            # The grads of all parameters should be None at this point.
1659            [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
1660
1661            # Run backward
1662            output.mean().backward()
1663
1664            # Now locally unused parameter should have grad updated on all ranks.
1665            [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
1666
1667        process_group = self._get_process_group()
1668
1669        # Test on CPU
1670        cpu_model = DistributedDataParallel(
1671            FindUnusedParamModule().cpu(),
1672            process_group=process_group,
1673            find_unused_parameters=True,
1674        )
1675        run_and_verify_grad(cpu_model)
1676
1677        # Test on GPU
1678        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1679        gpu_model = DistributedDataParallel(
1680            FindUnusedParamModule().to(device_id),
1681            device_ids=[device_id],
1682            process_group=process_group,
1683            find_unused_parameters=True,
1684        )
1685        run_and_verify_grad(gpu_model)
1686
1687    @requires_gloo()
1688    def test_ignored_output(self):
1689        """
1690        Test that the output of a model can be ignored and that there is no
1691        implicit requirement that `backward` gets called.
1692        """
1693        process_group = self._get_process_group()
1694
1695        class IgnoredOutput(nn.Module):
1696            def __init__(self) -> None:
1697                super().__init__()
1698                self.fc1 = nn.Linear(2, 10, bias=False)
1699                self.fc2 = nn.Linear(10, 4, bias=False)
1700                self.relu = nn.ReLU()
1701
1702            def forward(self, x):
1703                x = self.relu(self.fc1(x))
1704                x = self.relu(self.fc2(x))
1705                return F.softmax(x, dim=1)
1706
1707        model = DistributedDataParallel(
1708            IgnoredOutput().float(),
1709            process_group=process_group,
1710        )
1711
1712        batch_size = 4
1713        criterion = nn.CrossEntropyLoss()
1714        input = torch.rand([batch_size, 2], dtype=torch.float)
1715        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
1716
1717        # Run a few iterations where we ignore the output.
1718        for _ in range(4):
1719            output = model(input)
1720            del output
1721
1722        # Run a few iterations where we use the output.
1723        for _ in range(4):
1724            output = model(input)
1725            loss = criterion(output, target)
1726            loss.backward()
1727
1728    @requires_gloo()
1729    def test_ignored_output_with_unused_parameters(self):
1730        """
1731        Test that the output of a model can be ignored and that there is no
1732        implicit requirement that `backward` gets called, if not all model
1733        parameters participated in computing the model output.
1734        """
1735        process_group = self._get_process_group()
1736
1737        class IgnoredOutputWithUnusedParameters(nn.Module):
1738            def __init__(self) -> None:
1739                super().__init__()
1740                self.fc1 = nn.Linear(2, 10, bias=False)
1741                self.fc2 = nn.Linear(10, 4, bias=False)
1742                self.fc3 = nn.Linear(4, 4, bias=False)
1743                self.relu = nn.ReLU()
1744
1745            def forward(self, x):
1746                x = self.relu(self.fc1(x))
1747                x = self.relu(self.fc2(x))
1748                return F.softmax(x, dim=1)
1749
1750        model = DistributedDataParallel(
1751            IgnoredOutputWithUnusedParameters().float(),
1752            process_group=process_group,
1753            find_unused_parameters=True,
1754        )
1755
1756        batch_size = 4
1757        criterion = nn.CrossEntropyLoss()
1758        input = torch.rand([batch_size, 2], dtype=torch.float)
1759        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
1760
1761        # Run a few iterations where we ignore the output.
1762        for _ in range(4):
1763            output = model(input)
1764            del output
1765
1766        # Run a few iterations where we use the output.
1767        for _ in range(4):
1768            output = model(input)
1769            loss = criterion(output, target)
1770            loss.backward()
1771
1772    @requires_gloo()
1773    @skip_if_lt_x_gpu(2)
1774    def test_ignored_sharded_tensor(self):
1775        class MyModule(nn.Module):
1776            def __init__(self, shard_tensor: ShardedTensor) -> None:
1777                super().__init__()
1778                self.fc1 = nn.Linear(2, 10, bias=False)
1779                self.st = nn.Parameter(shard_tensor)
1780                self.relu = nn.ReLU()
1781
1782            def forward(self, x):
1783                x = self.relu(self.fc1(x))
1784                return F.softmax(x, dim=1)
1785
1786        pg = dist.init_process_group(
1787            "gloo",
1788            init_method=f"file://{self.file_name}",
1789            world_size=self.world_size,
1790            rank=self.rank,
1791        )
1792        device = torch.device(f"cuda:{self.rank}")
1793        local_shard_metadata = ShardMetadata(
1794            shard_offsets=[(self.rank % 2) * 5, 0],
1795            shard_sizes=[5, 10],
1796            placement=f"rank:{self.rank}/cuda:{self.rank}",
1797        )
1798        local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)]
1799        st = init_from_local_shards(local_shards, [10, 10])
1800        m = MyModule(st)
1801        DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
1802            module=m, params_and_buffers_to_ignore={"st"}
1803        )
1804        # test to make DDP constructor will not fail when module includes a ShardedTensor when ignored
1805        DistributedDataParallel(
1806            m,
1807            device_ids=[device] if device.type == "gpu" else None,
1808            process_group=pg,
1809            gradient_as_bucket_view=True,
1810            broadcast_buffers=False,
1811            static_graph=True,
1812        )
1813
1814    def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
1815        mult = 2
1816        batch_size = mult * self.world_size
1817        criterion = nn.CrossEntropyLoss()
1818        input = torch.randint(0, 10, [batch_size, 2])
1819        target = torch.randint(0, 10, [batch_size])
1820
1821        # Run with entire batch against single process version
1822        criterion(vanilla_model(input), target).backward()
1823
1824        # Run with partial batch against multi process version
1825        partial_input = input.split(mult)[self.rank]
1826        partial_target = target.split(mult)[self.rank]
1827        criterion(ddp_model(partial_input), partial_target).backward()
1828
1829        # Check that the gradients are sparse and identical
1830        vanilla_parameter = next(vanilla_model.parameters())
1831        ddp_parameter = next(ddp_model.parameters())
1832        self.assertEqual(
1833            vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce()
1834        )
1835
1836    @requires_gloo()
1837    @skip_if_lt_x_gpu(2)
1838    def test_save_load_checkpoint(self):
1839        dist.init_process_group(
1840            "gloo",
1841            init_method=f"file://{self.file_name}",
1842            world_size=self.world_size,
1843            rank=self.rank,
1844        )
1845
1846        class TestModel(nn.Module):
1847            def __init__(self) -> None:
1848                super().__init__()
1849                self.fc1 = nn.Linear(2, 10, bias=False)
1850                self.fc2 = nn.Linear(10, 4, bias=False)
1851                self.relu = nn.ReLU()
1852
1853            def forward(self, x):
1854                x = self.relu(self.fc1(x))
1855                x = self.relu(self.fc2(x))
1856                return F.softmax(x, dim=1)
1857
1858        def train_loop(model, optimizer, iterations):
1859            for _ in range(iterations):
1860                optimizer.zero_grad()
1861                output = model(input)
1862                loss = criterion(output, target)
1863                loss.backward()
1864                optimizer.step()
1865
1866        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1867
1868        model_withload = TestModel().float().to(device_id)
1869        model_withoutload = TestModel().float().to(device_id)
1870
1871        ddp_withload = DistributedDataParallel(
1872            model_withload,
1873            device_ids=[device_id],
1874        )
1875        ddp_withoutload = DistributedDataParallel(
1876            model_withoutload,
1877            device_ids=[device_id],
1878        )
1879
1880        # ensure that all the three models start with the same set of parameters. By default they are randomized on construction
1881        for p in ddp_withload.parameters():
1882            with torch.no_grad():
1883                p.zero_()
1884        for p in model_withload.parameters():
1885            with torch.no_grad():
1886                p.zero_()
1887        for p in ddp_withoutload.parameters():
1888            with torch.no_grad():
1889                p.zero_()
1890
1891        batch_size = 4
1892        criterion = nn.CrossEntropyLoss()
1893
1894        optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
1895        optimizer_non_ddp_withload = torch.optim.SGD(
1896            model_withload.parameters(), lr=0.001
1897        )
1898        optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
1899
1900        input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
1901        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1902            device_id
1903        )
1904
1905        # run the model for 6 iterations, with a checkpoint in the middle
1906        train_loop(ddp_withload, optimizer_withload, 3)
1907
1908        # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
1909        checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
1910        if self.rank == 0:
1911            torch.save(ddp_withload.state_dict(), checkpoint_path)
1912
1913        dist.barrier()
1914        map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
1915        ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
1916
1917        for model in [ddp_withload, model_withload]:
1918            for p in ddp_withload.parameters():
1919                with torch.no_grad():
1920                    p.zero_()
1921        ddp_withload.load_state_dict(ddp_state_dict)
1922        # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
1923        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
1924            ddp_state_dict, "module."
1925        )
1926        model_withload.load_state_dict(ddp_state_dict)
1927
1928        train_loop(ddp_withload, optimizer_withload, 3)
1929        train_loop(model_withload, optimizer_non_ddp_withload, 3)
1930
1931        # re-run the model with the same inputs for 6 iterations with no checkpoint
1932        train_loop(ddp_withoutload, optimizer_withoutload, 6)
1933
1934        for p_withload, p_withoutload, p_non_ddp_withload in zip(
1935            ddp_withload.parameters(),
1936            ddp_withoutload.parameters(),
1937            model_withload.parameters(),
1938        ):
1939            self.assertEqual(p_withload, p_withoutload)
1940            self.assertEqual(p_non_ddp_withload, p_withoutload)
1941
1942    def _test_sparse_gradients(self, gradient_as_bucket_view=False):
1943        process_group = self._get_process_group()
1944
1945        # Ensure initialized weights and inputs are identical across processes
1946        torch.manual_seed(1337)
1947
1948        vanilla_model = SparseGradientModule()
1949        ddp_model = DistributedDataParallel(
1950            copy.deepcopy(vanilla_model),
1951            process_group=process_group,
1952            gradient_as_bucket_view=gradient_as_bucket_view,
1953        )
1954
1955        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
1956
1957    @requires_gloo()
1958    def test_sparse_gradients(self):
1959        self._test_sparse_gradients()
1960
1961    @requires_gloo()
1962    def test_sparse_gradients_grad_is_view(self):
1963        self._test_sparse_gradients(gradient_as_bucket_view=True)
1964
1965    @requires_gloo()
1966    def test_ddp_comm_hook_future_passing_cpu(self):
1967        """
1968        This unit test verifies whether the Future object is passed properly.
1969        The callback function creates a Future object and sets a value to it.
1970        """
1971        store = c10d.FileStore(self.file_name, self.world_size)
1972        process_group = self._get_process_group()
1973
1974        # Test on CPU
1975        cpu_model = DistributedDataParallel(
1976            ModuleForDdpCommHook().cpu(), process_group=process_group
1977        )
1978
1979        # Register DDP Communication Hook
1980        cpu_model.register_comm_hook(None, self._simple_hook)
1981
1982        # check whether the grads are equal to what then callback returns.
1983        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
1984        self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
1985
1986    def _gpu_model_with_ddp_comm_hook(
1987        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
1988    ):
1989        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1990        gpu_model = DistributedDataParallel(
1991            ModuleForDdpCommHook().to(device_id),
1992            device_ids=[device_id],
1993            process_group=process_group,
1994            gradient_as_bucket_view=gradient_as_bucket_view,
1995        )
1996
1997        # Register a DDP communication hook if any.
1998        if hook is not None:
1999            gpu_model.register_comm_hook(state, hook)
2000
2001        return gpu_model
2002
2003    @requires_gloo()
2004    @skip_if_lt_x_gpu(2)
2005    def test_ddp_comm_hook_future_passing_gpu_gloo(self):
2006        """
2007        This unit test verifies whether the Future object is passed properly using gloo backend.
2008        The hook callback function creates a Future object and sets a value to it.
2009        """
2010        process_group = self._get_process_group()
2011
2012        # Get GPU model with simple_hook registered.
2013        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
2014
2015        # check whether the grads are equal to what simple_hook's then callback returns.
2016        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
2017        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
2018
2019    @requires_gloo()
2020    def test_ddp_invalid_comm_hook_init(self):
2021        """
2022        This unit test makes sure that register_comm_hook properly checks the format
2023        of hook defined by user. The Python hook must be callable. This test also
2024        checks whether bucket annotation checked properly if defined.
2025        """
2026        process_group = self._get_process_group()
2027
2028        model = DistributedDataParallel(
2029            ModuleForDdpCommHook(), process_group=process_group
2030        )
2031
2032        with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
2033            model.register_comm_hook(state=None, hook=1)
2034
2035        with self.assertRaisesRegex(
2036            ValueError, "bucket annotation should be dist.GradBucket."
2037        ):
2038
2039            def comm_hook(
2040                state: object, bucket: int
2041            ) -> torch.futures.Future[torch.Tensor]:
2042                return torch.futures.Future()
2043
2044            model.register_comm_hook(state=None, hook=comm_hook)
2045
2046    @requires_gloo()
2047    def test_ddp_invalid_comm_hook_return_type(self):
2048        """
2049        This test checks whether return annotation checked properly if defined. It also
2050        checks whether an internal error is thrown if return type is incorrect and user
2051        hasn't specified any return type annotation.
2052        """
2053        process_group = self._get_process_group()
2054
2055        model = DistributedDataParallel(
2056            ModuleForDdpCommHook(), process_group=process_group
2057        )
2058
2059        expected_err = (
2060            "Communication hook: return annotation should be torch.futures.Future"
2061        )
2062        with self.assertRaisesRegex(
2063            ValueError,
2064            expected_err,
2065        ):
2066
2067            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
2068                return torch.futures.Future()
2069
2070            model.register_comm_hook(state=None, hook=comm_hook)
2071
2072        verify_ddp_error_logged(model, expected_err)
2073
2074        with self.assertRaisesRegex(
2075            RuntimeError,
2076            "callback must return a torch.futures.Future object, but got",
2077        ):
2078
2079            def comm_hook(state: object, bucket: dist.GradBucket):
2080                return 1
2081
2082            model.register_comm_hook(state=None, hook=comm_hook)
2083
2084            # Run forward
2085            output = model(8, self.rank)
2086
2087            # Run backward
2088            output.mean().backward()
2089
2090    @requires_gloo()
2091    def test_ddp_comm_hook_register_just_once(self):
2092        """
2093        DDP communication hook can only be registered once. This test validates whether
2094        the error is thrown properly when register_comm_hook is called more than once.
2095        """
2096        process_group = self._get_process_group()
2097
2098        model = DistributedDataParallel(
2099            ModuleForDdpCommHook(), process_group=process_group
2100        )
2101
2102        def dummy_hook(state, bucket):
2103            fut = torch.futures.Future()
2104            fut.set_result([bucket.buffer()])
2105            return fut
2106
2107        model.register_comm_hook(None, dummy_hook)
2108
2109        with self.assertRaisesRegex(
2110            RuntimeError,
2111            "register_comm_hook or register_builtin_comm_hook can only be called once.",
2112        ):
2113            model.register_comm_hook(None, dummy_hook)
2114
2115    @requires_gloo()
2116    def test_ddp_comm_hook_sparse_gradients(self):
2117        """
2118        Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
2119        simple hook that does allreduce and works with gloo backend for this test.
2120        """
2121        process_group = self._get_process_group()
2122
2123        # Ensure initialized weights and inputs are identical across processes
2124        torch.manual_seed(1337)
2125
2126        vanilla_model = SparseGradientModule()
2127        ddp_model = DistributedDataParallel(
2128            copy.deepcopy(vanilla_model),
2129            process_group=process_group,
2130        )
2131
2132        def allreduce_hook_gloo(
2133            state: object, bucket: dist.GradBucket
2134        ) -> torch.futures.Future[torch.Tensor]:
2135            def div_by_world_size(fut):
2136                # Divide the result by 2 * world_size.
2137                return fut.wait()[0] / self.world_size
2138
2139            # Prepare allreduced grad bucket tensors by running an async work.
2140            fut = process_group.allreduce([bucket.buffer()]).get_future()
2141            return fut.then(div_by_world_size)
2142
2143        ddp_model.register_comm_hook(None, allreduce_hook_gloo)
2144
2145        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
2146
2147
2148class ReducerModule(nn.Module):
2149    def __init__(self) -> None:
2150        super().__init__()
2151        self.fc1 = nn.Linear(2, 10, bias=False)
2152        self.fc2 = nn.Linear(10, 4, bias=False)
2153        self.fc3 = nn.Linear(4, 4, bias=False)
2154        self.relu = nn.ReLU()
2155
2156    def forward(self, x, use_fc3=True):
2157        x = self.relu(self.fc1(x)).float()
2158        x = self.relu(self.fc2(x)).float()
2159        if use_fc3:
2160            x = self.fc3(x).float()
2161        return F.softmax(x, dim=1)
2162
2163
2164class ReducerTest(TestCase):
2165    def setUp(self):
2166        self.file = tempfile.NamedTemporaryFile(delete=False)
2167        world_size = 1
2168        self.store = c10d.FileStore(self.file.name, world_size)
2169        c10d.init_process_group(
2170            backend="gloo", store=self.store, rank=0, world_size=world_size
2171        )
2172        self.process_group = c10d.distributed_c10d._get_default_group()
2173
2174    def tearDown(self):
2175        c10d.destroy_process_group()
2176        try:
2177            os.remove(self.file.name)
2178        except OSError as e:
2179            print(str(e))
2180
2181    @requires_gloo()
2182    def test_single_dtype_single_bucket(self):
2183        model = ReducerModule()
2184        parameters = list(model.parameters())
2185        buckets = [list(range(len(parameters)))]
2186        dist.Reducer(
2187            parameters, buckets, [dist._DEFAULT_FIRST_BUCKET_BYTES], self.process_group
2188        )
2189
2190    def _create_mixed_precision_model(self):
2191        model = ReducerModule()
2192        model.float()
2193        model.fc1.double()
2194        return model
2195
2196    @requires_gloo()
2197    def test_multi_dtype_single_bucket(self):
2198        model = self._create_mixed_precision_model()
2199
2200        # Raise if there are multiple types per bucket.
2201        # In this case we create one bucket for all parameters.
2202        with self.assertRaises(RuntimeError):
2203            parameters = list(model.parameters())
2204            buckets = [list(range(len(parameters)))]
2205            dist.Reducer(
2206                parameters,
2207                buckets,
2208                [dist._DEFAULT_FIRST_BUCKET_BYTES],
2209                self.process_group,
2210            )
2211
2212    @requires_gloo()
2213    def test_multi_dtype_multi_bucket(self):
2214        model = self._create_mixed_precision_model()
2215        parameters = list(model.parameters())
2216        group_by_dtype = groupby(
2217            range(len(parameters)), key=lambda i: parameters[i].dtype
2218        )
2219        buckets = [list(indices) for _, indices in group_by_dtype]
2220        dist.Reducer(
2221            parameters,
2222            buckets,
2223            [dist._DEFAULT_FIRST_BUCKET_BYTES for _ in buckets],
2224            self.process_group,
2225        )
2226
2227    def _create_reducer_for_models(self, models, find_unused_parameters=False):
2228        self.assertEqual(len(models), 1)
2229        parameters = list(models[0].parameters())
2230        group_by_dtype = groupby(
2231            range(len(parameters)), key=lambda i: parameters[i].dtype
2232        )
2233        buckets = [list(indices) for _, indices in group_by_dtype]
2234        return dist.Reducer(
2235            parameters,
2236            buckets,
2237            [dist._DEFAULT_FIRST_BUCKET_BYTES for _ in range(len(buckets))],
2238            self.process_group,
2239            find_unused_parameters=find_unused_parameters,
2240        )
2241
2242    @requires_gloo()
2243    def test_forward_backward(self):
2244        batch_size = 10
2245        model = self._create_mixed_precision_model()
2246        reducer = self._create_reducer_for_models([model])
2247        reducer.prepare_for_forward()
2248        loss = nn.CrossEntropyLoss()
2249        input = torch.rand([batch_size, 2], dtype=torch.double)
2250        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2251        output = loss(model(input), target)
2252        reducer.prepare_for_backward(output)
2253        output.backward()
2254
2255    @requires_gloo()
2256    def test_forward_backward_unused_parameters(self):
2257        batch_size = 10
2258        model = self._create_mixed_precision_model()
2259        reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
2260        reducer.prepare_for_forward()
2261        loss = nn.CrossEntropyLoss()
2262        input = torch.rand([batch_size, 2], dtype=torch.double)
2263        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2264        output = loss(model(input, use_fc3=False), target)
2265
2266        # Check that the grad of fc3 is not set.
2267        self.assertEqual(None, model.fc3.weight.grad)
2268
2269        # Compute and accumulate gradients.
2270        reducer.prepare_for_backward(output)
2271        output.backward()
2272
2273        # The reducer will have marked the grad of fc3 as ready, because
2274        # it doesn't show up in the autograd graph of `output`. Since fc3.weight
2275        # is considered being globally unused, it will be kept untouched as None.
2276        self.assertEqual(None, model.fc3.weight.grad)
2277
2278    @requires_gloo()
2279    def test_forward_backward_optimizer(self):
2280        batch_size = 10
2281        model = self._create_mixed_precision_model()
2282        reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
2283        reducer.prepare_for_forward()
2284        loss = nn.CrossEntropyLoss()
2285        optimizer = torch.optim.Adam(model.parameters())
2286        for i in range(3):
2287            input = torch.rand([batch_size, 2], dtype=torch.double)
2288            target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2289
2290            # The `zero_grad` function calls `detach_` and `zero_` on the grad
2291            # tensors of model parameters. If we tried to set the grad tensors
2292            # to a view of the reducer's bucket tensors, this would blow up.
2293            optimizer.zero_grad()
2294
2295            # Unused parameter only in the first iteration.
2296            output = loss(model(input, use_fc3=(i > 0)), target)
2297            reducer.prepare_for_backward(output)
2298            output.backward()
2299            optimizer.step()
2300
2301
2302class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
2303    @property
2304    def device(self):
2305        return "cpu"
2306
2307    def setUp(self):
2308        super().setUp()
2309        self._spawn_processes()
2310
2311    def tearDown(self):
2312        super().tearDown()
2313        try:
2314            os.remove(self.file_name)
2315        except OSError:
2316            pass
2317
2318    def _test_broadcast_coalesced(self, process_group, device, root_rank):
2319        half = torch.float16
2320
2321        # No support for float16 for CPU tensors
2322        if device == torch.device("cpu"):
2323            half = torch.float32
2324
2325        target = torch.arange(60, dtype=half, device=device).chunk(5)
2326        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2327        target += torch.arange(60, dtype=half, device=device).chunk(5)
2328        target += torch.arange(60, dtype=torch.float64, device=device).chunk(5)
2329        target += torch.arange(60, dtype=half, device=device).chunk(5)
2330        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2331
2332        # The tensors to pass to broadcast are identical to the target
2333        # only on the process that is the root of the broadcast.
2334        if self.rank == root_rank:
2335            tensors = [tensor.clone() for tensor in target]
2336        else:
2337            tensors = [torch.zeros_like(tensor) for tensor in target]
2338
2339        if self.rank != root_rank:
2340            self.assertNotEqual(tensors, target)
2341
2342        c10d._broadcast_coalesced(
2343            process_group, tensors, buffer_size=256, src=root_rank
2344        )
2345
2346        if self.rank != root_rank:
2347            self.assertEqual(tensors, target)
2348
2349    @requires_gloo()
2350    @skip_if_lt_x_gpu(2)
2351    def test_broadcast_coalesced_gloo_cuda(self):
2352        store = c10d.FileStore(self.file_name, self.world_size)
2353        c10d.init_process_group(
2354            backend="gloo", store=store, rank=self.rank, world_size=self.world_size
2355        )
2356        process_group = c10d.distributed_c10d._get_default_group()
2357        device = torch.device("cuda:%d" % self.rank)
2358        backend = process_group._get_backend(device)
2359        backend.create_device(interface=LOOPBACK)
2360        ranks = list(range(self.world_size))
2361        for root_rank in ranks:
2362            self._test_broadcast_coalesced(process_group, device, root_rank)
2363
2364    @requires_gloo()
2365    def test_broadcast_coalesced_gloo_cpu(self):
2366        store = c10d.FileStore(self.file_name, self.world_size)
2367        c10d.init_process_group(
2368            backend="gloo", store=store, rank=self.rank, world_size=self.world_size
2369        )
2370        process_group = c10d.distributed_c10d._get_default_group()
2371        device = torch.device("cpu")
2372        backend = process_group._get_backend(device)
2373        backend.create_device(interface=LOOPBACK)
2374        ranks = list(range(self.world_size))
2375        for root_rank in ranks:
2376            self._test_broadcast_coalesced(process_group, device, root_rank)
2377
2378    @requires_gloo()
2379    @skip_if_lt_x_gpu(2)
2380    def test_sequence_num_set_default_pg_gloo(self):
2381        self._test_sequence_num_set_default_pg(backend="gloo")
2382
2383    @requires_gloo()
2384    @skip_if_lt_x_gpu(2)
2385    def test_sequence_num_set_gloo_new_group(self):
2386        self._test_sequence_num_set_new_group(backend="gloo")
2387
2388    @skip_if_lt_x_gpu(2)
2389    @requires_gloo()
2390    def test_sequence_num_incremented_gloo_default(self):
2391        self._test_sequence_num_incremented_default_group("gloo")
2392
2393    @skip_if_lt_x_gpu(4)
2394    @requires_gloo()
2395    def test_sequence_num_incremented_gloo_subgroup(self):
2396        if self.world_size < 4:
2397            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
2398        self._test_sequence_num_incremented_subgroup("gloo")
2399
2400    @skip_if_lt_x_gpu(2)
2401    @requires_gloo()
2402    def test_gloo_warn_not_in_group(self):
2403        self._test_warn_not_in_group(backend="gloo")
2404
2405    @skip_if_lt_x_gpu(2)
2406    @requires_gloo()
2407    def test_gloo_rank_membership(self):
2408        self._test_rank_membership(backend="gloo")
2409
2410    @skip_if_lt_x_gpu(2)
2411    @requires_gloo()
2412    def test_tensor_dtype_mismatch(self):
2413        self._test_tensor_dtype_mismatch(backend="gloo")
2414
2415    @skip_if_lt_x_gpu(2)
2416    @requires_gloo()
2417    def test_tensor_dtype_complex(self):
2418        self._test_tensor_dtype_complex(backend="gloo")
2419
2420    @requires_gloo()
2421    def test_bool_tensors(self):
2422        self._test_bool_tensors(backend="gloo")
2423
2424
2425class GlooProcessGroupWithDispatchedCollectivesTests(
2426    test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
2427):
2428    @requires_gloo()
2429    def test_collectives(self):
2430        self._test_collectives(backend="gloo")
2431
2432    @requires_gloo()
2433    def test_allreduce_coalesced(self):
2434        self._test_allreduce_coalesced(backend="gloo")
2435
2436    @requires_gloo()
2437    def test_all_to_all_single(self):
2438        self._test_all_to_all_single(backend="gloo")
2439
2440    @requires_gloo()
2441    def test_allgather_coalesced(self):
2442        store = dist.FileStore(self.file_name, self.world_size)
2443        dist.init_process_group(
2444            "gloo",
2445            world_size=self.world_size,
2446            rank=self.rank,
2447            store=store,
2448        )
2449        input_tensor = torch.ones(10, 10, dtype=torch.float32)
2450        output_tensor_list = [torch.zeros_like(input_tensor)]
2451        dist.all_gather_coalesced([output_tensor_list], [input_tensor])
2452        self.assertEqual(output_tensor_list, [input_tensor])
2453
2454    @requires_gloo()
2455    def test_monitored_barrier(self):
2456        store = dist.FileStore(self.file_name, self.world_size)
2457        dist.init_process_group(
2458            "gloo",
2459            world_size=self.world_size,
2460            rank=self.rank,
2461            store=store,
2462        )
2463        dist.monitored_barrier()
2464
2465
2466class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
2467    def setUp(self):
2468        super().setUp()
2469        self._spawn_processes()
2470
2471    def tearDown(self):
2472        super().tearDown()
2473        try:
2474            os.remove(self.file_name)
2475        except OSError:
2476            pass
2477
2478    @property
2479    def device(self):
2480        return torch.device("cpu")
2481
2482    @requires_gloo()
2483    def test_new_group_local_sync(self):
2484        self._test_new_group_local_sync(backend="gloo")
2485
2486    @requires_gloo()
2487    def test_new_group_local_sync_sanity_check(self):
2488        self._test_new_group_local_sync_sanity_check(backend="gloo")
2489
2490    @requires_gloo()
2491    def test_new_group_local_sync_duplicate_pg(self):
2492        self._test_new_group_local_sync_duplicate_pg(backend="gloo")
2493
2494
2495if __name__ == "__main__":
2496    assert (
2497        not torch.cuda._initialized
2498    ), "test_distributed must not have initialized CUDA context on main process"
2499
2500    run_tests()
2501