xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import os
5import pickle
6import sys
7import tempfile
8import threading
9import time
10from contextlib import nullcontext
11from dataclasses import dataclass
12from datetime import timedelta
13from itertools import product
14from sys import platform
15from typing import Dict, Optional
16
17import torch
18import torch.distributed as dist
19
20
21if not dist.is_available():
22    print("distributed package not available, skipping tests", file=sys.stderr)
23    sys.exit(0)
24
25import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
26import torch.distributed.distributed_c10d as c10d
27import torch.nn.functional as F
28import torch.testing._internal.common_utils as common
29from torch import nn
30from torch.nn.parallel import DistributedDataParallel
31from torch.testing._internal.common_distributed import (
32    MultiProcessTestCase,
33    skip_if_lt_x_gpu,
34)
35from torch.testing._internal.common_utils import (
36    instantiate_parametrized_tests,
37    load_tests,
38    parametrize,
39    retry_on_connect_failures,
40    run_tests,
41    TEST_WITH_DEV_DBG_ASAN,
42    TestCase,
43)
44from torch.utils.checkpoint import checkpoint
45
46
47if TEST_WITH_DEV_DBG_ASAN:
48    print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
49    sys.exit(0)
50
51# load_tests from common_utils is used to automatically filter tests for
52# sharding on sandcastle. This line silences flake warnings
53load_tests = load_tests
54
55if platform == "darwin":
56    LOOPBACK = "lo0"
57else:
58    LOOPBACK = "lo"
59
60torch.backends.cuda.matmul.allow_tf32 = False
61
62
63def gpus_for_rank(world_size):
64    """Multigpu tests are designed to simulate the multi nodes with multi
65    GPUs on each node. Nccl backend requires equal #GPUs in each process.
66    On a single node, all visible GPUs are evenly
67    divided to subsets, each process only uses a subset.
68    """
69    visible_devices = list(range(torch.cuda.device_count()))
70    gpus_per_process = torch.cuda.device_count() // world_size
71    gpus_for_rank = []
72    for rank in range(world_size):
73        gpus_for_rank.append(
74            visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
75        )
76    return gpus_for_rank
77
78
79class AbstractTimeoutTest:
80    def _test_store_timeout(self, backend, init_method, c2p):
81        try:
82            dist.init_process_group(
83                backend=backend,
84                init_method=init_method,
85                world_size=1,
86                rank=0,
87                timeout=timedelta(seconds=1),
88            )
89            default_store = c10d._get_default_store()
90            tik = time.time()
91            with self.assertRaisesRegex(RuntimeError, "(?i)timeout"):
92                default_store.get("nonexistent key")
93            tok = time.time()
94            dist.destroy_process_group()
95            c2p.append(float(tok - tik))
96        except RuntimeError as e:
97            # catch "Address already in use" error and report it to the main
98            # thread
99            c2p.append(e)
100
101    def _init_methods(self):
102        f = tempfile.NamedTemporaryFile(delete=False)
103        if sys.platform == "win32":
104            yield "file:///{}".format(f.name.replace("\\", "/"))
105            f.close()
106        else:
107            yield f"file://{f.name}"
108            f.close()
109            yield "tcp://127.0.0.1:%d" % common.find_free_port()
110
111    def _test_default_store_timeout(self, backend):
112        for init_method in self._init_methods():
113            c2p = []
114            t = threading.Thread(
115                target=self._test_store_timeout, args=(backend, init_method, c2p)
116            )
117            t.daemon = True
118            t.start()
119            t.join(5)
120
121            self.assertEqual(1, len(c2p))
122            if isinstance(c2p[0], float):
123                # waiting time should be 1s, use 3s to rule out false alarm
124                self.assertGreater(3, c2p[0])
125            elif isinstance(c2p[0], RuntimeError):
126                # let @retry_on_connect_failures handle the error
127                raise c2p[0]
128            else:
129                raise RuntimeError(f"Unexpected type {type(c2p[0])}")
130
131
132class TimeoutTest(TestCase):
133    @retry_on_connect_failures
134    def test_store_based_barrier(self):
135        f = tempfile.NamedTemporaryFile(delete=False)
136        port = common.find_free_port()
137
138        def thread_work(timeout, init_type, world_size, rank, error_list):
139            # we need to create a separate store just for the store barrier test
140            if init_type == "file":
141                barrier_store = dist.FileStore(f.name)
142            elif init_type == "tcp":
143                barrier_store = dist.TCPStore(
144                    "localhost",
145                    port,
146                    world_size,
147                    is_master=rank == 0,
148                    wait_for_workers=False,
149                )
150            elif init_type == "hash":
151                barrier_store = dist.HashStore()
152            try:
153                # 1 missing worker will cause it to timeout
154                if rank != world_size - 1:
155                    c10d._store_based_barrier(
156                        rank=rank,
157                        store=barrier_store,
158                        group_name="_",
159                        rendezvous_count=world_size,
160                        timeout=timeout,
161                        logging_interval=timeout / 2,
162                    )
163            except torch.distributed.DistStoreError as e:
164                self.assertTrue(isinstance(e, torch.distributed.DistError))
165                error_list.append(e)
166
167        world_size = 4
168        error_list = []
169        threads = []
170        for init_type in ["file", "tcp", "hash"]:
171            for rank in range(world_size):
172                t = threading.Thread(
173                    target=thread_work,
174                    args=(
175                        timedelta(seconds=3),
176                        init_type,
177                        world_size,
178                        rank,
179                        error_list,
180                    ),
181                )
182                threads.append(t)
183                t.start()
184
185            for i, thread in enumerate(threads):
186                thread.join()
187
188            # we expect the world_size-1 threads to have failed
189            self.assertEqual(len(error_list), world_size - 1)
190            for error in error_list:
191                self.assertTrue(
192                    "Timed out initializing process group in store based barrier"
193                    in error.args[0]
194                )
195            error_list = []
196            threads = []
197
198
199class Net(nn.Module):
200    def __init__(self) -> None:
201        super().__init__()
202        self.fc1 = nn.Linear(2, 10, bias=False)
203        self.fc2 = nn.Linear(10, 50, bias=False)
204        self.fc3 = nn.Linear(50, 4, bias=False)
205        self.relu = nn.ReLU()
206
207    def forward(self, x):
208        x = self.relu(self.fc1(x))
209        x = self.relu(self.fc2(x))
210        x = self.fc3(x)
211        return F.softmax(x, dim=1)
212
213
214class DoubleGpuNet(nn.Module):
215    def __init__(self, gpus):
216        super().__init__()
217        self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
218        self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
219        self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1])
220        self.relu = nn.ReLU()
221        self.no_grad_param = nn.Parameter(
222            torch.tensor([2, 2]).long(), requires_grad=False
223        ).to(gpus[0])
224
225    def forward(self, x):
226        dev0 = self.fc1.weight.device
227        dev1 = self.fc2.weight.device
228        x = self.relu(self.fc1(x.to(dev0)))
229        x = self.relu(self.fc2(x.to(dev1)))
230        x = self.fc3(x)
231        return F.softmax(x, dim=1).to(dev0)
232
233
234class QuadraGpuNet(nn.Module):
235    def __init__(self, gpus):
236        super().__init__()
237        self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
238        self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
239        self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2])
240        self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3])
241        self.relu = nn.ReLU()
242        self.no_grad_param = nn.Parameter(
243            torch.tensor([2, 2]).long(), requires_grad=False
244        ).to(gpus[0])
245
246    def forward(self, x):
247        dev0 = self.fc1.weight.device
248        dev1 = self.fc2.weight.device
249        dev2 = self.fc3.weight.device
250        dev3 = self.fc4.weight.device
251        x = self.relu(self.fc1(x.to(dev0)))
252        x = self.relu(self.fc2(x.to(dev1)))
253        x = self.relu(self.fc3(x.to(dev2)))
254        x = self.fc4(x.to(dev3))
255        return F.softmax(x, dim=1).to(dev0)
256
257
258class ConvNet(nn.Module):
259    def __init__(self, gpus, layouts, dtypes):
260        super().__init__()
261        self.dtypes = dtypes
262        if isinstance(gpus, list):
263            self.layer_gpus = gpus
264        else:
265            gpus = [gpus] * 4
266        self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(
267            device=gpus[0], memory_format=layouts[0], dtype=dtypes[0]
268        )
269        self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(
270            device=gpus[1], memory_format=layouts[1], dtype=dtypes[1]
271        )
272        self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(
273            device=gpus[2], memory_format=layouts[2], dtype=dtypes[2]
274        )
275        self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(
276            device=gpus[3], memory_format=layouts[3], dtype=dtypes[3]
277        )
278
279    def forward(self, x):
280        x = x.to(self.dtypes[0])
281        # Could say
282        # x = self.conv0(x).to(device=self.conv1.weight.device, dtype=self.dtypes[1])
283        # etc.  But I don't want to appeal to the weights' devices directly, because part of this test's purpose
284        # is to verify weights are where expected if the model gets replicated.
285        gpus = self.layer_gpus if hasattr(self, "layer_gpus") else [x.device] * 4
286        x = self.conv0(x).to(device=gpus[1], dtype=self.dtypes[1])
287        x = self.conv1(x).to(device=gpus[2], dtype=self.dtypes[2])
288        x = self.conv2(x).to(device=gpus[3], dtype=self.dtypes[3])
289        return self.conv3(x)
290
291
292class Task(nn.Module):
293    def __init__(self) -> None:
294        super().__init__()
295        self.p = nn.Parameter(torch.ones(2, 2))
296
297    def forward(self, x):
298        return self.p + x
299
300
301class ModuleForDdpCommHook(nn.Module):
302    def __init__(self) -> None:
303        super().__init__()
304        self.t0 = Task()
305
306    def forward(self, x, rank):
307        return self.t0(x + rank)
308
309
310class SparseGradientModule(nn.Module):
311    def __init__(self) -> None:
312        super().__init__()
313        self.embedding = nn.EmbeddingBag(10, 10, sparse=True)
314
315    def forward(self, x):
316        return F.softmax(self.embedding(x), dim=1)
317
318
319class CommonDistributedDataParallelTest:
320    def tearDown(self):
321        # DistributedDataParallel test doesn't seem to call FileStore destructor
322        # TODO: investigate this test and the test is known to have issues
323        # Use this hack to remove files for that test
324        try:
325            os.remove(self.file_name)
326        except OSError:
327            pass
328
329    @property
330    def world_size(self):
331        return 2
332
333    def _prepare_single_device_module(
334        self,
335        process_group,
336        devices,
337        device_ids,
338        global_batch_size,
339        gradient_as_bucket_view=False,
340    ):
341        model = Net()
342        device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
343        ddp_model = DistributedDataParallel(
344            copy.deepcopy(model).to(device),
345            device_ids=device_ids,
346            process_group=process_group,
347            bucket_cap_mb=0.001,
348            gradient_as_bucket_view=gradient_as_bucket_view,
349        )
350
351        model.to(device)
352
353        input = torch.randn(global_batch_size, 2).to(device)
354        target = torch.randn(global_batch_size, 4).to(device)
355
356        return model, ddp_model, input, target
357
358    def _prepare_multi_device_module(
359        self,
360        process_group,
361        devices,
362        device_ids,
363        global_batch_size,
364        gradient_as_bucket_view=False,
365    ):
366        self.assertTrue(
367            len(devices) == 2 or len(devices) == 4,
368            f"unexpected devices for ddp tests {devices}",
369        )
370        if len(devices) == 2:
371            model = DoubleGpuNet(devices)
372        elif len(devices) == 4:
373            model = QuadraGpuNet(devices)
374
375        ddp_model = DistributedDataParallel(
376            copy.deepcopy(model),
377            device_ids=device_ids,
378            process_group=process_group,
379            bucket_cap_mb=0.001,
380            gradient_as_bucket_view=gradient_as_bucket_view,
381        )
382
383        input = torch.randn(global_batch_size, 2).cuda(devices[0])
384        target = torch.randn(global_batch_size, 4)
385
386        return model, ddp_model, input, target
387
388    def _get_store(self):
389        return dist.FileStore(self.file_name, self.world_size)
390
391    def _get_process_group(self):
392        raise NotImplementedError("To be implemented by child class")
393
394    def _train_model(
395        self, model, input_var, target, loss, run_checkpoint=False, use_reentrant=True
396    ):
397        model.train()
398        if run_checkpoint:
399            output = checkpoint(model, input_var, use_reentrant=use_reentrant)
400        else:
401            output = model(input_var)
402        l = loss(output, target)
403        l.backward()
404
405    def _test_ddp_checkpointing(
406        self,
407        input_model,
408        process_group,
409        use_bucket_view,
410        find_unused_parameters=False,
411        static_graph=False,
412        run_checkpoint=False,
413        use_reentrant=True,
414        allow_none_grads=False,
415    ):
416        # to reproduce the same training results
417        torch.cuda.set_device(self.rank)
418        torch.manual_seed(31415)
419        model = copy.deepcopy(input_model).cuda()
420        ddp_model = copy.deepcopy(input_model).cuda()
421        ddp_model = nn.parallel.DistributedDataParallel(
422            ddp_model,
423            bucket_cap_mb=1,
424            gradient_as_bucket_view=use_bucket_view,
425            device_ids=[self.rank],
426            process_group=process_group,
427            find_unused_parameters=find_unused_parameters,
428            static_graph=static_graph,
429        )
430        self.assertEqual(
431            ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph
432        )
433        input, ddp_input, target, ddp_target = self._prepare_dummy_data()
434        loss = nn.MSELoss()
435        n_iters = 5
436        for i in range(n_iters):
437            model.zero_grad(set_to_none=False)
438            ddp_model.zero_grad(set_to_none=False)
439            self._train_model(
440                model,
441                input,
442                target,
443                loss,
444                run_checkpoint=run_checkpoint,
445                use_reentrant=use_reentrant,
446            )
447            self._train_model(
448                ddp_model,
449                ddp_input,
450                ddp_target,
451                loss,
452                run_checkpoint=run_checkpoint,
453                use_reentrant=use_reentrant,
454            )
455            for i, j in zip(model.parameters(), ddp_model.parameters()):
456                if not allow_none_grads:
457                    self.assertTrue(i.grad is not None)
458                    self.assertTrue(j.grad is not None)
459                self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5)
460
461    # A list of tests for ddp with activation checkpointing
462    # when gradient_as_bucket_view=True, False.
463    # Most of the tests are referred to
464    # https://github.com/facebookresearch/fairscale/blob/main/tests/nn/pipe/test_checkpoint_ddp.py
465    class CheckpointOnceModule(nn.Module):
466        """
467        Runs checkpoint for a single layer in the model.
468        """
469
470        def __init__(self, use_reentrant=True):
471            super().__init__()
472            self.l1 = nn.Linear(20, 20)
473            self.l2 = nn.Linear(20, 20)
474            self.use_reentrant = use_reentrant
475
476        def forward(self, inp):
477            x = self.l1(inp)
478            x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
479            return x
480
481    class CheckpointTwiceModule(CheckpointOnceModule):
482        """
483        Runs checkpoint for the same layer twice in a model. This simulates use
484        cases such as pipeline parallel where the same layer can be checkpointed
485        more than one time.
486        """
487
488        def __init__(self, use_reentrant=True):
489            super().__init__(use_reentrant=use_reentrant)
490
491        def forward(self, inp):
492            x = self.l1(inp)
493            x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
494            x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
495            return x
496
497    class CheckpointTwiceModuleWeightSharing(CheckpointTwiceModule):
498        """
499        Similar to CheckpointTwiceModule but the weights are shared.
500        """
501
502        def __init__(self, use_reentrant=True):
503            super().__init__(use_reentrant=use_reentrant)
504            # Share weights
505            self.l1.weight = self.l2.weight
506
507        def forward(self, inp):
508            x = self.l1(inp)
509            x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
510            x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
511            return x
512
513    class DynamicCheckpointTwiceModule(CheckpointTwiceModule):
514        def __init__(self, use_reentrant=True):
515            super().__init__(use_reentrant=use_reentrant)
516            self.count = 0
517
518        def forward(self, inp):
519            if self.count % 2:
520                x = checkpoint(self.l1, inp, use_reentrant=self.use_reentrant)
521            else:
522                x = checkpoint(self.l2, inp, use_reentrant=self.use_reentrant)
523
524            self.count += 1
525            return x
526
527    class DynamicCheckpointTwiceModuleWeightSharing(DynamicCheckpointTwiceModule):
528        def __init__(self, use_reentrant=True):
529            super().__init__(use_reentrant=use_reentrant)
530            # Share weights
531            self.l1.weight = self.l2.weight
532
533    def _prepare_dummy_data(self):
534        ddp_bs = 16
535        bs = ddp_bs * self.world_size
536        input = torch.rand((bs, 20), device="cuda", requires_grad=True)
537        target = torch.randn((bs, 20), device="cuda")
538        offset = self.rank * ddp_bs
539        ddp_input = input[offset : offset + ddp_bs]
540        ddp_target = target[offset : offset + ddp_bs]
541        return input, ddp_input, target, ddp_target
542
543    @skip_if_lt_x_gpu(2)
544    @parametrize("use_reentrant", [True, False])
545    def test_ddp_checkpointing_once(self, use_reentrant):
546        """
547        DDP works as expected when layer is checkpointed only once.
548        """
549        process_group = self._get_process_group()
550        for use_bucket_view, static_graph in product((False, True), (False, True)):
551            self._test_ddp_checkpointing(
552                self.CheckpointOnceModule(use_reentrant=use_reentrant),
553                process_group=process_group,
554                use_bucket_view=use_bucket_view,
555                static_graph=static_graph,
556            )
557            if static_graph:
558                # find_unused_parameters does not make a difference, since it is
559                # ignored for static graph.
560                self._test_ddp_checkpointing(
561                    self.CheckpointOnceModule(),
562                    process_group=process_group,
563                    use_bucket_view=use_bucket_view,
564                    static_graph=static_graph,
565                    find_unused_parameters=True,
566                )
567
568    @skip_if_lt_x_gpu(2)
569    @parametrize("use_reentrant", [True, False])
570    def test_ddp_checkpointing_unused_params(self, use_reentrant):
571        """
572        With reentrant autograd checkpointing impl, DDP will fail when there are
573        unused params in the model and no static graph training. With
574        non-reentrant checkpointing implementation, this works as expected.
575        """
576        process_group = self._get_process_group()
577        for use_bucket_view in (True, False):
578            err_ctx = (
579                nullcontext()
580                if not use_reentrant
581                else self.assertRaisesRegex(
582                    RuntimeError, "Expected to mark a variable ready only once."
583                )
584            )
585            with err_ctx:
586                model = self._test_ddp_checkpointing(
587                    self.CheckpointOnceModule(use_reentrant=use_reentrant),
588                    process_group=process_group,
589                    use_bucket_view=use_bucket_view,
590                    find_unused_parameters=True,
591                )
592            # test passes when static_graph is true
593            model = self._test_ddp_checkpointing(
594                self.CheckpointOnceModule(use_reentrant=use_reentrant),
595                process_group=process_group,
596                use_bucket_view=use_bucket_view,
597                find_unused_parameters=True,
598                static_graph=True,
599            )
600
601    @skip_if_lt_x_gpu(2)
602    @parametrize("use_reentrant", [True, False])
603    def test_ddp_checkpointing_twice(self, use_reentrant):
604        """
605        Checkpointing twice fails for non-static graph with reentrant checkpoint
606        implementation, succeeds with non-reentrant checkpoint implementation.
607        """
608        process_group = self._get_process_group()
609        for use_bucket_view in (True, False):
610            err_ctx = (
611                nullcontext()
612                if not use_reentrant
613                else self.assertRaisesRegex(
614                    RuntimeError, "Expected to mark a variable ready only once."
615                )
616            )
617            with err_ctx:
618                model = self._test_ddp_checkpointing(
619                    self.CheckpointTwiceModule(use_reentrant=use_reentrant),
620                    process_group=process_group,
621                    use_bucket_view=use_bucket_view,
622                    static_graph=False,
623                )
624
625            with err_ctx:
626                model = self._test_ddp_checkpointing(
627                    self.CheckpointTwiceModule(use_reentrant=use_reentrant),
628                    process_group=process_group,
629                    use_bucket_view=use_bucket_view,
630                    static_graph=False,
631                    find_unused_parameters=True,
632                )
633
634    @skip_if_lt_x_gpu(2)
635    @parametrize("use_reentrant", [True, False])
636    def test_ddp_checkpointing_twice_static_graph(self, use_reentrant):
637        """
638        Regardless of reentrant or non-reentrant checkpointing impl,
639        checkpointing twice works with static graph enabled.
640        """
641        process_group = self._get_process_group()
642        for use_bucket_view in (True, False):
643            # Test passes when static_graph=True.
644            model = self._test_ddp_checkpointing(
645                self.CheckpointTwiceModule(use_reentrant=use_reentrant),
646                process_group=process_group,
647                use_bucket_view=use_bucket_view,
648                static_graph=True,
649            )
650
651    @skip_if_lt_x_gpu(2)
652    def test_ddp_checkpointing_dynamic_module(self):
653        """
654        Dynamic module can be checkpointed, multiple times, with non-reentrant
655        checkpointing implementation.
656        """
657        process_group = self._get_process_group()
658        for use_bucket_view in (True, False):
659            model = self._test_ddp_checkpointing(
660                self.DynamicCheckpointTwiceModule(use_reentrant=False),
661                process_group=process_group,
662                use_bucket_view=use_bucket_view,
663                static_graph=False,
664                find_unused_parameters=True,
665                # Grads can be none sometimes due to dynamic module not using
666                # all params.
667                allow_none_grads=True,
668            )
669
670    @skip_if_lt_x_gpu(2)
671    def test_ddp_checkpointing_dynamic_weight_sharing(self):
672        """
673        Dynamic module can be checkpointed multiple times with weight sharing
674        using non-reentrant checkpointing implementation.
675        """
676        process_group = self._get_process_group()
677        for use_bucket_view in (True, False):
678            model = self._test_ddp_checkpointing(
679                self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False),
680                process_group=process_group,
681                use_bucket_view=use_bucket_view,
682                static_graph=False,
683                find_unused_parameters=True,
684                # Grads can be none sometimes due to dynamic module not using
685                # all params.
686                allow_none_grads=True,
687            )
688
689    # DDP works as expected if there is weight sharing among layers
690    @skip_if_lt_x_gpu(2)
691    @parametrize("use_reentrant", [True, False])
692    def test_ddp_checkpointing_weight_sharing(self, use_reentrant):
693        """
694        Test that checkpointing with weight sharing works.
695        """
696        process_group = self._get_process_group()
697        torch.cuda.set_device(self.rank)
698        for use_bucket_view, static_graph in product((False, True), (False, True)):
699            torch.manual_seed(31415)
700            l1 = nn.Linear(20, 20)
701            l2 = nn.Linear(20, 20)
702            l1.weight = l2.weight
703            model = nn.Sequential(l1, l2)
704            self._test_ddp_checkpointing(
705                model,
706                process_group=process_group,
707                use_bucket_view=use_bucket_view,
708                static_graph=static_graph,
709                run_checkpoint=True,
710                use_reentrant=use_reentrant,
711            )
712
713    @skip_if_lt_x_gpu(2)
714    def test_ddp_checkpointing_twice_weight_sharing(self):
715        """
716        Checkpointing should work with static graph in the case of checkpointing
717        same layer twice and having weights shared across layers.
718        """
719        process_group = self._get_process_group()
720        torch.cuda.set_device(self.rank)
721        for use_bucket_view in (True, False):
722            model = self._test_ddp_checkpointing(
723                self.CheckpointTwiceModuleWeightSharing(),
724                process_group=process_group,
725                use_bucket_view=use_bucket_view,
726                static_graph=True,
727            )
728
729    def test_invalid_powerSGD_state(self):
730        for start_powerSGD_iter, use_error_feedback, warm_start in product(
731            [0, 1], [True, False], [True, False]
732        ):
733            if not use_error_feedback and not warm_start:
734                continue
735            with self.assertRaisesRegex(
736                ValueError,
737                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
738                "because PowerSGD can only be applied after the first two iterations in DDP.",
739            ):
740                state = powerSGD.PowerSGDState(
741                    process_group=None,
742                    matrix_approximation_rank=1,
743                    start_powerSGD_iter=start_powerSGD_iter,
744                    use_error_feedback=use_error_feedback,
745                    warm_start=warm_start,
746                )
747
748    def _test_ddp_with_process_group(
749        self,
750        process_group,
751        devices,
752        device_ids,
753        multi_device=False,
754        gradient_as_bucket_view=False,
755    ):
756        """
757        Note: we pass down `device_ids` all the way to DistributedDataParallel
758        as part of the test. Below you find tests that either use a list of
759        integers, a list of `torch.Device` instances, or an empty list.
760        The `devices` argument is used to control placement of the model and
761        must always be specified as list of `torch.Device` instances.
762        """
763        local_batch_size = 1 if devices is None else len(devices)
764        global_batch_size = self.world_size * local_batch_size
765
766        if multi_device:
767            model, ddp_model, input, target = self._prepare_multi_device_module(
768                process_group,
769                devices,
770                device_ids,
771                global_batch_size,
772                gradient_as_bucket_view,
773            )
774            ddp_logging_data = ddp_model._get_ddp_logging_data()
775            self.assertTrue(ddp_logging_data.get("is_multi_device_module"))
776        else:
777            model, ddp_model, input, target = self._prepare_single_device_module(
778                process_group,
779                devices,
780                device_ids,
781                global_batch_size,
782                gradient_as_bucket_view,
783            )
784            ddp_logging_data = ddp_model._get_ddp_logging_data()
785            self.assertFalse(ddp_logging_data.get("is_multi_device_module"))
786
787        def step_model(model, input, target):
788            model.train()
789            output = model(input)
790            loss = F.mse_loss(output, target.to(output.device))
791            loss.backward()
792
793        def update_parameters(model):
794            for param in model.parameters():
795                with torch.no_grad():
796                    param -= param.grad
797                param.grad = None
798
799        # check two model parameters over 2 iterations
800        for iteration in range(2):
801            # single cpu/gpu training
802            step_model(model, input, target)
803
804            # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
805            step_model(
806                ddp_model,
807                input[
808                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
809                ],
810                target[
811                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
812                ],
813            )
814
815            # Update weights and run a second iteration to shake out errors
816            update_parameters(model)
817            update_parameters(ddp_model)
818            self.assertEqual(
819                len(list(model.parameters())), len(list(ddp_model.parameters()))
820            )
821            for i, j in zip(model.parameters(), ddp_model.parameters()):
822                self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
823
824            # Shuffle the input so that DDP input is different
825            torch.manual_seed(1337 + iteration)
826            input = input[torch.randperm(global_batch_size)]
827
828    def _gpu_model_with_ddp_comm_hook(
829        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
830    ):
831        device_id = gpus_for_rank(self.world_size)[self.rank][0]
832        gpu_model = DistributedDataParallel(
833            ModuleForDdpCommHook().to(device_id),
834            device_ids=[device_id],
835            process_group=process_group,
836            gradient_as_bucket_view=gradient_as_bucket_view,
837        )
838
839        # Register a DDP communication hook if any.
840        if hook is not None:
841            gpu_model.register_comm_hook(state, hook)
842
843        return gpu_model
844
845    def _gpu_model_with_builtin_ddp_comm_hook(
846        self, process_group, hook=None, gradient_as_bucket_view=False
847    ):
848        device_id = gpus_for_rank(self.world_size)[self.rank][0]
849        gpu_model = DistributedDataParallel(
850            ModuleForDdpCommHook().to(device_id),
851            device_ids=[device_id],
852            process_group=process_group,
853            gradient_as_bucket_view=gradient_as_bucket_view,
854        )
855
856        # Register a built-in DDP communication hook if defined
857        if hook is not None:
858            gpu_model._register_builtin_comm_hook(hook)
859
860        return gpu_model
861
862    def _run_and_verify_hook(self, model, input, expected_grad):
863        # Run forward
864        output = model(input, self.rank)
865
866        # Run backward
867        output.mean().backward()
868
869        [self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
870
871    def _simple_hook(
872        self, state: object, bucket: dist.GradBucket
873    ) -> torch.futures.Future[torch.Tensor]:
874        fut = torch.futures.Future()
875        fut.set_result(torch.ones_like(bucket.buffer()))
876
877        def fut_then(fut):
878            # Add ones to fut's result.
879            t = fut.value()
880            return t + torch.ones_like(t)
881
882        return fut.then(fut_then)
883
884    def _test_not_nan(self, model, x):
885        y = model(x)
886        self.assertFalse(y.isnan().any().item())
887        y.sum().backward()
888        for p in model.parameters():
889            self.assertFalse(p.grad.isnan().any().item())
890
891    @skip_if_lt_x_gpu(2)
892    def test_sync_batch_norm_only_empty_input(self):
893        pg = self._get_process_group()
894
895        model = torch.nn.Sequential(
896            nn.BatchNorm2d(2),
897        ).to(device=self.rank)
898        model = DistributedDataParallel(
899            model,
900            device_ids=[self.rank],
901            process_group=pg,
902        )
903        model = nn.SyncBatchNorm.convert_sync_batchnorm(
904            model,
905            process_group=pg,
906        )
907
908        model.train()
909
910        # only rank 0 receives empty inputs
911        x = torch.zeros(
912            (1 if self.rank != 0 else 0, 2, 11, 13),
913            dtype=torch.float32,
914            device=self.rank,
915        )
916
917        # input requires grad, this will trigger the collective communication
918        # in the backward pass
919        x.requires_grad = True
920        self._test_not_nan(model, x)
921
922        # input does not requires grad
923        x.requires_grad = False
924        self._test_not_nan(model, x)
925
926        # all ranks receive empty inputs
927        x = torch.zeros((0, 2, 11, 13), dtype=torch.float32, device=self.rank)
928
929        # input requires grad, this will trigger the collective communication
930        # in the backward pass
931        x.requires_grad = True
932        self._test_not_nan(model, x)
933
934        # input does not requires grad
935        x.requires_grad = False
936        self._test_not_nan(model, x)
937
938    @skip_if_lt_x_gpu(2)
939    def test_sync_batch_norm_empty_input(self):
940        pg = self._get_process_group()
941
942        model = torch.nn.Sequential(
943            nn.Conv2d(2, 2, 3),
944            nn.BatchNorm2d(2),
945            nn.Linear(28, 2),
946        ).to(device=self.rank)
947        model = DistributedDataParallel(
948            model,
949            device_ids=[self.rank],
950            process_group=pg,
951        )
952        model = nn.SyncBatchNorm.convert_sync_batchnorm(
953            model,
954            process_group=pg,
955        )
956
957        model.train()
958        # only rank 0 receives empty inputs
959        x = torch.zeros(
960            (3 if self.rank != 0 else 0, 2, 30, 30),
961            dtype=torch.float32,
962            device=self.rank,
963        )
964
965        self._test_not_nan(model, x)
966
967        # all ranks receive empty inputs
968        x = torch.zeros((0, 2, 30, 30), dtype=torch.float32, device=self.rank)
969
970        self._test_not_nan(model, x)
971
972    @dataclass
973    class CustomOutput:
974        o1: Optional[torch.Tensor]
975        o2: Dict[str, torch.Tensor]
976
977    class DataclassOutputModule(nn.Module):
978        def __init__(self, skip_o1):
979            super().__init__()
980            self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(3)])
981            self.relu = nn.ReLU()
982            self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(3)])
983            self.skip_o1 = skip_o1
984
985        def forward(self, x):
986            o1 = None if self.skip_o1 else self.relu(self.seq1(x))
987            o2 = {"a": self.seq2(x), "b": self.relu(self.seq2(x))}
988            return CommonDistributedDataParallelTest.CustomOutput(o1=o1, o2=o2)
989
990    def _test_dataclass_output(self, skip_o1):
991        net_x = torch.cat([torch.ones(4, 10) * i for i in range(self.world_size)]).to(
992            self.rank
993        )
994        ddp_x = torch.ones(4, 10, device=self.rank) * self.rank
995
996        # use manual_seed to make sure local models start with the same values
997        torch.manual_seed(0)
998        net = self.DataclassOutputModule(skip_o1=skip_o1).to(self.rank)
999        ddp = DistributedDataParallel(
1000            copy.deepcopy(net),
1001            device_ids=[self.rank],
1002            find_unused_parameters=True,
1003            static_graph=False,
1004            process_group=self._get_process_group(),
1005        )
1006
1007        net_out = net(net_x)
1008        ddp_out = ddp(ddp_x)
1009
1010        net_loss = F.mse_loss(
1011            net_out.o1 + net_out.o2["a"] + net_out.o2["b"]
1012            if not skip_o1
1013            else net_out.o2["a"] + net_out.o2["b"],
1014            torch.ones_like(net_out.o2["a"], device=self.rank),
1015        )
1016        ddp_loss = F.mse_loss(
1017            ddp_out.o1 + ddp_out.o2["a"] + ddp_out.o2["b"]
1018            if not skip_o1
1019            else ddp_out.o2["a"] + ddp_out.o2["b"],
1020            torch.ones_like(ddp_out.o2["a"], device=self.rank),
1021        )
1022
1023        net_loss.backward()
1024        ddp_loss.backward()
1025
1026        for p1, p2 in zip(net.parameters(), ddp.parameters()):
1027            if torch.is_tensor(p1.grad):
1028                self.assertTrue(p1.grad.allclose(p2.grad))
1029            else:
1030                self.assertEqual(p1.grad, p2.grad)
1031
1032    @skip_if_lt_x_gpu(2)
1033    def test_dataclass_output(self):
1034        self._test_dataclass_output(skip_o1=False)
1035
1036    @skip_if_lt_x_gpu(2)
1037    def test_dataclass_output_unused_param(self):
1038        self._test_dataclass_output(skip_o1=True)
1039
1040
1041class ComputeBucketAssignmentTest(TestCase):
1042    def test_single_limit_single_dtype(self):
1043        tensors = [
1044            torch.empty([100], dtype=torch.float),
1045            torch.empty([200], dtype=torch.float),
1046            torch.empty([100], dtype=torch.float),
1047            torch.empty([50], dtype=torch.float),
1048        ]
1049        result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
1050            tensors, [400]
1051        )
1052        self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
1053        self.assertEqual([[0], [1], [2], [3]], result)
1054
1055    def test_single_limit_multi_dtype(self):
1056        tensors = [
1057            torch.empty([50], dtype=torch.float),
1058            torch.empty([25], dtype=torch.double),
1059            torch.empty([50], dtype=torch.float),
1060            torch.empty([25], dtype=torch.double),
1061            torch.empty([50], dtype=torch.float),
1062            torch.empty([25], dtype=torch.double),
1063        ]
1064        result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
1065            tensors, [400]
1066        )
1067        self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
1068        self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
1069
1070    def test_multi_limit_single_dtype(self):
1071        tensors = [
1072            torch.empty([10], dtype=torch.float),
1073            torch.empty([10], dtype=torch.float),
1074            torch.empty([10], dtype=torch.float),
1075            torch.empty([10], dtype=torch.float),
1076        ]
1077        result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
1078            tensors, [40, 80]
1079        )
1080        self.assertEqual(per_bucket_size_limits, [40, 80, 80])
1081        self.assertEqual([[0], [1, 2], [3]], result)
1082
1083    def test_multi_limit_multi_dtype(self):
1084        tensors = [
1085            torch.empty([50], dtype=torch.float),
1086            torch.empty([25], dtype=torch.double),
1087            torch.empty([50], dtype=torch.float),
1088            torch.empty([25], dtype=torch.double),
1089            torch.empty([50], dtype=torch.float),
1090            torch.empty([25], dtype=torch.double),
1091        ]
1092        result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
1093            tensors, [200, 400]
1094        )
1095        self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
1096        self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400])
1097
1098
1099class AbstractCommTest:
1100    @property
1101    def op_timeout_sec(self):
1102        return 1
1103
1104    @property
1105    def world_size(self):
1106        return 2
1107
1108    @property
1109    def device(self):
1110        self.fail("test subclass didn't override device")
1111
1112    def _verify_sequence_number_across_pg(self, pg, verify_pg):
1113        seq_num = pg._get_sequence_number_for_group()
1114        obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
1115        # We use a separate pg to verify the sequence numbers, otherwise these
1116        # collectives will themselves increment the sequence number.
1117        dist.all_gather_object(obj_list, seq_num, group=verify_pg)
1118        self.assertEqual(len(set(obj_list)), 1)
1119        return obj_list[0]
1120
1121    def _test_sequence_num_incremented(self, process_group, ranks):
1122        # verify initial sequence numbers. Use a distinct process group for
1123        # verification to keep counts as expected with respect to process_group.
1124        verify_pg = dist.new_group(
1125            ranks=ranks,
1126            backend="gloo",
1127        )
1128        assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)
1129
1130        initial_num = (
1131            self._verify_sequence_number_across_pg(
1132                pg=process_group, verify_pg=verify_pg
1133            )
1134            if not c10d._rank_not_in_group(process_group)
1135            else -1
1136        )
1137
1138        # Verify sequence numbers are appropriately incremented
1139        for i in range(10):
1140            t = torch.ones(1, device=torch.cuda.current_device())
1141            dist.all_reduce(t, group=process_group)
1142            if not c10d._rank_not_in_group(process_group):
1143                seq_num = self._verify_sequence_number_across_pg(
1144                    pg=process_group,
1145                    verify_pg=verify_pg,
1146                )
1147                self.assertEqual(initial_num + i + 1, seq_num)
1148
1149        if dist.get_world_size(process_group) > 2:
1150            # Test when certain ranks don't call collectives
1151            if dist.get_rank(process_group) not in [0, 2]:
1152                dist.all_reduce(t, group=process_group, async_op=True)
1153            # Now ranks 0 and 2 should be lagging by 1.
1154            if not c10d._rank_not_in_group(process_group):
1155                seq_num = process_group._get_sequence_number_for_group()
1156                rank = dist.get_rank(process_group)
1157                obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
1158                dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
1159                rank_to_seq_num = dict(obj_list)
1160                self.assertEqual(len(set(rank_to_seq_num.values())), 2)
1161                self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
1162                expected_same = {
1163                    rank_to_seq_num[i]
1164                    for i in rank_to_seq_num.keys()
1165                    if i not in [0, 2]
1166                }
1167                self.assertEqual(len(expected_same), 1)
1168                self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
1169
1170    def _test_sequence_num_incremented_default_group(self, backend_name):
1171        torch.cuda.set_device(self.rank)
1172        store = dist.FileStore(self.file_name, self.world_size)
1173        dist.init_process_group(
1174            backend_name,
1175            world_size=self.world_size,
1176            rank=self.rank,
1177            store=store,
1178        )
1179        self._test_sequence_num_incremented(
1180            c10d._get_default_group(),
1181            ranks=list(range(dist.get_world_size())),
1182        )
1183
1184    def _test_sequence_num_incremented_subgroup(self, backend_name):
1185        torch.cuda.set_device(self.rank)
1186        store = dist.FileStore(self.file_name, self.world_size)
1187        dist.init_process_group(
1188            backend_name,
1189            world_size=self.world_size,
1190            rank=self.rank,
1191            store=store,
1192        )
1193        subgroup_ranks = [0, 1, 2]
1194        subgroup = dist.new_group(subgroup_ranks)
1195        self._test_sequence_num_incremented(subgroup, subgroup_ranks)
1196
1197    def _test_sequence_num_set_default_pg(self, backend):
1198        store = dist.FileStore(self.file_name, self.world_size)
1199        dist.init_process_group(
1200            backend,
1201            world_size=self.world_size,
1202            rank=self.rank,
1203            store=store,
1204        )
1205
1206        default_pg = c10d._get_default_group()
1207        seq_num = default_pg._get_sequence_number_for_group()
1208        obj_list = [None for _ in range(dist.get_world_size())]
1209        dist.all_gather_object(obj_list, seq_num)
1210        self.assertEqual(len(set(obj_list)), 1)
1211
1212    def _test_sequence_num_set_new_group(self, backend):
1213        store = dist.FileStore(self.file_name, self.world_size)
1214        dist.init_process_group(
1215            backend,
1216            world_size=self.world_size,
1217            rank=self.rank,
1218            store=store,
1219        )
1220
1221        subgroup = dist.new_group([0, 1])
1222
1223        if not c10d._rank_not_in_group(subgroup):
1224            subgroup_seq = subgroup._get_sequence_number_for_group()
1225            obj_list = [None for _ in range(dist.get_world_size(subgroup))]
1226            dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
1227            self.assertEqual(len(set(obj_list)), 1)
1228
1229    def _test_warn_not_in_group(self, backend):
1230        store = dist.FileStore(self.file_name, self.world_size)
1231        dist.init_process_group(
1232            backend,
1233            world_size=self.world_size,
1234            rank=self.rank,
1235            store=store,
1236        )
1237        in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
1238        group = dist.new_group(in_group_ranks)
1239
1240        x = torch.zeros(2, 2).cuda(self.rank)
1241        xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
1242        if self.rank not in in_group_ranks:
1243            msg = ".*{}.*does not belong to.*"
1244            with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
1245                dist.all_gather(xs, x, group=group)
1246            with self.assertWarnsOnceRegex(UserWarning, msg.format("all_reduce")):
1247                dist.all_reduce(x, group=group)
1248            with self.assertWarnsOnceRegex(UserWarning, msg.format("barrier")):
1249                dist.barrier(group=group)
1250            with self.assertWarnsOnceRegex(UserWarning, msg.format("broadcast")):
1251                dist.broadcast(x, src=0, group=group)
1252        else:
1253            dist.all_gather(xs, x, group=group)
1254            dist.all_reduce(x, group=group)
1255            dist.barrier(group=group)
1256            dist.broadcast(x, src=0, group=group)
1257
1258    def _test_rank_membership(self, backend):
1259        store = dist.FileStore(self.file_name, self.world_size)
1260        dist.init_process_group(
1261            backend,
1262            world_size=self.world_size,
1263            rank=self.rank,
1264            store=store,
1265        )
1266        self.assertTrue(self.world_size > 1)
1267
1268        group = dist.new_group(ranks=[1])
1269        self.assertEqual(dist.get_group_rank(group, 1), 0)
1270        with self.assertRaisesRegex(ValueError, "not part of group"):
1271            dist.get_group_rank(group, 0)
1272        with self.assertRaisesRegex(ValueError, "not registered"):
1273            dist.get_group_rank(DummyProcessGroup(self.rank, self.world_size), 0)
1274
1275        self.assertEqual(dist.get_global_rank(group, 0), 1)
1276        with self.assertRaisesRegex(ValueError, "not part of group"):
1277            dist.get_global_rank(group, 1)
1278        with self.assertRaisesRegex(ValueError, "not registered"):
1279            dist.get_global_rank(DummyProcessGroup(self.rank, self.world_size), 0)
1280
1281        self.assertEqual(dist.get_process_group_ranks(group), [1])
1282
1283    def _test_tensor_dtype_mismatch(self, backend):
1284        store = dist.FileStore(self.file_name, self.world_size)
1285        dist.init_process_group(
1286            backend,
1287            world_size=self.world_size,
1288            rank=self.rank,
1289            store=store,
1290        )
1291
1292        tensor = torch.ones(2, 2, device=self.device) * 7
1293        tensor_h = tensor.half()
1294        tensor_list = [
1295            torch.zeros(2, 2, device=self.device) for _ in range(self.world_size)
1296        ]
1297        tensor_list_h = list(tensor_list)
1298        tensor_list_h[1] = tensor_list_h[1].half()
1299
1300        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1301            dist.all_gather(tensor_list_h, tensor)
1302
1303        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1304            dist.all_gather(tensor_list, tensor_h)
1305
1306        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1307            dist.all_gather_coalesced([tensor_list_h], tensor_list)
1308            dist.all_gather_coalesced([tensor_list], tensor_list_h)
1309
1310        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1311            dist.all_reduce_coalesced(tensor_list_h)
1312
1313        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1314            dist.reduce_scatter(tensor, tensor_list_h)
1315
1316        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1317            dist.reduce_scatter(tensor_h, tensor_list)
1318
1319        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1320            dist.all_to_all_single(tensor_h, tensor)
1321
1322        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1323            dist.all_to_all(tensor_list_h, tensor_list)
1324
1325        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1326            dist.all_to_all(tensor_list, tensor_list_h)
1327
1328        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1329            dist.scatter(tensor, tensor_list_h)
1330
1331        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1332            dist.gather(tensor_h, tensor_list)
1333
1334        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1335            dist.gather(tensor, tensor_list_h)
1336
1337        with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
1338            dist.scatter(tensor_h, tensor_list)
1339
1340    def _test_tensor_dtype_complex(self, backend):
1341        store = dist.FileStore(self.file_name, self.world_size)
1342        dist.init_process_group(
1343            backend,
1344            world_size=self.world_size,
1345            rank=self.rank,
1346            store=store,
1347        )
1348
1349        tensor = torch.rand(2, device=self.device)
1350        tensor_c = torch.view_as_complex(tensor)
1351        tensor_list = [
1352            torch.rand(2, device=self.device) for _ in range(self.world_size)
1353        ]
1354        tensor_list_c = list(tensor_list)
1355        tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1])
1356
1357        dist.all_gather(tensor_list, tensor)
1358        dist.all_gather(tensor_list, tensor_c)
1359        dist.all_gather(tensor_list_c, tensor)
1360        dist.all_gather(tensor_list_c, tensor_c)
1361
1362    def _test_bool_tensors(self, backend):
1363        store = dist.FileStore(self.file_name, self.world_size)
1364        dist.init_process_group(
1365            backend,
1366            world_size=self.world_size,
1367            rank=self.rank,
1368            store=store,
1369        )
1370        device = "cuda" if backend == "nccl" else "cpu"
1371        # test alltoall_base
1372        tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
1373        zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
1374        outensor = zeros if self.rank > 0 else tensor
1375        dist.broadcast(outensor, src=0)
1376        self.assertEqual(outensor, tensor)
1377
1378
1379# Variant of AbstractCommTest that expects world size of 4
1380class AbstractLargeCommTest:
1381    @property
1382    def op_timeout_sec(self):
1383        return 1
1384
1385    @property
1386    def world_size(self):
1387        return 4
1388
1389    @property
1390    def device(self):
1391        raise RuntimeError("Implement me")
1392
1393    def _test_new_group_local_sync(self, backend):
1394        store = dist.FileStore(self.file_name, self.world_size)
1395        dist.init_process_group(
1396            backend,
1397            world_size=self.world_size,
1398            rank=self.rank,
1399            store=store,
1400        )
1401        rank = dist.get_rank()
1402        ranks_in = [rank, (rank + 2) % self.world_size]
1403        ranks_out = [i for i in range(self.world_size) if i not in ranks_in]
1404        self.assertIn(rank, ranks_in)
1405        self.assertNotIn(rank, ranks_out)
1406
1407        self.assertIsNone(
1408            dist.new_group(ranks=ranks_out, use_local_synchronization=True)
1409        )
1410
1411        new_pg = dist.new_group(ranks=ranks_in, use_local_synchronization=True)
1412        self.assertIsInstance(new_pg, dist.ProcessGroup)
1413
1414        # PTD sorts ranks before creating the PG, so [3, 1] actually gets assigned ranks [1, 0]
1415        ranks_in.sort()
1416        self.assertEqual(dist.get_group_rank(new_pg, rank), ranks_in.index(rank))
1417        self.assertEqual(
1418            ranks_in,
1419            dist.get_process_group_ranks(new_pg),
1420            f"expecting {ranks_in} but got {dist.get_process_group_ranks(new_pg)}",
1421        )
1422
1423    def _test_new_group_local_sync_sanity_check(self, backend):
1424        store = dist.FileStore(self.file_name, self.world_size)
1425        dist.init_process_group(
1426            backend,
1427            world_size=self.world_size,
1428            rank=self.rank,
1429            store=store,
1430        )
1431        rank = dist.get_rank()
1432
1433        # split the world in 2 PGs
1434        rank = dist.get_rank()
1435        pg_idx = rank // 2
1436        ranks_in = [pg_idx * 2, pg_idx * 2 + 1]
1437        new_pg = dist.new_group(ranks=ranks_in, use_local_synchronization=True)
1438
1439        input_tensor = torch.tensor([pg_idx, rank], device=self.device)
1440        output_tensor_list = [
1441            torch.tensor(
1442                [-1, -1],
1443                device=self.device,
1444            )
1445            for _ in range(new_pg.size())
1446        ]
1447        dist.all_gather(output_tensor_list, input_tensor, group=new_pg)
1448
1449        expected = [
1450            torch.tensor([pg_idx, ranks_in[0]], device=self.device),
1451            torch.tensor([pg_idx, ranks_in[1]], device=self.device),
1452        ]
1453        self.assertEqual(output_tensor_list, expected)
1454
1455    def _test_new_group_local_sync_duplicate_pg(self, backend):
1456        """
1457        We should support users create multiple PGs with the same set of
1458        members, and no conflict in group name
1459        """
1460        store = dist.FileStore(self.file_name, self.world_size)
1461        dist.init_process_group(
1462            backend,
1463            world_size=self.world_size,
1464            rank=self.rank,
1465            store=store,
1466        )
1467        rank = dist.get_rank()
1468
1469        # split the world in 2 PGs
1470        rank = dist.get_rank()
1471        pg_idx = rank // 2
1472        ranks_in = [pg_idx * 2, pg_idx * 2 + 1]
1473        new_pgs = []
1474        for _ in range(2):
1475            new_pgs.append(
1476                dist.new_group(ranks=ranks_in, use_local_synchronization=True)
1477            )
1478
1479        input_tensor = torch.tensor([pg_idx, rank], device=self.device)
1480        for new_pg in new_pgs:
1481            output_tensor_list = [
1482                torch.tensor(
1483                    [-1, -1],
1484                    device=self.device,
1485                )
1486                for _ in range(new_pg.size())
1487            ]
1488            dist.all_gather(output_tensor_list, input_tensor, group=new_pg)
1489
1490            expected = [
1491                torch.tensor([pg_idx, ranks_in[0]], device=self.device),
1492                torch.tensor([pg_idx, ranks_in[1]], device=self.device),
1493            ]
1494            self.assertEqual(output_tensor_list, expected)
1495
1496
1497class CommTest(AbstractCommTest, MultiProcessTestCase):
1498    def setUp(self):
1499        super().setUp()
1500        self._spawn_processes()
1501
1502    def tearDown(self):
1503        super().tearDown()
1504        try:
1505            os.remove(self.file_name)
1506        except OSError:
1507            pass
1508
1509    def test_debug_level(self):
1510        try:
1511            del os.environ["TORCH_DISTRIBUTED_DEBUG"]
1512        except KeyError:
1513            pass
1514
1515        dist.set_debug_level_from_env()
1516        # Default should be off
1517        default_debug_mode = dist.get_debug_level()
1518        self.assertEqual(default_debug_mode, dist.DebugLevel.OFF)
1519        mapping = {
1520            "OFF": dist.DebugLevel.OFF,
1521            "off": dist.DebugLevel.OFF,
1522            "oFf": dist.DebugLevel.OFF,
1523            "INFO": dist.DebugLevel.INFO,
1524            "info": dist.DebugLevel.INFO,
1525            "INfO": dist.DebugLevel.INFO,
1526            "DETAIL": dist.DebugLevel.DETAIL,
1527            "detail": dist.DebugLevel.DETAIL,
1528            "DeTaIl": dist.DebugLevel.DETAIL,
1529        }
1530        invalid_debug_modes = ["foo", 0, 1, -1]
1531
1532        for mode in mapping.keys():
1533            os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
1534            dist.set_debug_level_from_env()
1535            set_debug_mode = dist.get_debug_level()
1536            self.assertEqual(
1537                set_debug_mode,
1538                mapping[mode],
1539                f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
1540            )
1541
1542        for mode in invalid_debug_modes:
1543            os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
1544            with self.assertRaisesRegex(
1545                ValueError, "The value of TORCH_DISTRIBUTED_DEBUG must"
1546            ):
1547                dist.set_debug_level_from_env()
1548
1549
1550class DummyWork(dist._Work):
1551    def wait(self, timeout=5.0):
1552        if torch.cuda.is_available():
1553            torch.cuda.current_stream().synchronize()
1554        return True
1555
1556
1557class DummyProcessGroup(dist.ProcessGroup):
1558    def getBackendName(self):
1559        return "Dummy"
1560
1561    def allgather(self, output_tensor_lists, input_tensor_list, opts=None):
1562        for output_tensor_list, input_tensor in zip(
1563            output_tensor_lists, input_tensor_list
1564        ):
1565            for output_tensor in output_tensor_list:
1566                output_tensor.copy_(input_tensor)
1567
1568        return DummyWork()
1569
1570    def allreduce(self, tensor_list, opts=None):
1571        for tensor in tensor_list:
1572            tensor.add_(2)
1573
1574        return DummyWork()
1575
1576    def barrier(self, opts=None):
1577        store = c10d._get_default_store()
1578        key = "TEST:DummyProcessGroup:barrier"
1579        if self.rank() == 0:
1580            worker_count = 0
1581            # By default, TCPServer lives on rank 0. So rank 0 needs to make
1582            # sure that it does not exit too early before other ranks finish
1583            # using the store.
1584            # Note that, _store_based_barrier does not solve this problem, as
1585            # all ranks need to run at least one store.add(key, 0) before
1586            # exiting, but there is no guarantee that rank 0 is still alive at
1587            # that point.
1588            while worker_count < self.size() - 1:
1589                worker_count = store.add(key, 0)
1590        else:
1591            store.add(key, 1)
1592
1593        return DummyWork()
1594
1595    def broadcast(self, tensor_list, opts=None):
1596        for tensor in tensor_list:
1597            tensor.add_(1)
1598
1599        return DummyWork()
1600
1601    def reduce_scatter(self, output_tensor_list, input_tensor_lists, opts=None):
1602        for output_tensor, input_tensor_list in zip(
1603            output_tensor_list, input_tensor_lists
1604        ):
1605            output_tensor.copy_(input_tensor_list[self.rank()])
1606
1607        return DummyWork()
1608
1609    def send(self, tensor_list, dst, tag=0):
1610        for tensor in tensor_list:
1611            tensor.add_(1)
1612
1613        return DummyWork()
1614
1615    def recv(self, tensor_list, src, tag=0):
1616        for tensor in tensor_list:
1617            tensor.add_(2)
1618
1619        return DummyWork()
1620
1621
1622class PythonProcessGroupExtensionTest(MultiProcessTestCase):
1623    def setUp(self):
1624        super().setUp()
1625        self._spawn_processes()
1626
1627    def tearDown(self):
1628        super().tearDown()
1629        try:
1630            os.remove(self.file_name)
1631        except OSError:
1632            pass
1633
1634    def test_get_backend_name(self):
1635        dpg = DummyProcessGroup(0, 1)
1636        self.assertEqual("Dummy", dpg.name())
1637
1638    def test_backend_class_attr(self):
1639        dist.Backend.register_backend(
1640            "dummy", PythonProcessGroupExtensionTest.create_dummy
1641        )
1642        self.assertEqual(dist.Backend.DUMMY, "dummy")
1643        self.assertEqual(
1644            dist.Backend._plugins["DUMMY"].creator_fn,
1645            PythonProcessGroupExtensionTest.create_dummy,
1646        )
1647
1648    def test_is_backend_available(self):
1649        self.assertEqual(dist.is_ucc_available(), dist.is_backend_available("ucc"))
1650        self.assertFalse(dist.is_backend_available("dummy"))
1651        dist.Backend.register_backend(
1652            "dummy", PythonProcessGroupExtensionTest.create_dummy
1653        )
1654        self.assertTrue(dist.is_backend_available("dummy"))
1655
1656    def test_backend_config(self):
1657        dist.Backend.register_backend(
1658            "dummy", PythonProcessGroupExtensionTest.create_dummy
1659        )
1660
1661        # Ensure backend config can be created with the following arguments
1662        backend_config_strings_and_expected_values = [
1663            (dist.Backend.GLOO, "cpu:gloo,cuda:gloo"),
1664            (dist.Backend.NCCL, "cuda:nccl"),
1665            (dist.Backend.MPI, "cpu:mpi,cuda:mpi"),
1666            (dist.Backend.UCC, "cpu:ucc,cuda:ucc"),
1667            (dist.Backend.DUMMY, "cpu:dummy,cuda:dummy"),
1668            ("DUMMY", "cpu:dummy,cuda:dummy"),
1669            ("dummy", "cpu:dummy,cuda:dummy"),
1670            ("cpu:dummy,cuda:dummy", "cpu:dummy,cuda:dummy"),
1671            ("cpu:dummy,cuda:nccl", "cpu:dummy,cuda:nccl"),
1672            ("cpu:gloo,cuda:dummy", "cpu:gloo,cuda:dummy"),
1673            ("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
1674        ]
1675
1676        for config_str, expected_value in backend_config_strings_and_expected_values:
1677            with self.subTest(config_str):
1678                # ensures these configs strings are valid and no ValueError is raised
1679                config = dist.BackendConfig(config_str)
1680                self.assertEqual(str(config), expected_value)
1681
1682        # Ensure backend config will raise ValueError with the following arguments
1683        invalid_backend_config_strings = [
1684            "cpu:gloo,cuda:nccl,",  # trailing comma
1685            "cpu:gloo,cuda:nccl,cpu:dummy",  # duplicate device
1686        ]
1687        for config_str in invalid_backend_config_strings:
1688            with self.subTest(config_str):
1689                with self.assertRaises(ValueError):
1690                    dist.BackendConfig(config_str)
1691
1692    def test_init_process_group_with_multiple_backends(self):
1693        dist.Backend.register_backend(
1694            "dummy", PythonProcessGroupExtensionTest.create_dummy
1695        )
1696
1697        os.environ["MASTER_ADDR"] = "localhost"
1698        os.environ["MASTER_PORT"] = "6789"
1699        dist.init_process_group(
1700            "cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
1701        )
1702
1703        # test all_gather
1704        input_tensor = torch.ones(2, 2) * 7
1705        output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)]
1706        dist.all_gather(output_tensor_list, input_tensor)
1707
1708        dist.barrier()
1709        dist.destroy_process_group()
1710
1711    class Options:
1712        def __init__(self) -> None:
1713            pass
1714
1715        def create(self):
1716            pass
1717
1718    @staticmethod
1719    def create_dummy(store, group_rank, group_size, timeout):
1720        return DummyProcessGroup(group_rank, group_size)
1721
1722    def test_collectives(self):
1723        dist.Backend.register_backend(
1724            "dummy", PythonProcessGroupExtensionTest.create_dummy
1725        )
1726
1727        os.environ["MASTER_ADDR"] = "localhost"
1728        os.environ["MASTER_PORT"] = "6789"
1729        dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
1730
1731        # test all_gather
1732        input_tensor = torch.ones(2, 2) * 7
1733        output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)]
1734        dist.all_gather(output_tensor_list, input_tensor)
1735
1736        for tensor in output_tensor_list:
1737            self.assertEqual(tensor, input_tensor)
1738
1739        # test all_reduce
1740        input_tensor = torch.ones(2, 2) * 7
1741        dist.all_reduce(input_tensor)
1742        self.assertEqual(input_tensor, torch.ones(2, 2) * 7 + 2)
1743
1744        # test broadcast
1745        input_tensor = torch.zeros(2, 2)
1746        dist.broadcast(input_tensor, 0, async_op=True).wait()
1747        self.assertEqual(torch.ones(2, 2), input_tensor)
1748
1749        # test reduce_scatter
1750        output_tensor = torch.zeros(2, 2)
1751        input_tensor_list = [torch.ones(2, 2) for _ in range(self.world_size)]
1752        dist.reduce_scatter(output_tensor, input_tensor_list)
1753        self.assertEqual(output_tensor, torch.zeros(2, 2) + 1)
1754
1755        dist.barrier()
1756        dist.destroy_process_group()
1757
1758    def test_send_recv(self):
1759        dist.Backend.register_backend(
1760            "dummy", PythonProcessGroupExtensionTest.create_dummy
1761        )
1762
1763        os.environ["MASTER_ADDR"] = "localhost"
1764        os.environ["MASTER_PORT"] = "6789"
1765        dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
1766
1767        # test send
1768        input_tensor = torch.zeros(2, 2)
1769        dist.send(input_tensor, (self.rank + 1) % self.world_size)
1770        self.assertEqual(input_tensor, torch.zeros(2, 2) + 1)
1771
1772        with self.assertRaises(ValueError):
1773            dist.send(input_tensor, dist.get_rank())
1774
1775        # test recv
1776        input_tensor = torch.zeros(2, 2)
1777        dist.recv(input_tensor, (self.rank + 1) % self.world_size)
1778        self.assertEqual(input_tensor, torch.zeros(2, 2) + 2)
1779
1780        dist.barrier()
1781        # intentionally not calling into `destroy_process_group` as not all
1782        # user applications would explicitly that.
1783
1784
1785instantiate_parametrized_tests(CommonDistributedDataParallelTest)
1786
1787
1788class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
1789    @property
1790    def world_size(self):
1791        return 1
1792
1793    def setUp(self):
1794        super().setUp()
1795        self._spawn_processes()
1796
1797    def tearDown(self):
1798        super().tearDown()
1799        try:
1800            os.remove(self.file_name)
1801        except OSError:
1802            pass
1803
1804    def test_init_process_group_optional_backend(self):
1805        with tempfile.NamedTemporaryFile(delete=False) as f:
1806            store = dist.FileStore(f.name, self.world_size)
1807            # creates both gloo and nccl backend
1808            if dist.is_gloo_available() and dist.is_nccl_available():
1809                dist.init_process_group(
1810                    store=store,
1811                    rank=self.rank,
1812                    world_size=self.world_size,
1813                )
1814                dist.destroy_process_group()
1815
1816    def test_init_process_group_for_all_backends(self):
1817        for backend in dist.Backend.backend_list:
1818            # skip if the backend is not available on the system
1819            if backend == dist.Backend.UNDEFINED:
1820                continue
1821            elif backend == dist.Backend.MPI:
1822                if not dist.is_mpi_available():
1823                    continue
1824            elif backend == dist.Backend.NCCL:
1825                if not dist.is_nccl_available() or not torch.cuda.is_available():
1826                    continue
1827            elif backend == dist.Backend.GLOO:
1828                if not dist.is_gloo_available():
1829                    continue
1830            elif backend == dist.Backend.UCC:
1831                if not dist.is_ucc_available():
1832                    continue
1833
1834            with tempfile.NamedTemporaryFile(delete=False) as f:
1835                store = dist.FileStore(f.name, self.world_size)
1836                dist.init_process_group(
1837                    backend=backend,
1838                    rank=self.rank,
1839                    world_size=self.world_size,
1840                    store=store,
1841                )
1842                pg = c10d._get_default_group()
1843                self.assertEqual(pg.rank(), self.rank)
1844                self.assertEqual(pg.size(), self.world_size)
1845                self.assertEqual(pg.name(), str(backend))
1846
1847                dist.destroy_process_group()
1848
1849    def _call_collective_with_varying_tensors(self, backend, collective, *args):
1850        # call collective with varying tensors to ensure that the tensors are
1851        # correctly dispatched
1852
1853        # TODO: this will be updated in the future to not be backend specific
1854        device = "cuda" if backend == "nccl" else "cpu"
1855        # ensure supported devices (cpu, cuda) succeeds during dispatch call
1856        tensor = torch.zeros(2, 2, device=torch.device(device))
1857        # multi tensor collectives
1858        if collective == dist.barrier:
1859            collective()
1860        elif collective in (dist.all_gather, dist.gather):
1861            collective([tensor], tensor, *args)
1862        elif collective == dist.scatter:
1863            collective(tensor, [tensor], *args)
1864        elif collective in (dist.reduce_scatter, dist.all_to_all):
1865            # gloo does not support reduce_scatter or all_to_all
1866            if backend != "gloo":
1867                if collective == dist.reduce_scatter:
1868                    collective(tensor, [tensor], *args)
1869                else:
1870                    collective([tensor], [tensor], *args)
1871        else:
1872            collective(tensor, *args)
1873
1874    # TODO: backend will be replaced with a non specified backend
1875    def _test_collectives(self, backend):
1876        store = dist.FileStore(self.file_name, self.world_size)
1877        dist.init_process_group(
1878            backend,
1879            world_size=self.world_size,
1880            rank=self.rank,
1881            store=store,
1882        )
1883        collectives_and_args = [
1884            (dist.reduce, self.rank),
1885            (dist.broadcast, self.rank),
1886            (dist.all_reduce,),
1887            (dist.all_gather,),
1888            (dist.reduce_scatter,),
1889            (dist.barrier,),
1890            (dist.all_to_all,),
1891            (dist.scatter,),
1892        ]
1893        for collective, *args in collectives_and_args:
1894            with self.subTest(collective=collective, args=args):
1895                self._call_collective_with_varying_tensors(backend, collective, *args)
1896
1897    def _test_allreduce_coalesced(self, backend):
1898        store = dist.FileStore(self.file_name, self.world_size)
1899        dist.init_process_group(
1900            backend,
1901            world_size=self.world_size,
1902            rank=self.rank,
1903            store=store,
1904        )
1905        # TODO: this will be updated in the future to not be backend specific
1906        device = "cuda" if backend == "nccl" else "cpu"
1907        tensors = [torch.ones(10, 10, device=torch.device(device))]
1908        dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM)
1909        for tensor in tensors:
1910            self.assertEqual(tensor, torch.ones(10, 10) * self.world_size)
1911
1912    def _test_all_to_all_single(self, backend):
1913        store = dist.FileStore(self.file_name, self.world_size)
1914        dist.init_process_group(
1915            backend,
1916            world_size=self.world_size,
1917            rank=self.rank,
1918            store=store,
1919        )
1920        device = "cuda" if backend == "nccl" else "cpu"
1921        # test alltoall_base
1922        input_tensor = torch.ones(2, 2, device=torch.device(device))
1923        output_tensor = torch.zeros(2, 2, device=torch.device(device))
1924        dist.all_to_all_single(output_tensor, input_tensor)
1925
1926
1927class ReduceOpTest(TestCase):
1928    # Ref: https://github.com/pytorch/pytorch/issues/87191
1929    def test_op_isinstance_of_reduceop(self):
1930        for reduce_op in (
1931            c10d.ReduceOp.SUM,
1932            c10d.ReduceOp.AVG,
1933            c10d.ReduceOp.PRODUCT,
1934            c10d.ReduceOp.MIN,
1935            c10d.ReduceOp.MAX,
1936            c10d.ReduceOp.BAND,
1937            c10d.ReduceOp.BOR,
1938            c10d.ReduceOp.BXOR,
1939        ):
1940            self.assertTrue(isinstance(reduce_op, c10d.ReduceOp))
1941        for scale in (torch.tensor(1.0), 2.0):
1942            self.assertTrue(
1943                isinstance(dist._make_nccl_premul_sum(scale), c10d.ReduceOp)
1944            )
1945
1946    # Ref: https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700
1947    def test_reduceop_copyable(self):
1948        for reduce_op in (
1949            c10d.ReduceOp.SUM,
1950            c10d.ReduceOp.AVG,
1951            c10d.ReduceOp.PRODUCT,
1952            c10d.ReduceOp.MIN,
1953            c10d.ReduceOp.MAX,
1954            c10d.ReduceOp.BAND,
1955            c10d.ReduceOp.BOR,
1956            c10d.ReduceOp.BXOR,
1957        ):
1958            self.assertEqual(copy.copy(reduce_op), reduce_op)
1959            self.assertEqual(copy.deepcopy(reduce_op), reduce_op)
1960            self.assertEqual(copy.copy(c10d.ReduceOp(reduce_op)), reduce_op)
1961            self.assertEqual(copy.deepcopy(c10d.ReduceOp(reduce_op)), reduce_op)
1962
1963        for scale in (torch.tensor(1.0), 2.0):
1964            reduce_op = dist._make_nccl_premul_sum(scale)
1965            self.assertEqual(copy.copy(reduce_op), reduce_op)
1966            self.assertEqual(copy.deepcopy(reduce_op), reduce_op)
1967
1968    def test_reduceop_pickle(self):
1969        for reduce_op in (
1970            c10d.ReduceOp.SUM,
1971            c10d.ReduceOp.AVG,
1972            c10d.ReduceOp.PRODUCT,
1973            c10d.ReduceOp.MIN,
1974            c10d.ReduceOp.MAX,
1975            c10d.ReduceOp.BAND,
1976            c10d.ReduceOp.BOR,
1977            c10d.ReduceOp.BXOR,
1978        ):
1979            pickle.loads(pickle.dumps(reduce_op))
1980            orig = c10d.ReduceOp(reduce_op)
1981            self.assertEqual(pickle.loads(pickle.dumps(orig)), orig)
1982        for scale in (torch.tensor(1.0), 2.0):
1983            reduce_op = dist._make_nccl_premul_sum(scale)
1984            self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op)
1985
1986    # Ref: https://github.com/pytorch/pytorch/issues/90072
1987    def test_reduceop_equal(self):
1988        not_reduceop = "abc"
1989        for reduce_op in (
1990            c10d.ReduceOp.SUM,
1991            c10d.ReduceOp.AVG,
1992            c10d.ReduceOp.PRODUCT,
1993            c10d.ReduceOp.MIN,
1994            c10d.ReduceOp.MAX,
1995            c10d.ReduceOp.BAND,
1996            c10d.ReduceOp.BOR,
1997            c10d.ReduceOp.BXOR,
1998        ):
1999            reduce_op_obj = c10d.ReduceOp(reduce_op)
2000            # this calls `ReduceOp.__eq__(self, other)`
2001            self.assertEqual(reduce_op_obj, reduce_op_obj)
2002            self.assertEqual(reduce_op_obj, reduce_op)
2003            self.assertNotEqual(reduce_op_obj, not_reduceop)
2004            self.assertNotEqual(reduce_op, not_reduceop)
2005            # TODO(crcrpar): This needs to be `assertEqual` for the associativity even though
2006            # the comparison of `RedOpType` and `ReduceOp` sounds less likely to happen compared
2007            # to that of `ReduceOp` and `RedOptype`.
2008            # this calls `RedOpType.__eq__(self, other)`
2009            self.assertNotEqual(reduce_op, reduce_op_obj)
2010
2011            self.assertFalse(None in (reduce_op, reduce_op_obj))
2012            self.assertFalse(not_reduceop in (reduce_op, reduce_op_obj))
2013
2014
2015class LocalRankTest(MultiProcessTestCase):
2016    @property
2017    def world_size(self):
2018        return 4
2019
2020    def setUp(self):
2021        super().setUp()
2022        self._spawn_processes()
2023
2024    def tearDown(self):
2025        super().tearDown()
2026        try:
2027            os.remove(self.file_name)
2028        except OSError:
2029            pass
2030
2031    def testWithoutEnv(self):
2032        with self.assertRaisesRegex(RuntimeError, "LOCAL_RANK"):
2033            dist.get_node_local_rank()
2034
2035    def testWithoutEnvWithFallback(self):
2036        self.assertEqual(dist.get_node_local_rank(fallback_rank=2), 2)
2037
2038    def testNodeLocalRankOverridesFallback(self):
2039        os.environ["LOCAL_RANK"] = str(self.rank)
2040        self.assertEqual(dist.get_node_local_rank(fallback_rank=123), self.rank)
2041
2042    def testNodeLocalRank(self):
2043        os.environ["LOCAL_RANK"] = str(self.rank)
2044        self.assertEqual(dist.get_node_local_rank(), self.rank)
2045
2046
2047if __name__ == "__main__":
2048    assert (
2049        not torch.cuda._initialized
2050    ), "test_distributed must not have initialized CUDA context on main process"
2051
2052    run_tests()
2053