xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/distributed_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import copy
4import json
5import itertools
6import math
7import os
8import random
9import sys
10import tempfile
11import time
12from collections import namedtuple, OrderedDict, defaultdict
13from contextlib import contextmanager, nullcontext
14from dataclasses import dataclass
15from datetime import timedelta
16from functools import reduce
17from typing import Union, NamedTuple, Callable, Any
18import unittest
19import numpy as np
20import torch
21import torch.cuda
22import torch.distributed as dist
23import torch.distributed.algorithms.model_averaging.averagers as averagers
24import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
25import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
26import torch.nn as nn
27import torch.nn.functional as F
28from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
29from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
30from torch.utils._python_dispatch import TorchDispatchMode
31from torch.autograd import DeviceType
32from torch.cuda.amp import GradScaler, autocast
33
34from torch.distributed.algorithms.ddp_comm_hooks import (
35    post_localSGD_hook as post_localSGD,
36    powerSGD_hook as powerSGD,
37    default_hooks as default,
38    quantization as quantization_hooks,
39)
40from torch.distributed.optim import _apply_optimizer_in_backward
41
42from torch.distributed.distributed_c10d import (
43    get_world_size,
44    _get_default_group,
45    _get_pg_config,
46)
47from torch.distributed.utils import (
48    _verify_param_shape_across_processes,
49    _sync_module_states,
50)
51from torch.profiler import (
52    ExecutionTraceObserver,
53    ProfilerActivity,
54)
55
56from torch.nn.parallel import DistributedDataParallel
57from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
58from torch.testing._internal.common_distributed import (
59    MultiProcessTestCase,
60    TEST_SKIPS,
61    init_multigpu_helper,
62    initialize_temp_directories,
63    cleanup_temp_dir,
64    simple_sparse_reduce_tests,
65    skip_if_rocm_multiprocess,
66    skip_if_small_worldsize,
67    skip_if_odd_worldsize,
68    skip_if_lt_x_gpu,
69    nccl_skip_if_lt_x_gpu,
70    skip_if_no_gpu,
71    require_n_gpus_for_nccl_backend,
72    requires_nccl_version,
73    captured_output,
74    with_nccl_blocking_wait,
75    with_dist_debug_levels,
76    verify_ddp_error_logged,
77    DistTestCases,
78)
79from torch.testing._internal.common_utils import (
80    instantiate_parametrized_tests,
81    IS_MACOS,
82    IS_WINDOWS,
83    FILE_SCHEMA,
84    IS_FBCODE,
85    NO_MULTIPROCESSING_SPAWN,
86    IS_SANDCASTLE,
87    skip_but_pass_in_sandcastle,
88    skip_but_pass_in_sandcastle_if,
89)
90
91import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
92
93from torch.utils.data.distributed import DistributedSampler
94import operator
95
96try:
97    import torchvision
98
99    HAS_TORCHVISION = True
100except ImportError:
101    HAS_TORCHVISION = False
102
103if sys.platform == "win32":
104    import msvcrt
105else:
106    import fcntl
107
108
109class NetWithBuffers(nn.Module):
110    def __init__(self) -> None:
111        super().__init__()
112        self.a = nn.Linear(10, 10, bias=False)
113        self.b = nn.Linear(10, 1, bias=False)
114        self.register_buffer("buffer", torch.randn(1, 2))
115
116    def forward(self, x):
117        self.buffer.add_(1)
118        return self.b(self.a(x))
119
120
121class Foo:
122    def __init__(self, x):
123        # Can be tensor or int
124        self.x = x
125
126    def __eq__(self, other):
127        def eq(value, other):
128            if isinstance(value, torch.Tensor):
129                return torch.equal(value, other)
130            return value == other
131
132        for attr, value in self.__dict__.items():
133            other_value = other.__dict__[attr]
134            if not eq(value, other_value):
135                return False
136        return True
137
138
139f = Foo(10)
140f.bar = 1
141
142foo_cpu_tensor = Foo(torch.randn(3, 3))
143
144
145COLLECTIVES_OBJECT_TEST_LIST = [
146    {"key1": 3, "key2": 4, "key3": {"nested": True}},
147    f,
148    foo_cpu_tensor,
149    "foo",
150    [1, 2, True, "string", [4, 5, "nested"]],
151]
152
153# Allowlist of distributed backends where profiling collectives is supported.
154PROFILING_SUPPORTED_BACKENDS = [
155    dist.Backend.NCCL,
156    dist.Backend.GLOO,
157    dist.Backend.MPI,
158    dist.Backend.UCC,
159]
160
161# Allowlist of distributed backends where profiling is supported with use_cuda=True
162CUDA_PROFILING_SUPPORTED_BACKENDS = [
163    dist.Backend.GLOO,
164    dist.Backend.MPI,
165    dist.Backend.NCCL,
166    dist.Backend.UCC,
167]
168
169# Allowlist of distributed backends where profiling is supported for p2p ops
170SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
171    dist.Backend.MPI,
172    dist.Backend.GLOO,
173    dist.Backend.NCCL,
174    dist.Backend.UCC,
175]
176
177# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
178EXPECTED_FIELDS = ("a", "b")
179TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)
180
181
182class TestNamedTupleInput_1(NamedTuple):
183    a: torch.tensor
184    b: torch.tensor
185
186
187skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
188    not HAS_TORCHVISION, "no torchvision"
189)
190
191BACKEND = os.environ["BACKEND"]
192INIT_METHOD = os.getenv("INIT_METHOD", "env://")
193
194DEFAULT_TIMEOUT = 300
195CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
196
197
198def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
199    event_list = (
200        profiler.events()
201        if isinstance(profiler, torch.profiler.profile)
202        else profiler.function_events
203    )
204    return [
205        event for event in event_list
206        if (
207            (event.name.endswith(event_name) or event.name.startswith(event_name))
208            and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
209        )
210    ]
211
212def get_profiler_nccl_meta(prof):
213    """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
214    We will need to test metadata obtained from profiler here"""
215    tf = tempfile.NamedTemporaryFile(
216        mode="w+t", suffix=".json", delete=False
217    )
218    tf.close()
219    trace_file = tf.name
220
221    prof.export_chrome_trace(trace_file)
222    with open(trace_file) as f:
223        events = json.load(f)["traceEvents"]
224    print(f"Trace saved to {trace_file}")
225
226    # Comment to debug
227    os.remove(trace_file)
228
229    return [e for e in events if e.get("name") == "record_param_comms"]
230
231# Base error message substring on unfinished reductions.
232ddp_prev_reduction_unfinished_str = (
233    "Expected to have finished reduction in the prior iteration"
234)
235# Error message substring when find_unused_parameters=True has not been passed
236ddp_recommend_find_unused_params_str = (
237    "passing the keyword argument `find_unused_parameters=True`"
238)
239# Error message substring when find_unused_parameters=True is enabled
240ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
241# Error message substring for possibility of not all model outputs being used
242# in loss computation
243ddp_outputs_not_used_in_loss_str = (
244    "`forward` function outputs participate in calculating loss"
245)
246# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
247ddp_suggest_debug_mode_str = (
248    "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
249)
250
251
252class DDPUnevenTestInput(NamedTuple):
253    name: str
254    model: nn.Module
255    inp: Union[torch.tensor, tuple]
256    sync_interval: int
257    throw_on_early_termination: bool = False
258    hook: Callable = None
259    state: Any = None
260
261
262class _FC2(nn.Module):
263    def __init__(self) -> None:
264        super().__init__()
265        self.fc = nn.Linear(10, 50, bias=True)
266        self.fc.bias.requires_grad = False
267
268    def forward(self, x):
269        x = self.fc(x)
270        return x
271
272
273class Net(nn.Module):
274    def __init__(self) -> None:
275        super().__init__()
276        self.fc1 = nn.Linear(2, 10, bias=False)
277        self.fc2 = _FC2()
278        self.fc3 = nn.Linear(50, 4, bias=False)
279        self.relu = nn.ReLU()
280        self.no_grad_param = nn.Parameter(
281            torch.tensor([2, 2]).long(), requires_grad=False
282        )
283
284    def forward(self, x):
285        x = self.relu(self.fc1(x))
286        x = self.relu(self.fc2(x))
287        x = self.fc3(x)
288        return F.softmax(x, dim=1)
289
290
291class LargeNet(nn.Module):
292    def __init__(self) -> None:
293        super().__init__()
294        self.fc1 = nn.Linear(1000, 2000, bias=False)
295        self.fc2 = nn.Linear(2000, 500, bias=False)
296
297    def forward(self, x):
298        x = self.fc1(x)
299        x = self.fc2(x)
300        return x
301
302
303class Task(nn.Module):
304    def __init__(self) -> None:
305        super().__init__()
306        self.p = nn.Parameter(torch.ones(2, 2))
307
308    def forward(self, x):
309        return self.p + x
310
311
312class BatchNormNet(nn.Module):
313    def __init__(self, affine=True):
314        super().__init__()
315        self.fc1 = nn.Linear(2, 40, bias=False)
316        self.bn = nn.BatchNorm1d(4, affine=affine)
317        self.fc2 = nn.Linear(40, 4, bias=False)
318
319    def forward(self, x):
320        x = torch.reshape(self.fc1(x), (-1, 4, 10))
321        x = self.bn(x)
322        x = torch.reshape(x, (-1, 40))
323        x = self.fc2(x)
324        return F.softmax(x, dim=1)
325
326
327class UnusedParamTwoLinLayerNet(nn.Module):
328    def __init__(self) -> None:
329        super().__init__()
330        self.a = nn.Linear(10, 10, bias=False)
331        self.b = nn.Linear(10, 10, bias=False)
332        self.c = nn.Linear(5, 5, bias=False)
333
334    def forward(self, x):
335        a = self.a(x)
336        b = self.b(x)
337        return (a, b)
338
339
340class DictOutputModule(nn.Module):
341    def __init__(self) -> None:
342        super().__init__()
343        self.module = UnusedParamTwoLinLayerNet()
344
345    def forward(self, x):
346        predictions = self.module(x)
347        loss = (predictions[0] + predictions[1]).sum()
348        return {
349            "predictions": predictions,
350            "loss": loss,
351        }
352
353
354class TwoLinLayerNet(nn.Module):
355    def __init__(self) -> None:
356        super().__init__()
357        self.a = nn.Linear(10, 10, bias=False)
358        self.b = nn.Linear(10, 1, bias=False)
359
360    def forward(self, x):
361        a = self.a(x)
362        b = self.b(x)
363        return (a, b)
364
365
366class EmbeddingNetDifferentParams(nn.Module):
367    """
368    A module containing an embedding with different dimension or different # of
369    parameters depending on the rank.
370    """
371
372    def __init__(self, rank, diff_num_params=False):
373        super().__init__()
374        embedding_dim = 500 if diff_num_params or rank == 0 else 50
375        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
376        self.lin = nn.Linear(embedding_dim, 1)
377        if diff_num_params:
378            self.lin2 = nn.Linear(1, 1, bias=False)
379
380    def forward(self, x):
381        x = self.embedding(x)
382        return self.lin(x)
383
384
385class ControlFlowToyModel(nn.Module):
386    def __init__(self) -> None:
387        super().__init__()
388        self.lin1 = nn.Linear(10, 10, bias=False)
389        self.lin2 = nn.Linear(10, 10, bias=False)
390
391    def forward(self, x):
392        # Second layer is used dependent on input x.
393        use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
394        if use_second_layer:
395            return self.lin2(F.relu(self.lin1(x)))
396        else:
397            return F.relu(self.lin1(x))
398
399
400DDP_NET = Net()
401BN_NET = BatchNormNet()
402BN_NET_NO_AFFINE = BatchNormNet(affine=False)
403ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
404
405
406def get_timeout(test_id):
407    test_name = test_id.split(".")[-1]
408    if test_name in CUSTOMIZED_TIMEOUT:
409        return CUSTOMIZED_TIMEOUT[test_name]
410    else:
411        return DEFAULT_TIMEOUT
412
413
414default_pg_timeout = 60
415
416CUSTOM_PG_TIMEOUT = {
417    # This test runs slowly and needs additional time to complete, otherwise can
418    # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING
419    "test_ddp_uneven_inputs": 300,
420    # This test has a short timeout since it tests being taken down by
421    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
422    "test_ddp_model_diff_across_ranks": 5,
423    # This test has a short timeout since it tests being taken down by
424    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
425    "test_ddp_has_finalized": 5,
426}
427
428def require_backend_is_available(backends):
429    def check(backend):
430        if backend == dist.Backend.GLOO:
431            return dist.is_gloo_available()
432        if backend == dist.Backend.NCCL:
433            return dist.is_nccl_available()
434        if backend == dist.Backend.MPI:
435            return dist.is_mpi_available()
436        if backend == dist.Backend.UCC:
437            return dist.is_ucc_available()
438        if backend in DistTestCases.backend_feature["plugin"]:
439            return True
440        return False
441
442    if BACKEND not in backends:
443        return skip_but_pass_in_sandcastle(
444            f"Test requires backend {BACKEND} to be one of {backends}"
445        )
446
447    if not check(dist.Backend(BACKEND)):
448        return skip_but_pass_in_sandcastle(
449            f"Test requires backend {BACKEND} to be available"
450        )
451    return lambda func: func
452
453
454def require_world_size(world_size):
455    if int(os.environ["WORLD_SIZE"]) < world_size:
456        return skip_but_pass_in_sandcastle(
457            "Test requires world size of %d" % world_size
458        )
459    return lambda func: func
460
461
462@contextmanager
463def _lock():
464    TEMP_DIR = os.environ["TEMP_DIR"]
465    lockfile = os.path.join(TEMP_DIR, "lockfile")
466    with open(lockfile, "w") as lf:
467        try:
468            if sys.platform == "win32":
469                msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
470                yield
471            else:
472                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
473                yield
474        finally:
475            if sys.platform == "win32":
476                msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
477            else:
478                fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
479            lf.close()
480
481
482@contextmanager
483def _rank_temp_file():
484    if dist.get_rank() == 0:
485        fd, name = tempfile.mkstemp()
486        os.close(fd)
487    else:
488        name = None
489    object_list = [name]
490    dist.broadcast_object_list(object_list)
491    name = object_list[0]
492    try:
493        yield name
494    finally:
495        if dist.get_rank() == 0:
496            os.remove(name)
497
498
499def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
500    if value is None:
501        value = size
502    if device_id is None:
503        return torch.empty(size, size, size, dtype=dtype).fill_(value)
504    else:
505        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
506
507
508def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
509    if value is None:
510        value = dim
511    return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)
512
513
514def _create_autograd_profiler():
515    return torch.autograd.profiler.profile(record_shapes=True)
516
517
518def _create_torch_profiler():
519    return torch.profiler.profile(
520        activities=[
521            torch.profiler.ProfilerActivity.CPU,
522        ],
523        record_shapes=True,
524    )
525
526
527class Barrier:
528    barrier_id = 0
529
530    @classmethod
531    def init(cls):
532        cls.barrier_id = 0
533        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
534        for f_name in os.listdir(barrier_dir):
535            os.unlink(os.path.join(barrier_dir, f_name))
536
537    @classmethod
538    def sync(cls, wait_for=None, timeout=10):
539        if wait_for is None:
540            wait_for = dist.get_world_size()
541        cls.barrier_id += 1
542        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
543        pid = str(os.getpid())
544        barrier_file = os.path.join(barrier_dir, pid)
545        with _lock():
546            with open(barrier_file, "w") as f:
547                f.write(str(cls.barrier_id))
548
549        start_time = time.time()
550        while True:
551            arrived = 0
552            with _lock():
553                for f_name in os.listdir(barrier_dir):
554                    with open(os.path.join(barrier_dir, f_name)) as f:
555                        data = f.read()
556                        if int(data) >= cls.barrier_id:
557                            arrived += 1
558            if arrived == wait_for:
559                break
560
561            if time.time() - start_time > timeout:
562                raise RuntimeError("barrier timeout")
563            time.sleep(0.1)
564
565
566class TestDistBackend(MultiProcessTestCase):
567    @classmethod
568    def setUpClass(cls):
569        os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
570        # Not setting MASTER_PORT and get a random free port
571        super().setUpClass()
572
573    def setUp(self):
574        super().setUp()
575        # initialize temp directories
576        initialize_temp_directories()
577        # initialize Barrier
578        Barrier.init()
579        # Skip return code checking for following tests as they are expected to
580        # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING.
581        self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__]
582
583    def tearDown(self):
584        cleanup_temp_dir()
585        super().tearDown()
586
587    @property
588    def init_method(self):
589        return f"{FILE_SCHEMA}{self.file_name}"
590
591    @classmethod
592    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
593        if BACKEND == "nccl" and not torch.cuda.is_available():
594            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
595        self = cls(test_name)
596        self.rank = rank
597        self.file_name = file_name
598
599        if torch.cuda.is_available() and torch.cuda.device_count() < int(
600            self.world_size
601        ):
602            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
603        try:
604            pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
605            timeout = timedelta(seconds=pg_timeout_seconds)
606            dist.init_process_group(
607                init_method=self.init_method,
608                backend=BACKEND,
609                world_size=int(self.world_size),
610                rank=self.rank,
611                timeout=timeout,
612            )
613        except RuntimeError as e:
614            if "recompile" in e.args[0]:
615                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
616
617            raise
618
619        # Execute barrier prior to running test to ensure that every process
620        # has finished initialization and that the following test
621        # immediately exiting due to a skip doesn't cause flakiness.
622        self._barrier()
623
624        self.run_test(test_name, pipe)
625        self._barrier()
626        dist.destroy_process_group()
627        sys.exit(0)
628
629    # Needed since MultiProcessTestCase assumes a world_size of 4, but we
630    # run these tests under other various world_sizes.
631    @property
632    def world_size(self):
633        return os.environ["WORLD_SIZE"]
634
635
636class DistributedTest:
637    class _DistTestBase:
638        def _barrier(self, *args, **kwargs):
639            Barrier.sync(*args, **kwargs)
640
641        def _init_group_test(self, **kwargs):
642            group = [1, 2]
643            group_id = dist.new_group(group, **kwargs)
644            rank = dist.get_rank()
645            if rank not in group:
646                return ([], None, rank)
647
648            return (group, group_id, rank)
649
650        def _init_full_group_test(self, **kwargs):
651            group = list(range(0, dist.get_world_size()))
652            group_id = dist.new_group(**kwargs)
653            rank = dist.get_rank()
654            return (group, group_id, rank)
655
656        def _init_global_test(self):
657            group = list(range(0, dist.get_world_size()))
658            group_id = dist.group.WORLD
659            rank = dist.get_rank()
660            return (group, group_id, rank)
661
662        def _verify_buffers_equal(self, m1, m2):
663            # verify buffers across models
664            m1_buf_dict = dict(m1.module.named_buffers())
665            for name, buf in m2.module.named_buffers():
666                self.assertEqual(buf, m1_buf_dict[name])
667
668            # Verify buffers across ranks.
669            m1_buffers = list(m1.buffers())
670            m2_buffers = list(m2.buffers())
671            for (buf1, buf2) in zip(m1_buffers, m2_buffers):
672                gathered_bufs = [
673                    torch.empty_like(buf1) for _ in range(dist.get_world_size())
674                ]
675                dist.all_gather(gathered_bufs, buf1)
676                gathered_bufs_m2 = [
677                    torch.empty_like(buf2) for _ in range(dist.get_world_size())
678                ]
679                for b in gathered_bufs:
680                    self.assertEqual(b, buf1)
681                dist.all_gather(gathered_bufs_m2, buf2)
682                for b in gathered_bufs_m2:
683                    self.assertEqual(b, buf2)
684
685        def _sanity_check_profiler_nccl_meta(self, nccl_meta_events):
686            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
687            We test for basic fields in this profiler event that correspond to the nccl communication
688            collectives"""
689            per_coll_meta = defaultdict(list)
690            for e in nccl_meta_events:
691                args = e.get("args", {})
692                collname = args.get("Collective name", "")
693                self.assertNotEqual(collname, "")
694                self.assertNotEqual(args.get("dtype", ""), "")
695
696                per_coll_meta[collname].append(args)
697                if collname in {"wait"}:
698                    continue
699
700                self.assertEqual(args["Process Group Description"], "default_pg")
701                self.assertNotEqual(args["Process Group Ranks"], "")
702
703                self.assertGreaterEqual(args.get("In msg nelems", -1), 0)
704                self.assertGreaterEqual(args.get("Out msg nelems", -1), 0)
705                self.assertGreaterEqual(args.get("Group size", -1), 0)
706                self.assertGreaterEqual(args.get("Global rank start", -1), 0)
707                self.assertGreaterEqual(args.get("Global rank stride", -1), 0)
708
709            # print(per_coll_meta)
710            return per_coll_meta
711
712        def test_dump_DDP_relevant_env_vars(self):
713            with captured_output() as (out, _):
714                _dump_DDP_relevant_env_vars()
715                lines = out.getvalue().splitlines()
716
717            def format_line(var):
718                return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
719
720            # Check relevant env vars
721            vars = [
722                "MASTER_ADDR",
723                "MASTER_PORT",
724                "WORLD_SIZE",
725                "NCCL_TOPO_DUMP_FILE",  # N/A
726                "TORCH_NCCL_ASYNC_ERROR_HANDLING",
727            ]
728            for var in vars:
729                line = format_line(var)
730                self.assertIn(line, lines)
731            # Check irrelevant env vars
732            vars = [
733                "xxx",
734                "yyy",
735                "zzz",
736            ]
737            for var in vars:
738                line = format_line(var)
739                self.assertNotIn(line, lines)
740
741        # GET RANK
742        def test_get_rank(self):
743            test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
744            pid = str(os.getpid())
745            num_processes = dist.get_world_size()
746            with open(os.path.join(test_dir, pid), "w") as f:
747                f.write(str(dist.get_rank()))
748
749            self._barrier()
750
751            all_ranks = set()
752            for f_name in os.listdir(test_dir):
753                with open(os.path.join(test_dir, f_name)) as f:
754                    all_ranks.add(int(f.read()))
755            self.assertEqual(len(all_ranks), num_processes)
756
757            self._barrier()
758
759            if dist.get_rank() == 0:
760                for f_name in os.listdir(test_dir):
761                    os.unlink(os.path.join(test_dir, f_name))
762
763            self._barrier()
764
765        def test_get_backend(self):
766            if dist.get_world_size() > 2:
767                group = [1, 2]
768            else:
769                group = [0, 1]
770            group_id = dist.new_group(group)
771            backend_str = BACKEND.lower()
772            self.assertEqual(dist.get_backend(), backend_str)
773            if dist.get_rank() in group:
774                self.assertEqual(dist.get_backend(group_id), backend_str)
775            else:
776                with self.assertRaisesRegex(
777                    ValueError, "Invalid process group specified"
778                ):
779                    dist.get_backend(group_id)
780
781        def test_Backend_enum_class(self):
782            # test parsing
783            backend = BACKEND.lower()
784            self.assertEqual(dist.Backend(BACKEND.upper()), backend)
785            self.assertEqual(dist.Backend(BACKEND), backend)
786            with self.assertRaises(ValueError):
787                dist.Backend(None)
788            with self.assertRaises(ValueError):
789                dist.Backend(3)
790            with self.assertRaises(ValueError):
791                dist.Backend(["gloo"])
792
793        # Test destroy
794        def test_destroy_group(self):
795            if dist.get_world_size() > 2:
796                group = [1, 2]
797            else:
798                group = [0, 1]
799            group_id = dist.new_group(group)
800            self._barrier()
801            dist.destroy_process_group(group_id)
802
803        # Test get rank and size of group
804        def test_get_rank_size_group(self):
805            if dist.get_world_size() > 2:
806                group = [1, 2]
807            else:
808                group = [0, 1]
809            group_id = dist.new_group(group)
810            if dist.get_rank() in group:
811                self.assertEqual(dist.get_world_size(group_id), 2)
812                self.assertTrue(dist.get_rank(group_id) in list(range(2)))
813            else:
814                self.assertEqual(dist.get_world_size(group_id), -1)
815                self.assertEqual(dist.get_rank(group_id), -1)
816
817        # Test destroy full groups
818        def test_destroy_full_group(self):
819            _, group_id, _ = self._init_full_group_test()
820            self._barrier()
821            dist.destroy_process_group(group_id)
822
823        # Test get rank and size of full group
824        def test_get_rank_size_full_group(self):
825            _, group_id, _ = self._init_full_group_test()
826            self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
827            self.assertEqual(dist.get_rank(group_id), dist.get_rank())
828
829        def _test_barrier_timeout(self, group_id, timeout):
830            local_rank = dist.get_rank(group_id)
831
832            # Only execute barrier on rank == 0, causing it to timeout
833            if local_rank == 0:
834                expected_time = time.time() + timeout.total_seconds()
835                # In debug mode, we execute a monitored_barrier before the
836                # collective, so assert on that.
837                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
838                    exception_ctx = self.assertRaisesRegex(
839                        Exception, "failed to pass monitoredBarrier"
840                    )
841                else:
842                    exception_ctx = self.assertRaisesRegex(
843                        Exception, " (Timed out|closed|timeout) "
844                    )
845                with exception_ctx:
846                    dist.barrier(group_id)
847                self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
848            else:
849                pass
850
851        @skip_but_pass_in_sandcastle_if(
852            BACKEND != "gloo", "Only gloo backend supports timeouts"
853        )
854        @skip_but_pass_in_sandcastle_if(
855            not INIT_METHOD.startswith("file://"),
856            "Requires file:// initialization method. "
857            + "Both tcp:// and env:// rely on the TCP store for which "
858            "reinitialization has proven racy.",
859        )
860        def test_barrier_timeout_global(self):
861            dist.destroy_process_group()
862
863            # Explicitly pass world size to the barrier because we've
864            # just destroyed any state in torch.distributed.
865            self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))
866
867            # Reinitialize global process group
868            timeout = timedelta(seconds=1)
869            dist.init_process_group(
870                init_method=INIT_METHOD,
871                backend=BACKEND,
872                world_size=int(os.environ["WORLD_SIZE"]),
873                rank=self.rank,
874                timeout=timeout,
875            )
876            self._test_barrier_timeout(dist.group.WORLD, timeout)
877
878        @skip_if_small_worldsize
879        @skip_but_pass_in_sandcastle_if(
880            BACKEND != "gloo", "Only gloo backend supports timeouts"
881        )
882        def test_barrier_timeout_group(self):
883            timeout = timedelta(seconds=5)
884            _, group_id, _ = self._init_group_test(timeout=timeout)
885            if group_id is not None:
886                self._test_barrier_timeout(group_id, timeout)
887
888        @skip_but_pass_in_sandcastle_if(
889            BACKEND != "gloo", "Only gloo backend supports timeouts"
890        )
891        def test_barrier_timeout_full_group(self):
892            timeout = timedelta(seconds=1)
893            _, group_id, _ = self._init_full_group_test(timeout=timeout)
894            if group_id is not None:
895                self._test_barrier_timeout(group_id, timeout)
896
897        # This test helper can only be used when using the Gloo or NCCL backend
898        # **and** both the Gloo and NCCL backends are available.
899        # See the @skip annotations below.
900        def _test_group_override_backend(self, initializer):
901            if BACKEND == "gloo":
902                new_backend = "nccl"
903            elif BACKEND == "nccl":
904                new_backend = "gloo"
905            elif BACKEND in DistTestCases.backend_feature["plugin"]:
906                new_backend = "gloo"
907
908            group, group_id, rank = initializer(backend=new_backend)
909            if group_id is None:
910                return
911
912            if new_backend == "gloo":
913                self.assertTrue(group_id._get_backend_name(), "gloo")
914            if new_backend == "nccl":
915                self.assertTrue(group_id._get_backend_name(), "nccl")
916
917            self.assertEqual(rank, group[dist.get_rank(group_id)])
918            self.assertEqual(len(group), dist.get_world_size(group_id))
919
920            # Pin device (so we avoid NCCL race conditions/deadlocks).
921            group_rank = dist.get_rank(group_id)
922            torch.cuda.set_device(group_rank)
923
924            # Run broadcast of CUDA tensor (so it works for both Gloo and NCCL).
925            tensor = _build_tensor(2, value=group_rank).cuda()
926            dist.broadcast(tensor, src=group[0], group=group_id)
927            self.assertEqual(_build_tensor(2, value=0), tensor.to("cpu"))
928
929        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
930        @require_world_size(3)
931        @skip_if_lt_x_gpu(2)
932        def test_backend_group(self):
933            self._test_group_override_backend(self._init_group_test)
934
935        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
936        @skip_if_lt_x_gpu(2)
937        @unittest.skipIf(BACKEND == "ucc", "broken, see https://github.com/pytorch/pytorch/pull/113620")
938        def test_backend_full_group(self):
939            self._test_group_override_backend(self._init_full_group_test)
940
941        @skip_but_pass_in_sandcastle_if(
942            BACKEND not in DistTestCases.backend_feature["subgroup"],
943            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
944        )
945        @require_world_size(4)
946        @skip_if_lt_x_gpu(2)
947        def test_new_subgroups(self):
948            subgroup_size = 2
949            cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)
950
951            world_size = dist.get_world_size()
952            self.assertEqual(cur_subgroup.size(), subgroup_size)
953            self.assertEqual(len(subgroups), world_size / subgroup_size)
954            self.assertFalse(dist._rank_not_in_group(cur_subgroup))
955
956            for subgroup in subgroups:
957                dist.destroy_process_group(subgroup)
958
959        @skip_but_pass_in_sandcastle_if(
960            BACKEND not in DistTestCases.backend_feature["subgroup"],
961            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
962        )
963        @skip_if_no_gpu
964        def test_new_subgroups_group_size_exceeds_world_size(self):
965            with self.assertRaisesRegex(ValueError, "must not exceed"):
966                dist.new_subgroups(100)
967
968        @skip_but_pass_in_sandcastle_if(
969            BACKEND not in DistTestCases.backend_feature["subgroup"],
970            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
971        )
972        @require_world_size(4)
973        @skip_if_lt_x_gpu(4)
974        def test_new_subgroups_world_size_not_divisible_by_group_size(self):
975            with self.assertRaisesRegex(
976                ValueError, "The world size must be divisible by 'group_size'"
977            ):
978                dist.new_subgroups(3)
979
980        @skip_but_pass_in_sandcastle_if(
981            BACKEND not in DistTestCases.backend_feature["subgroup"],
982            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
983        )
984        @require_world_size(4)
985        @skip_if_lt_x_gpu(4)
986        def test_new_subgroups_by_enumeration(self):
987            group, group_id, rank = self._init_global_test()
988            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
989            device_id = rank_to_GPU[rank][0]
990            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
991                ranks_per_subgroup_list=[[0, 2], [1, 3]]
992            )
993            if device_id >= 4:
994                self.assertIsNone(cur_subgroup)
995            else:
996                self.assertEqual(cur_subgroup.size(), 2)
997                self.assertEqual(len(subgroups), 2)
998                if device_id == 0 or device_id == 2:
999                    self.assertEqual(cur_subgroup, subgroups[0])
1000                else:
1001                    self.assertEqual(cur_subgroup, subgroups[1])
1002
1003            for subgroup in subgroups:
1004                dist.destroy_process_group(subgroup)
1005
1006        @skip_but_pass_in_sandcastle_if(
1007            BACKEND not in DistTestCases.backend_feature["subgroup"],
1008            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1009        )
1010        @require_world_size(4)
1011        @skip_if_lt_x_gpu(4)
1012        def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
1013            group, group_id, rank = self._init_global_test()
1014            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1015            device_id = rank_to_GPU[rank][0]
1016            world_size = get_world_size(group_id)
1017
1018            with self.assertRaisesRegex(
1019                RuntimeError,
1020                "The new group's rank should be within the world_size set by init_process_group",
1021            ):
1022                dist.new_subgroups_by_enumeration(
1023                    ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
1024                )
1025
1026        @skip_but_pass_in_sandcastle_if(
1027            BACKEND not in DistTestCases.backend_feature["subgroup"],
1028            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1029        )
1030        @skip_if_no_gpu
1031        def test_new_subgroups_by_enumeration_negative_input_rank(self):
1032            group, group_id, rank = self._init_global_test()
1033
1034            with self.assertRaisesRegex(
1035                ValueError,
1036                "The new group's rank should be within the world_size set by init_process_group",
1037            ):
1038                dist.new_subgroups_by_enumeration(
1039                    ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
1040                )
1041
1042        @skip_but_pass_in_sandcastle_if(
1043            BACKEND not in DistTestCases.backend_feature["subgroup"],
1044            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1045        )
1046        @require_world_size(4)
1047        @skip_if_lt_x_gpu(4)
1048        def test_new_subgroups_overlap_not_allowed(self):
1049            with self.assertRaisesRegex(
1050                ValueError, "Rank 1 has appeared in both subgroup"
1051            ):
1052                dist.new_subgroups_by_enumeration(
1053                    ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
1054                )
1055
1056        @skip_but_pass_in_sandcastle_if(
1057            BACKEND not in DistTestCases.backend_feature["subgroup"],
1058            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1059        )
1060        @skip_if_lt_x_gpu(2)
1061        def test_average_parameters(self):
1062            rank = dist.get_rank()
1063            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1064            device_id = rank_to_GPU[rank][0]
1065
1066            model = nn.Sequential(
1067                nn.Conv2d(3, 3, kernel_size=3, padding=1),
1068                nn.ReLU(),
1069                nn.Linear(1, 5, bias=False),
1070            ).cuda(device_id)
1071            # Test global model averaging
1072            for p in model.parameters():
1073                p.data = torch.ones_like(p.data)
1074            model_averaging_utils.average_parameters(
1075                params=model.parameters(), process_group=None
1076            )
1077            # Every element will be the same as the input.
1078            for p in model.parameters():
1079                self.assertEqual(p.data, torch.ones_like(p.data))
1080
1081            # Test partial model averaging
1082            for p in model.parameters():
1083                p.data = torch.ones_like(p.data) * rank
1084            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
1085            model_averaging_utils.average_parameters(
1086                params=model.parameters(), process_group=group_nccl
1087            )
1088            if not dist._rank_not_in_group(group_nccl):
1089                # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
1090                for p in model.parameters():
1091                    self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
1092            else:
1093                # Every element on device not in the subgroup should remain the same.
1094                for p in model.parameters():
1095                    self.assertEqual(p.data, torch.ones_like(p.data) * rank)
1096
1097        @skip_but_pass_in_sandcastle_if(
1098            BACKEND not in DistTestCases.backend_feature["subgroup"],
1099            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1100        )
1101        @skip_if_lt_x_gpu(2)
1102        def test_periodic_model_averager(self):
1103            rank = dist.get_rank()
1104            world_size = dist.get_world_size()
1105            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1106            device_id = rank_to_GPU[rank][0]
1107
1108            model = nn.Linear(1, 5, bias=False).cuda(device_id)
1109            param = next(model.parameters())
1110            tensor = torch.ones_like(param.data) * rank
1111            expected_avg_tensor = (
1112                torch.ones_like(param.data) * sum(range(world_size)) / world_size
1113            )
1114            period = 4
1115            for warmup_steps in [12, 13, 14, 15]:
1116                averager = averagers.PeriodicModelAverager(
1117                    period=period, warmup_steps=warmup_steps
1118                )
1119                for step in range(0, 20):
1120                    # Reset the parameters at every step.
1121                    param.data = copy.deepcopy(tensor)
1122                    for params in model.parameters():
1123                        # mock grad
1124                        params.grad = torch.ones_like(param.data)
1125                    averager.average_parameters(model.parameters())
1126                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
1127                        self.assertEqual(param.data, expected_avg_tensor)
1128                    else:
1129                        # No model averaging, so the parameters are not updated.
1130                        self.assertEqual(param.data, tensor)
1131
1132        @skip_if_lt_x_gpu(2)
1133        def test_periodic_model_averager_param_group(self):
1134            rank = dist.get_rank()
1135            world_size = dist.get_world_size()
1136            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1137            device_id = rank_to_GPU[rank][0]
1138
1139            model = nn.Linear(1, 5, bias=False).cuda(device_id)
1140            param = next(model.parameters())
1141            opt = torch.optim.SGD(model.parameters(), lr=0.1)
1142
1143            period = 4
1144            for warmup_steps in [12, 13, 14, 15]:
1145                averager = averagers.PeriodicModelAverager(
1146                    period=period, warmup_steps=warmup_steps
1147                )
1148                for step in range(0, 20):
1149                    # Reset the parameters at every step.
1150                    for param_group in opt.param_groups:
1151                        for params in param_group["params"]:
1152                            # mock grad
1153                            params.grad = torch.ones_like(param.data) * rank
1154                            params.data = torch.ones_like(param.data) * rank
1155                    averager.average_parameters(opt.param_groups)
1156                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
1157                        for param_group in opt.param_groups:
1158                            for params in param_group["params"]:
1159                                if params.grad is None:
1160                                    continue
1161                                self.assertEqual(
1162                                    param.data,
1163                                    torch.ones_like(param.data)
1164                                    * sum(range(world_size))
1165                                    / world_size,
1166                                )
1167                    else:
1168                        # No model averaging, so the parameters are not updated.
1169                        for param_group in opt.param_groups:
1170                            for params in param_group["params"]:
1171                                if params.grad is None:
1172                                    continue
1173                                self.assertEqual(
1174                                    param.data, torch.ones_like(param.data) * rank
1175                                )
1176
1177        @skip_but_pass_in_sandcastle_if(
1178            BACKEND not in DistTestCases.backend_feature["subgroup"],
1179            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1180        )
1181        @skip_if_lt_x_gpu(2)
1182        def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(
1183            self,
1184        ):
1185            rank = dist.get_rank()
1186            world_size = dist.get_world_size()
1187            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1188            device_id = rank_to_GPU[rank][0]
1189
1190            model = nn.Linear(1, 5, bias=False).cuda(device_id)
1191            param = next(model.parameters())
1192            tensor = torch.ones_like(param.data) * rank
1193            expected_avg_tensor = (
1194                torch.ones_like(param.data) * sum(range(world_size)) / world_size
1195            )
1196            period = 4
1197            for warmup_steps in [12, 13, 14, 15]:
1198                averager = hierarchicalSGD.HierarchicalModelAverager(
1199                    # Run the global averaging at a period of 4,
1200                    # which is equivalent to the above periodic model averaging test case.
1201                    period_group_size_dict=OrderedDict([(period, world_size)]),
1202                    warmup_steps=warmup_steps,
1203                )
1204
1205                averager = averagers.PeriodicModelAverager(
1206                    period=period, warmup_steps=warmup_steps
1207                )
1208                for step in range(0, 20):
1209                    # Reset the parameters at every step.
1210                    param.data = copy.deepcopy(tensor)
1211                    for params in model.parameters():
1212                        # mock grad
1213                        params.grad = torch.ones_like(param.data)
1214                    averager.average_parameters(model.parameters())
1215                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
1216                        self.assertEqual(param.data, expected_avg_tensor)
1217                    else:
1218                        # No model averaging, so the parameters are not updated.
1219                        self.assertEqual(param.data, tensor)
1220
1221        @skip_but_pass_in_sandcastle_if(
1222            BACKEND not in DistTestCases.backend_feature["subgroup"],
1223            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
1224        )
1225        @require_world_size(4)
1226        @skip_if_lt_x_gpu(4)
1227        def test_3_level_hierarchical_model_averager(self):
1228            rank = dist.get_rank()
1229            world_size = dist.get_world_size()
1230            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1231            device_id = rank_to_GPU[rank][0]
1232
1233            model = nn.Linear(1, 5, bias=False).cuda(device_id)
1234            param = next(model.parameters())
1235            tensor = torch.ones_like(param.data) * rank
1236            # Set up such a hierarchical model averaging as follows:
1237            # after the first 10 warmup steps,
1238            # run model averaging every 2 steps within each subgroup of size 2,
1239            # run model averaging every 4 steps within each subgroup of size 3,
1240            # and run the global model averaging every 8 steps.
1241            # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
1242            warmup_steps = 10
1243            subgroup_size1 = 2
1244            subgroup_avg_period1 = 2
1245            subgroup_size2 = 4
1246            subgroup_avg_period2 = 4
1247            global_avg_period = 8
1248            period_group_size_dict = OrderedDict(
1249                [
1250                    (subgroup_avg_period1, subgroup_size1),
1251                    (subgroup_avg_period2, subgroup_size2),
1252                    (global_avg_period, world_size),
1253                ]
1254            )
1255            averager = hierarchicalSGD.HierarchicalModelAverager(
1256                period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
1257            )
1258            self.assertEqual(dist.get_pg_count(), len(period_group_size_dict))
1259
1260            subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
1261            subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
1262            real_group_ranks_res1 = _get_pg_config(subgroup1)['ranks']
1263            real_group_ranks_res2 = _get_pg_config(subgroup2)['ranks']
1264
1265            expect_group_ranks_res1 = (
1266                rank // subgroup_size1 * subgroup_size1
1267                + np.array(list(range(subgroup_size1)))
1268            ).tolist()
1269            expect_group_ranks_res2 = (
1270                rank // subgroup_size2 * subgroup_size2
1271                + np.array(list(range(subgroup_size2)))
1272            ).tolist()
1273            self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
1274            self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)
1275
1276            expected_avg_tensor_within_subgroup1 = (
1277                torch.ones_like(param.data)
1278                * sum(real_group_ranks_res1)
1279                / subgroup_size1
1280            )
1281            expected_avg_tensor_within_subgroup2 = (
1282                torch.ones_like(param.data)
1283                * sum(real_group_ranks_res2)
1284                / subgroup_size2
1285            )
1286            expected_global_avg_tensor = (
1287                torch.ones_like(param.data) * sum(range(world_size)) / world_size
1288            )
1289            for step in range(0, 25):
1290                # Reset the parameters at every step.
1291                param.data = copy.deepcopy(tensor)
1292                for params in model.parameters():
1293                    # mock grad
1294                    params.grad = torch.ones_like(param.data)
1295                averager.average_parameters(model.parameters())
1296                if step == 16 or step == 24:
1297                    # Run global model averaging when `step` can be divided by 8.
1298                    self.assertEqual(param.data, expected_global_avg_tensor)
1299                elif step == 12 or step == 20:
1300                    # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
1301                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
1302                elif step == 10 or step == 14 or step == 18 or step == 22:
1303                    # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
1304                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
1305                else:
1306                    # No model averaging, so the parameters are not updated.
1307                    self.assertEqual(param.data, tensor)
1308
1309        # Coalescing manager (sync mode)
1310        @skip_if_no_gpu
1311        @skip_but_pass_in_sandcastle_if(
1312            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
1313            "Coalescing manager currently tests with NCCL only; internal test flaky"
1314        )
1315        def test_coalescing_manager(self):
1316            self._barrier()
1317            rank = dist.get_rank()
1318            world_size = dist.get_world_size()
1319            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1320            device_id = rank_to_GPU[rank][0]
1321            torch.cuda.set_device(device_id)
1322            num_colls = 2
1323            size_per_coll = 8
1324            small_tensors = [
1325                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
1326            ]
1327
1328            with dist._coalescing_manager():
1329                for i in range(num_colls):
1330                    dist.all_reduce(small_tensors[i])
1331
1332            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
1333            dist.all_reduce(big_tensor)
1334
1335            for i in range(num_colls):
1336                self.assertEqual(
1337                    small_tensors[i],
1338                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
1339                )
1340
1341            self._barrier()
1342
1343        # Coalescing manager (async mode)
1344        @skip_if_no_gpu
1345        @skip_but_pass_in_sandcastle_if(
1346            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
1347            "Coalescing manager currently tests with NCCL only; internal test flaky"
1348        )
1349        def test_coalescing_manager_async(self):
1350            self._barrier()
1351            rank = dist.get_rank()
1352            world_size = dist.get_world_size()
1353            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1354            device_id = rank_to_GPU[rank][0]
1355            torch.cuda.set_device(device_id)
1356            num_colls = 2
1357            size_per_coll = 8
1358            small_tensors = [
1359                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
1360            ]
1361
1362            with dist._coalescing_manager(async_ops=True) as cm:
1363                for i in range(num_colls):
1364                    dist.all_reduce(small_tensors[i])
1365            cm.wait()
1366
1367            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
1368            dist.all_reduce(big_tensor)
1369
1370            for i in range(num_colls):
1371                self.assertEqual(
1372                    small_tensors[i],
1373                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
1374                )
1375
1376            self._barrier()
1377
1378        # NCCL Batch SEND RECV
1379        @skip_if_no_gpu
1380        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1381        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1382        def test_batch_isend_irecv_nccl(self):
1383            self._barrier()
1384            rank = dist.get_rank()
1385            world_size = dist.get_world_size()
1386            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1387            device_id = rank_to_GPU[rank][0]
1388            torch.cuda.set_device(device_id)
1389            p2p_op_list = []
1390            recv_tensors = [None for _ in range(world_size)]
1391            expected_tensors = [None for _ in range(world_size)]
1392
1393            for val in ["1", "0"]:
1394                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
1395                for src in range(0, world_size):
1396                    send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(
1397                        src
1398                    )
1399                    recv_tensors[src] = _build_tensor(
1400                        src + 1, value=-1, device_id=device_id
1401                    ).fill_(-1)
1402                    expected_tensors[src] = _build_tensor(
1403                        src + 1, value=-1, device_id=device_id
1404                    ).fill_(rank)
1405                    recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
1406                    p2p_op_list.append(recv_op)
1407                    send_op = dist.P2POp(dist.isend, send_tensor, src)
1408                    p2p_op_list.append(send_op)
1409
1410                reqs = dist.batch_isend_irecv(p2p_op_list)
1411                for req in reqs:
1412                    req.wait()
1413
1414                for src in range(0, world_size):
1415                    self.assertEqual(recv_tensors[src], expected_tensors[src])
1416
1417            self._barrier()
1418
1419        @skip_if_no_gpu
1420        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1421        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1422        def test_batch_isend_irecv_ring_exchange_nccl(self):
1423            self._barrier()
1424            rank = dist.get_rank()
1425            world_size = dist.get_world_size()
1426            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1427            device_id = rank_to_GPU[rank][0]
1428            torch.cuda.set_device(device_id)
1429            p2p_op_list = []
1430
1431            send_tensor = _build_tensor(world_size, device_id=device_id)
1432            recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
1433            send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
1434            recv_op = dist.P2POp(
1435                dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
1436            )
1437            reqs = dist.batch_isend_irecv([send_op, recv_op])
1438            for req in reqs:
1439                req.wait()
1440
1441            self._barrier()
1442
1443        @skip_if_no_gpu
1444        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1445        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1446        def test_batch_isend_irecv_self_nccl(self):
1447            self._barrier()
1448            # Ensure the process group has been fully initialized (needed by
1449            # the first sub-group batch_isend_irecv call)
1450            dist.barrier()
1451            rank = dist.get_rank()
1452            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1453            device_id = rank_to_GPU[rank][0]
1454            p2p_op_list = []
1455
1456            if rank == 0:
1457                send_tensor = _build_tensor(rank + 1, device_id=device_id)
1458                recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
1459                recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
1460                p2p_op_list.append(recv_op)
1461                send_op = dist.P2POp(dist.isend, send_tensor, 0)
1462                p2p_op_list.append(send_op)
1463
1464                reqs = dist.batch_isend_irecv(p2p_op_list)
1465                for req in reqs:
1466                    req.wait()
1467
1468            self._barrier()
1469
1470        @skip_if_no_gpu
1471        @skip_if_small_worldsize
1472        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1473        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1474        def test_batch_isend_irecv_no_rank_zero_nccl(self):
1475            self._barrier()
1476            # Ensure the process group has been fully initialized (needed by
1477            # the first sub-group batch_isend_irecv call)
1478            dist.barrier()
1479            rank = dist.get_rank()
1480            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1481            device_id = rank_to_GPU[rank][0]
1482            torch.cuda.set_device(device_id)
1483            p2p_op_list = []
1484
1485            if rank == 1:
1486                peer = 2
1487            elif rank == 2:
1488                peer = 1
1489
1490            if rank in [1, 2]:
1491                send_tensor = _build_tensor(rank + 1, device_id=device_id)
1492                recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
1493                recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
1494                p2p_op_list.append(recv_op)
1495                send_op = dist.P2POp(dist.isend, send_tensor, peer)
1496                p2p_op_list.append(send_op)
1497
1498                reqs = dist.batch_isend_irecv(p2p_op_list)
1499                for req in reqs:
1500                    req.wait()
1501
1502            self._barrier()
1503
1504        # GLOO Batch SEND RECV CPU
1505        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
1506        def test_batch_isend_irecv_gloo(self):
1507            self._barrier()
1508            rank = dist.get_rank()
1509            p2p_op_list = []
1510
1511            for src in range(0, dist.get_world_size()):
1512                if src == rank:
1513                    continue
1514                send_tensor = _build_tensor(rank + 1)
1515                recv_tensor = _build_tensor(src + 1, value=-1)
1516                recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
1517                p2p_op_list.append(recv_op)
1518                send_op = dist.P2POp(dist.isend, send_tensor, src)
1519                p2p_op_list.append(send_op)
1520
1521            reqs = dist.batch_isend_irecv(p2p_op_list)
1522            for req in reqs:
1523                req.wait()
1524
1525            self._barrier()
1526
1527        # GLOO Batch SEND RECV CPU with provided tags
1528        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
1529        def test_batch_isend_irecv_gloo_tags(self):
1530            self._barrier()
1531            rank = dist.get_rank()
1532            p2p_op_list = []
1533
1534            for src in range(0, dist.get_world_size()):
1535                if src == rank:
1536                    continue
1537                send_tensor = _build_tensor(rank + 1)
1538                recv_tensor = _build_tensor(src + 1, value=-1)
1539                recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
1540                p2p_op_list.append(recv_op)
1541                send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
1542                p2p_op_list.append(send_op)
1543
1544            reqs = dist.batch_isend_irecv(p2p_op_list)
1545            for req in reqs:
1546                req.wait()
1547
1548            self._barrier()
1549
1550        # NCCL Batch SEND RECV Op Error
1551        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1552        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1553        def test_batch_isend_irecv_op_err(self):
1554            self._barrier()
1555            rank = dist.get_rank()
1556            if rank == 0:
1557                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1558                device_id = rank_to_GPU[rank][0]
1559                with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
1560                    send_tensor = _build_tensor(rank + 1, device_id=device_id)
1561                    send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
1562                    dist.batch_isend_irecv([send_op])
1563
1564        # NCCL Batch SEND RECV p2p_op_list Error
1565        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1566        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1567        def test_batch_isend_irecv_op_list_err(self):
1568            self._barrier()
1569            rank = dist.get_rank()
1570            if rank == 0:
1571                with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
1572                    dist.batch_isend_irecv([1, 2])
1573
1574        # NCCL Batch SEND RECV Mixed Backend Error
1575        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
1576        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1577        def test_batch_isend_irecv_mixed_backend_err(self):
1578            self._barrier()
1579            rank = dist.get_rank()
1580            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
1581            device_id = rank_to_GPU[rank][0]
1582            group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
1583            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
1584            if rank == 0:
1585                with self.assertRaisesRegex(
1586                    ValueError, "All ops need to use the same group"
1587                ):
1588                    send_tensor = _build_tensor(rank + 1)
1589                    send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
1590                    send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
1591                    dist.batch_isend_irecv([send_op_gloo, send_op_nccl])
1592
1593        # NCCL SEND RECV
1594        @skip_if_no_gpu
1595        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
1596        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1597        def _test_send_recv_nccl(self, profiler_ctx=None):
1598            # TODO: now that nccl send/recv is supported, there does not seem to
1599            # be a need to have nccl send/recv be tested separately.
1600            rank = dist.get_rank()
1601            world_size = dist.get_world_size()
1602            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1603            device_id = rank_to_GPU[rank][0]
1604            torch.cuda.set_device(device_id)
1605
1606            tensor = _build_tensor(rank + 1, device_id=device_id)
1607            profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
1608            with profiler_cls as prof:
1609                for src in range(0, world_size):
1610                    if src == rank:
1611                        # Send mode
1612                        for dst in range(0, world_size):
1613                            if dst == rank:
1614                                continue
1615                            dist.send(tensor, dst)
1616                    else:
1617                        # Recv mode
1618                        expected_tensor = _build_tensor(src + 1)
1619                        output_tensor = _build_tensor(
1620                            src + 1, value=-1, device_id=device_id
1621                        )
1622                        dist.recv(output_tensor, src)
1623                        self.assertEqual(output_tensor, expected_tensor)
1624
1625                self._barrier()
1626
1627            if profiler_ctx is not None:
1628                backend = dist.get_backend()
1629                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
1630                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
1631                        events = get_profiling_event(event_name, prof, dedup_gpu_user_annotation=True)
1632                        self.assertTrue(events)
1633                        # Event order is not deterministic, so simply assert their shape
1634                        # is found in the following list.
1635                        expected_shapes = [
1636                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
1637                        ]
1638                        for event in events:
1639                            self.assertTrue(event.input_shapes in expected_shapes)
1640
1641
1642        @skip_if_no_gpu
1643        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
1644        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1645        def test_send_recv_nccl(self):
1646            self._test_send_recv_nccl()
1647
1648        @skip_if_no_gpu
1649        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
1650        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1651        def test_send_recv_nccl_autograd_profiler(self):
1652            profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
1653            self._test_send_recv_nccl(profiler_ctx)
1654
1655        @skip_if_no_gpu
1656        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
1657        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
1658        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
1659        @skip_but_pass_in_sandcastle_if(
1660            IS_MACOS or IS_WINDOWS,
1661            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
1662        )
1663        def test_send_recv_nccl_torch_profiler(self):
1664            profiler_ctx = torch.profiler.profile(
1665                activities=[
1666                    torch.profiler.ProfilerActivity.CPU,
1667                    torch.profiler.ProfilerActivity.CUDA,
1668                ],
1669                record_shapes=True,
1670            )
1671            self._test_send_recv_nccl(profiler_ctx)
1672
1673        # SEND RECV
1674        def _test_send_recv(self, profiler_ctx):
1675            rank = dist.get_rank()
1676            send_size = rank + 1
1677            tensor = _build_tensor(send_size)
1678            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
1679            with ctx as prof:
1680                for src in range(0, dist.get_world_size()):
1681                    if src == rank:
1682                        # Send mode
1683                        for dst in range(0, dist.get_world_size()):
1684                            if dst == rank:
1685                                continue
1686                            dist.send(tensor, dst)
1687                    else:
1688                        # Recv mode
1689                        recv_size = src + 1
1690                        expected_tensor = _build_tensor(recv_size)
1691                        output_tensor = _build_tensor(recv_size, value=-1)
1692                        dist.recv(output_tensor, src)
1693                        self.assertEqual(output_tensor, expected_tensor)
1694
1695            if profiler_ctx is not None:
1696                backend = dist.get_backend()
1697                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
1698                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
1699                        events = get_profiling_event(event_name, prof)
1700                        # Each rank sends/recvs from all other ranks.
1701                        event_count = sum(e.count for e in events)
1702                        expected_event_count = dist.get_world_size() - 1
1703                        self.assertEqual(event_count, expected_event_count)
1704                        # Event order is not deterministic, so simply assert their shape
1705                        # is found in the following list.
1706                        expected_shapes = [
1707                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
1708                        ]
1709                        for event in events:
1710                            self.assertTrue(event.is_async)
1711                            self.assertTrue(event.input_shapes in expected_shapes)
1712
1713        @skip_but_pass_in_sandcastle_if(
1714            BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
1715        )
1716        def test_send_recv(self):
1717            self._test_send_recv(profiler_ctx=None)
1718
1719        @skip_but_pass_in_sandcastle_if(
1720            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
1721        )
1722        def test_send_recv_autograd_profiler(self):
1723            autograd_profiler_ctx = _create_autograd_profiler()
1724            self._test_send_recv(profiler_ctx=autograd_profiler_ctx)
1725
1726        @skip_but_pass_in_sandcastle_if(
1727            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
1728        )
1729        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
1730        @skip_but_pass_in_sandcastle_if(
1731            IS_MACOS or IS_WINDOWS,
1732            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
1733        )
1734        def test_send_recv_torch_profiler(self):
1735            torch_profiler_ctx = _create_torch_profiler()
1736            return self._test_send_recv(profiler_ctx=torch_profiler_ctx)
1737
1738        # SEND RECV ANY SOURCE
1739        def _test_send_recv_any_source(self, profiler_ctx):
1740            rank = dist.get_rank()
1741            send_recv_size = 10
1742            tensor = _build_tensor(send_recv_size, value=rank)
1743            recv_ranks = []
1744            irecv_ranks = []
1745
1746            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
1747            with ctx as prof:
1748                for dst in range(0, dist.get_world_size()):
1749                    if dst == rank:
1750                        # Recv mode
1751                        for dst in range(0, dist.get_world_size()):
1752                            if dst == rank:
1753                                continue
1754
1755                            for recv in ["recv", "irecv"]:
1756                                output_tensor = _build_tensor(send_recv_size, value=-1)
1757
1758                                if recv == "recv":
1759                                    sender = dist.recv(output_tensor)
1760                                    recv_ranks.append(sender)
1761                                elif recv == "irecv":
1762                                    work = dist.irecv(output_tensor)
1763                                    work.wait()
1764                                    sender = work._source_rank()
1765                                    irecv_ranks.append(sender)
1766
1767                                # Assert the scalar value "sender" that should be
1768                                # equal to the rank of the sender is equal to all
1769                                # values in the received tensor.
1770                                self.assertTrue(output_tensor.eq(sender).all())
1771                    else:
1772                        # Send mode
1773                        dist.send(tensor, dst)  # recv
1774                        dist.send(tensor, dst)  # irecv
1775
1776            if profiler_ctx is not None:
1777                backend = dist.get_backend()
1778                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
1779                    for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
1780                        events = get_profiling_event(event_name, prof)
1781                        # Each rank sends/recvs from other rank twice.
1782                        self.assertEqual(
1783                            sum(event.count for event in events),
1784                            2 * (dist.get_world_size() - 1),
1785                        )
1786                        for event in events:
1787                            self.assertTrue(event.is_async)
1788                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
1789
1790                # Each rank would have 2 * (world_size - 1) sends, verify that
1791                # globally we receive the same amount on the other end.
1792                recv_ranks_tensor = torch.cat(
1793                    (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
1794                )
1795                global_recv_ranks = [
1796                    torch.empty_like(recv_ranks_tensor)
1797                    for _ in range(dist.get_world_size())
1798                ]
1799                dist.all_gather(global_recv_ranks, recv_ranks_tensor)
1800                global_recv_ranks_list = []
1801                for tensor in global_recv_ranks:
1802                    global_recv_ranks_list += tensor.tolist()
1803
1804                from itertools import groupby
1805
1806                global_recv_ranks_list.sort()
1807                frequency = [
1808                    len(list(group)) for key, group in groupby(global_recv_ranks_list)
1809                ]
1810                self.assertEqual(dist.get_world_size(), len(frequency))
1811                self.assertEqual(
1812                    [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
1813                )
1814                self._barrier()
1815
1816        @skip_but_pass_in_sandcastle_if(
1817            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
1818            f"{BACKEND} does not support send/recv from any source",
1819        )
1820        def test_send_recv_any_source(self):
1821            self._test_send_recv_any_source(profiler_ctx=None)
1822
1823        @skip_but_pass_in_sandcastle_if(
1824            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
1825            f"{BACKEND} does not support send/recv from any source",
1826        )
1827        def test_send_recv_any_source_autograd_profiler(self):
1828            autograd_profiler_ctx = _create_autograd_profiler()
1829            self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)
1830
1831        @skip_but_pass_in_sandcastle_if(
1832            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
1833            f"{BACKEND} does not support send/recv from any source",
1834        )
1835        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
1836        @skip_but_pass_in_sandcastle_if(
1837            IS_MACOS or IS_WINDOWS,
1838            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
1839        )
1840        def test_send_recv_any_source_torch_profiler(self):
1841            torch_profiler_ctx = _create_torch_profiler()
1842            return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)
1843
1844        # SEND RECV WITH TAG
1845        def _test_send_recv_with_tag(self, profiler_ctx):
1846            rank = dist.get_rank()
1847            world_size = dist.get_world_size()
1848            send_recv_size = 10
1849            tensor = _build_tensor(send_recv_size, value=rank)
1850            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
1851            with ctx as prof:
1852                for dst in range(0, world_size):
1853                    if dst == rank:
1854                        # Recv mode
1855                        for src in range(0, world_size):
1856                            if src == rank:
1857                                continue
1858                            output_tensor = _build_tensor(send_recv_size, value=-1)
1859                            dist.recv(output_tensor, src, tag=src)
1860                            self.assertTrue(output_tensor.eq(src).all())
1861                    else:
1862                        # Send mode
1863                        dist.send(tensor, dst, tag=rank)
1864
1865            if profiler_ctx is not None:
1866                backend = dist.get_backend()
1867                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
1868                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
1869                        events = get_profiling_event(event_name, prof)
1870                        # Each rank sends/recvs from all other ranks
1871                        event_count = sum(e.count for e in events)
1872                        expected_event_count = dist.get_world_size() - 1
1873                        self.assertEqual(event_count, expected_event_count)
1874                        for event in events:
1875                            self.assertTrue(event.is_async)
1876                            self.assertEqual(event.name, event_name)
1877                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
1878
1879        @skip_but_pass_in_sandcastle_if(
1880            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
1881        )
1882        def test_send_recv_with_tag(self):
1883            self._test_send_recv_with_tag(profiler_ctx=None)
1884
1885        @skip_but_pass_in_sandcastle_if(
1886            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
1887        )
1888        def test_send_recv_with_tag_autograd_profiler(self):
1889            autograd_profiler_ctx = _create_autograd_profiler()
1890            return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)
1891
1892        @skip_but_pass_in_sandcastle_if(
1893            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
1894        )
1895        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
1896        @skip_but_pass_in_sandcastle_if(
1897            IS_MACOS or IS_WINDOWS,
1898            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
1899        )
1900        def test_send_recv_with_tag_torch_profiler(self):
1901            torch_profiler_ctx = _create_torch_profiler()
1902            return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)
1903
1904        # ISEND
1905        def _test_isend(self, profiler_ctx):
1906            rank = dist.get_rank()
1907            world_size = dist.get_world_size()
1908            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
1909            with ctx as prof:
1910                if rank == 0:
1911                    requests = [
1912                        dist.isend(_build_tensor(dest, 10), dest)
1913                        for dest in range(1, world_size)
1914                    ]
1915                    for request in requests:
1916                        request.wait()
1917                        self.assertTrue(request.is_completed())
1918                else:
1919                    tensor = _build_tensor(rank, -1)
1920                    dist.recv(tensor, 0)
1921                    self.assertEqual(tensor, _build_tensor(rank, 10))
1922
1923                self._barrier()
1924
1925            if profiler_ctx is not None:
1926                backend = dist.get_backend()
1927                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
1928                    expected_event_name = (
1929                        f"{backend}:send" if rank == 0 else f"{backend}:recv"
1930                    )
1931                    events = get_profiling_event(expected_event_name, prof)
1932                    event_count = sum(e.count for e in events)
1933                    expected_count = dist.get_world_size() - 1 if rank == 0 else 1
1934                    self.assertEqual(expected_count, event_count)
1935                    # Event ordering is not guaranteed, so simply ensure the shapes are
1936                    # found in the following map.
1937                    expected_shapes = {
1938                        r: [[r] * 3] for r in range(1, dist.get_world_size())
1939                    }
1940                    for event in events:
1941                        self.assertTrue(event.is_async)
1942                        self.assertEqual(event.name, expected_event_name)
1943                        if rank == 0:
1944                            self.assertTrue(
1945                                event.input_shapes in expected_shapes.values()
1946                            )
1947                        else:
1948                            self.assertEqual(event.input_shapes, expected_shapes[rank])
1949
1950        @skip_but_pass_in_sandcastle_if(
1951            BACKEND == "nccl", "Nccl does not support isend"
1952        )
1953        def test_isend(self):
1954            self._test_isend(profiler_ctx=None)
1955
1956        @skip_but_pass_in_sandcastle_if(
1957            BACKEND == "nccl", "Nccl does not support isend"
1958        )
1959        def test_isend_autograd_profiler(self):
1960            autograd_profiler_ctx = _create_autograd_profiler()
1961            self._test_isend(profiler_ctx=autograd_profiler_ctx)
1962
1963        @skip_but_pass_in_sandcastle_if(
1964            BACKEND == "nccl", "Nccl does not support isend"
1965        )
1966        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
1967        @skip_but_pass_in_sandcastle_if(
1968            IS_MACOS or IS_WINDOWS,
1969            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
1970        )
1971        def test_isend_torch_profiler(self):
1972            torch_profiler_ctx = _create_torch_profiler()
1973            self._test_isend(profiler_ctx=torch_profiler_ctx)
1974
1975        # IRECV
1976        @skip_but_pass_in_sandcastle_if(
1977            BACKEND == "nccl", "Nccl does not support irecv"
1978        )
1979        def test_irecv(self):
1980            rank = dist.get_rank()
1981            world_size = dist.get_world_size()
1982
1983            if rank == 0:
1984                expected_tensors = [
1985                    _build_tensor(src, -1) for src in range(1, world_size)
1986                ]
1987                requests = [
1988                    dist.irecv(expected_tensors[src - 1], src)
1989                    for src in range(1, world_size)
1990                ]
1991
1992                for src in range(1, world_size):
1993                    requests[src - 1].wait()
1994                    self.assertTrue(requests[src - 1].is_completed())
1995                    self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
1996            else:
1997                tensor = _build_tensor(rank, 10)
1998                dist.send(tensor, 0)
1999
2000            self._barrier()
2001
2002        # BROADCAST
2003        def _test_broadcast_helper(
2004            self,
2005            group,
2006            group_id,
2007            rank,
2008            cuda=False,
2009            rank_to_GPU=None,
2010            with_options=False,
2011        ):
2012            for dtype, value, requires_cuda in [
2013                (torch.float, -1e-10, False),
2014                (torch.double, -1e-100, False),
2015                (torch.half, -0.1, True),
2016                (torch.int8, -2, False),
2017                (torch.uint8, 129, False),
2018                (torch.int, -1e5, False),
2019                (torch.long, -1e15, False),
2020            ]:
2021                if requires_cuda and not cuda:
2022                    continue
2023                for src in group:
2024                    expected_tensor = _build_tensor(src + 1, value, dtype)
2025                    if cuda:
2026                        expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
2027                    if rank == src:
2028                        if with_options:
2029                            opts = dist.BroadcastOptions()
2030                            opts.rootTensor = 0
2031                            opts.rootRank = src
2032                            self.call_dist_op(
2033                                ":broadcast",
2034                                True,
2035                                group_id.broadcast,
2036                                [expected_tensor],
2037                                opts,
2038                            )
2039                        else:
2040                            self.call_dist_op(
2041                                ":broadcast",
2042                                False,
2043                                dist.broadcast,
2044                                expected_tensor,
2045                                src,
2046                                group_id,
2047                            )
2048                    else:
2049                        tensor = _build_tensor(src + 1, -1, dtype)
2050                        if cuda:
2051                            tensor = tensor.cuda(rank_to_GPU[rank][0])
2052                        if with_options:
2053                            opts = dist.BroadcastOptions()
2054                            opts.rootTensor = 0
2055                            opts.rootRank = src
2056                            self.call_dist_op(
2057                                ":broadcast", True, group_id.broadcast, [tensor], opts
2058                            )
2059                        else:
2060                            self.call_dist_op(
2061                                ":broadcast",
2062                                False,
2063                                dist.broadcast,
2064                                tensor,
2065                                src,
2066                                group_id,
2067                            )
2068                        self.assertEqual(tensor.size(), expected_tensor.size())
2069                        self.assertEqual(
2070                            tensor.ne(expected_tensor).max(), torch.tensor(False)
2071                        )
2072
2073            self._barrier()
2074
2075        @skip_but_pass_in_sandcastle_if(
2076            BACKEND == "nccl", "Nccl does not support CPU tensors"
2077        )
2078        def test_broadcast(self):
2079            group, group_id, rank = self._init_global_test()
2080            self._test_broadcast_helper(group, group_id, rank)
2081
2082        @skip_but_pass_in_sandcastle_if(
2083            BACKEND != "gloo" and BACKEND != "nccl",
2084            "Only Gloo and Nccl backend supports CUDA allReduce",
2085        )
2086        @skip_if_no_gpu
2087        def test_broadcast_cuda(self):
2088            group, group_id, rank = self._init_global_test()
2089            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2090            device_id = rank_to_GPU[rank][0]
2091            torch.cuda.set_device(device_id)
2092            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
2093
2094        @skip_if_small_worldsize
2095        @skip_but_pass_in_sandcastle_if(
2096            BACKEND == "nccl", "Nccl does not support CPU tensors"
2097        )
2098        def test_broadcast_group(self):
2099            group, group_id, rank = self._init_group_test()
2100            self._test_broadcast_helper(group, group_id, rank)
2101
2102        @skip_but_pass_in_sandcastle_if(
2103            BACKEND == "nccl", "Nccl does not support CPU tensors"
2104        )
2105        def test_broadcast_full_group(self):
2106            group, group_id, rank = self._init_full_group_test()
2107            self._test_broadcast_helper(group, group_id, rank)
2108
2109        @skip_but_pass_in_sandcastle_if(
2110            BACKEND != "nccl",
2111            "Only NCCL backend supports high priority stream",
2112        )
2113        @skip_if_no_gpu
2114        def test_nccl_high_priority_stream(self):
2115            group, _, rank = self._init_global_test()
2116            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2117            device_id = rank_to_GPU[rank][0]
2118            torch.cuda.set_device(device_id)
2119
2120            new_port = str(MASTER_PORT + 1)
2121            os.environ["MASTER_PORT"] = new_port
2122            gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
2123            store, rank, size = next(gen_iterator)
2124            store = dist.PrefixStore(new_port, store)
2125
2126            opts = dist.ProcessGroupNCCL.Options()
2127            opts.is_high_priority_stream = False
2128            group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
2129
2130            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
2131
2132        # REDUCE
2133        def _test_reduce_helper(
2134            self,
2135            group,
2136            group_id,
2137            rank,
2138            op,
2139            master_value,
2140            worker_value,
2141            expected_value,
2142            cuda=False,
2143            rank_to_GPU=None,
2144        ):
2145            for src in group:
2146                tensor = _build_tensor(src + 1).fill_(
2147                    master_value if rank == src else worker_value
2148                )
2149                if cuda:
2150                    tensor = tensor.cuda(rank_to_GPU[rank][0])
2151                self.call_dist_op(
2152                    ":reduce",
2153                    False,
2154                    dist.reduce,
2155                    tensor,
2156                    src,
2157                    op,
2158                    group_id,
2159                    tensor_shapes=[tensor.shape],
2160                )
2161                if rank == src:
2162                    self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
2163
2164            self._barrier()
2165
2166        @skip_but_pass_in_sandcastle_if(
2167            BACKEND == "nccl", "Nccl does not support CPU tensors"
2168        )
2169        @skip_but_pass_in_sandcastle_if(
2170            BACKEND in DistTestCases.skip_collective["reduce"],
2171            f"{BACKEND} does not support reduce",
2172        )
2173        def test_reduce_sum(self):
2174            group, group_id, rank = self._init_global_test()
2175            self._test_reduce_helper(
2176                group,
2177                group_id,
2178                rank,
2179                dist.ReduceOp.SUM,
2180                2,
2181                10,
2182                2 + (10 * (len(group) - 1)),
2183            )
2184
2185        @skip_but_pass_in_sandcastle_if(
2186            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
2187        )
2188        @skip_but_pass_in_sandcastle_if(
2189            BACKEND in DistTestCases.skip_collective["reduce"],
2190            f"{BACKEND} does not support reduce",
2191        )
2192        @skip_if_no_gpu
2193        def test_reduce_sum_cuda(self):
2194            group, group_id, rank = self._init_global_test()
2195            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2196            device_id = rank_to_GPU[rank][0]
2197            torch.cuda.set_device(device_id)
2198            self._test_reduce_helper(
2199                group,
2200                group_id,
2201                rank,
2202                dist.ReduceOp.SUM,
2203                2,
2204                10,
2205                2 + 10 * (len(group) - 1),
2206                True,
2207                rank_to_GPU,
2208            )
2209
2210        @skip_but_pass_in_sandcastle_if(
2211            BACKEND == "nccl", "Nccl does not support CPU tensors"
2212        )
2213        @skip_but_pass_in_sandcastle_if(
2214            BACKEND in DistTestCases.skip_collective["reduce"],
2215            f"{BACKEND} does not support reduce",
2216        )
2217        def test_reduce_product(self):
2218            group, group_id, rank = self._init_global_test()
2219            self._test_reduce_helper(
2220                group,
2221                group_id,
2222                rank,
2223                dist.ReduceOp.PRODUCT,
2224                2,
2225                10,
2226                reduce(operator.mul, [10] * (len(group) - 1), 2),
2227            )
2228
2229        @skip_but_pass_in_sandcastle_if(
2230            BACKEND == "nccl", "Nccl does not support CPU tensors"
2231        )
2232        @skip_but_pass_in_sandcastle_if(
2233            BACKEND in DistTestCases.skip_collective["reduce"],
2234            f"{BACKEND} does not support reduce",
2235        )
2236        def test_reduce_min(self):
2237            group, group_id, rank = self._init_global_test()
2238            self._test_reduce_helper(
2239                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2240            )
2241
2242        @skip_but_pass_in_sandcastle_if(
2243            BACKEND == "nccl", "Nccl does not support CPU tensors"
2244        )
2245        @skip_but_pass_in_sandcastle_if(
2246            BACKEND in DistTestCases.skip_collective["reduce"],
2247            f"{BACKEND} does not support reduce",
2248        )
2249        def test_reduce_max(self):
2250            group, group_id, rank = self._init_global_test()
2251            self._test_reduce_helper(
2252                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2253            )
2254
2255        @skip_but_pass_in_sandcastle_if(
2256            BACKEND == "nccl", "Nccl does not support CPU tensors"
2257        )
2258        @skip_but_pass_in_sandcastle_if(
2259            BACKEND in DistTestCases.skip_collective["reduce"],
2260            f"{BACKEND} does not support reduce",
2261        )
2262        @skip_if_small_worldsize
2263        def test_reduce_group_sum(self):
2264            group, group_id, rank = self._init_group_test()
2265            self._test_reduce_helper(
2266                group,
2267                group_id,
2268                rank,
2269                dist.ReduceOp.SUM,
2270                2,
2271                10,
2272                2 + (10 * (len(group) - 1)),
2273            )
2274
2275        @skip_but_pass_in_sandcastle_if(
2276            BACKEND == "nccl", "Nccl does not support CPU tensors"
2277        )
2278        @skip_but_pass_in_sandcastle_if(
2279            BACKEND in DistTestCases.skip_collective["reduce"],
2280            f"{BACKEND} does not support reduce",
2281        )
2282        @skip_if_small_worldsize
2283        def test_reduce_group_product(self):
2284            group, group_id, rank = self._init_group_test()
2285            self._test_reduce_helper(
2286                group,
2287                group_id,
2288                rank,
2289                dist.ReduceOp.PRODUCT,
2290                2,
2291                10,
2292                reduce(operator.mul, [10] * (len(group) - 1), 2),
2293            )
2294
2295        @skip_but_pass_in_sandcastle_if(
2296            BACKEND == "nccl", "Nccl does not support CPU tensors"
2297        )
2298        @skip_but_pass_in_sandcastle_if(
2299            BACKEND in DistTestCases.skip_collective["reduce"],
2300            f"{BACKEND} does not support reduce",
2301        )
2302        @skip_if_small_worldsize
2303        def test_reduce_group_min(self):
2304            group, group_id, rank = self._init_group_test()
2305            self._test_reduce_helper(
2306                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2307            )
2308
2309        @skip_but_pass_in_sandcastle_if(
2310            BACKEND == "nccl", "Nccl does not support CPU tensors"
2311        )
2312        @skip_but_pass_in_sandcastle_if(
2313            BACKEND in DistTestCases.skip_collective["reduce"],
2314            f"{BACKEND} does not support reduce",
2315        )
2316        @skip_if_small_worldsize
2317        def test_reduce_group_max(self):
2318            group, group_id, rank = self._init_group_test()
2319            self._test_reduce_helper(
2320                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2321            )
2322
2323        @skip_but_pass_in_sandcastle_if(
2324            BACKEND == "nccl", "Nccl does not support CPU tensors"
2325        )
2326        @skip_but_pass_in_sandcastle_if(
2327            BACKEND in DistTestCases.skip_collective["reduce"],
2328            f"{BACKEND} does not support reduce",
2329        )
2330        def test_reduce_full_group_sum(self):
2331            group, group_id, rank = self._init_full_group_test()
2332            self._test_reduce_helper(
2333                group,
2334                group_id,
2335                rank,
2336                dist.ReduceOp.SUM,
2337                2,
2338                10,
2339                2 + (10 * (len(group) - 1)),
2340            )
2341
2342        @skip_but_pass_in_sandcastle_if(
2343            BACKEND == "nccl", "Nccl does not support CPU tensors"
2344        )
2345        @skip_but_pass_in_sandcastle_if(
2346            BACKEND in DistTestCases.skip_collective["reduce"],
2347            f"{BACKEND} does not support reduce",
2348        )
2349        def test_reduce_full_group_product(self):
2350            group, group_id, rank = self._init_full_group_test()
2351            self._test_reduce_helper(
2352                group,
2353                group_id,
2354                rank,
2355                dist.ReduceOp.PRODUCT,
2356                2,
2357                10,
2358                reduce(operator.mul, [10] * (len(group) - 1), 2),
2359            )
2360
2361        @skip_but_pass_in_sandcastle_if(
2362            BACKEND == "nccl", "Nccl does not support CPU tensors"
2363        )
2364        @skip_but_pass_in_sandcastle_if(
2365            BACKEND in DistTestCases.skip_collective["reduce"],
2366            f"{BACKEND} does not support reduce",
2367        )
2368        def test_reduce_full_group_min(self):
2369            group, group_id, rank = self._init_full_group_test()
2370            self._test_reduce_helper(
2371                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2372            )
2373
2374        @skip_but_pass_in_sandcastle_if(
2375            BACKEND == "nccl", "Nccl does not support CPU tensors"
2376        )
2377        @skip_but_pass_in_sandcastle_if(
2378            BACKEND in DistTestCases.skip_collective["reduce"],
2379            f"{BACKEND} does not support reduce",
2380        )
2381        def test_reduce_full_group_max(self):
2382            group, group_id, rank = self._init_full_group_test()
2383            self._test_reduce_helper(
2384                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2385            )
2386
2387        # REDUCE TWICE
2388        def _test_reduce_twice_helper(
2389            self,
2390            group,
2391            group_id,
2392            rank,
2393            op,
2394            master_value,
2395            worker_value,
2396            expected_value,
2397            cuda=False,
2398            rank_to_GPU=None,
2399        ):
2400            for src in group:
2401                tensors = [
2402                    _build_tensor(src + 1).fill_(
2403                        master_value if rank == src else worker_value
2404                    )
2405                    for i in range(2)
2406                ]
2407                if cuda:
2408                    for i in range(2):
2409                        tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
2410                self.call_dist_op(
2411                    ":reduce",
2412                    False,
2413                    dist.reduce,
2414                    tensors[0],
2415                    src,
2416                    op,
2417                    group_id,
2418                    secondary_op_call=lambda: dist.reduce(
2419                        tensors[1], src, op, group_id
2420                    ),
2421                    tensor_shapes=[tensors[0].shape],
2422                )
2423                if rank == src:
2424                    for tensor in tensors:
2425                        self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
2426
2427            self._barrier()
2428
2429        @skip_but_pass_in_sandcastle_if(
2430            BACKEND == "nccl", "Nccl does not support CPU tensors"
2431        )
2432        @skip_but_pass_in_sandcastle_if(
2433            BACKEND in DistTestCases.skip_collective["reduce"],
2434            f"{BACKEND} does not support reduce",
2435        )
2436        def test_reduce_sum_twice(self):
2437            group, group_id, rank = self._init_global_test()
2438            self._test_reduce_twice_helper(
2439                group,
2440                group_id,
2441                rank,
2442                dist.ReduceOp.SUM,
2443                2,
2444                10,
2445                2 + (10 * (len(group) - 1)),
2446            )
2447
2448        @skip_but_pass_in_sandcastle_if(
2449            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
2450        )
2451        @skip_but_pass_in_sandcastle_if(
2452            BACKEND in DistTestCases.skip_collective["reduce"],
2453            f"{BACKEND} does not support reduce",
2454        )
2455        @skip_if_no_gpu
2456        def test_reduce_sum_cuda_twice(self):
2457            group, group_id, rank = self._init_global_test()
2458            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2459            device_id = rank_to_GPU[rank][0]
2460            torch.cuda.set_device(device_id)
2461            self._test_reduce_twice_helper(
2462                group,
2463                group_id,
2464                rank,
2465                dist.ReduceOp.SUM,
2466                2,
2467                10,
2468                2 + 10 * (len(group) - 1),
2469                True,
2470                rank_to_GPU,
2471            )
2472
2473        @skip_but_pass_in_sandcastle_if(
2474            BACKEND != "nccl", "Only Nccl supports reduce_scatter_v"
2475        )
2476        @skip_but_pass_in_sandcastle_if(
2477            BACKEND in DistTestCases.skip_collective["reduce"],
2478            f"{BACKEND} does not support reduce",
2479        )
2480        @skip_if_no_gpu
2481        def test_reduce_scatter_v_cuda(self):
2482            self._barrier()
2483            group, group_id, rank = self._init_global_test()
2484            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2485            device_id = rank_to_GPU[rank][0]
2486
2487            input_split_sizes = []
2488            for src in group:
2489                input_split_sizes.append(src + 1)
2490            start_len = sum(input_split_sizes[:rank])
2491            end_len = start_len + input_split_sizes[rank]
2492            sum_len = sum(input_split_sizes)
2493            master_value = 2
2494            worker_value = 10
2495
2496            for async_val in [True, False]:
2497                tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
2498                tensor[start_len:end_len].fill_(master_value)
2499                out_tensor = (
2500                    torch.empty(
2501                        input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
2502                    )
2503                    .fill_(-1)
2504                    .cuda(device_id)
2505                )
2506
2507                req = dist.reduce_scatter(
2508                    out_tensor,
2509                    list(torch.split(tensor, input_split_sizes)),
2510                    dist.ReduceOp.SUM,
2511                    group_id,
2512                    async_val,
2513                )
2514                if async_val:
2515                    req.wait()
2516
2517                expected_value = 2 + (10 * (len(group) - 1))
2518                expected_tensor = torch.empty(
2519                    input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
2520                )
2521                expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)
2522
2523                self.assertEqual(out_tensor, expected_tensor)
2524            self._barrier()
2525
2526        # Test reduce_scatter_tensor accepting single tensor as input
2527        def _reduce_scatter_tensor_helper(
2528            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
2529        ):
2530            if cuda:
2531                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
2532                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
2533            tensor_shapes = [tensor_out.shape]
2534            self.call_dist_op(
2535                ":reduce_scatter_tensor",
2536                False,
2537                dist.reduce_scatter_tensor,
2538                tensor_out,
2539                tensor_in,
2540                dist.ReduceOp.SUM,
2541                group_id,
2542                False,
2543                expect_event=False,
2544                tensor_shapes=tensor_shapes,
2545            )
2546            return tensor_out
2547
2548        @skip_but_pass_in_sandcastle_if(
2549            BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor"
2550        )
2551        @skip_if_no_gpu
2552        def test_reduce_scatter_tensor_cuda(self):
2553            group, group_id, rank = self._init_global_test()
2554            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2555            size = 2
2556            tensor_out = torch.zeros(size, dtype=torch.int64)
2557
2558            # Concatenated input
2559            tensor_in = torch.arange(len(group) * size)
2560            tensor_out = self._reduce_scatter_tensor_helper(
2561                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
2562            )
2563            # Check result
2564            expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
2565            self.assertEqual(tensor_out, expected_tensor)
2566            self._barrier()
2567
2568            # Stacked input
2569            tensor_in = torch.reshape(tensor_in, (len(group), size))
2570            tensor_out = self._reduce_scatter_tensor_helper(
2571                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
2572            )
2573            # Check result
2574            # Should be the same as the result in concatenated case
2575            self.assertEqual(tensor_out, expected_tensor)
2576            self._barrier()
2577
2578        def call_dist_op(
2579            self,
2580            profiling_title_postfix,
2581            is_async,
2582            op,
2583            *args,
2584            expect_event=True,
2585            secondary_op_call=None,
2586            profile_cuda=False,
2587            tensor_shapes=None,
2588            **kwargs,
2589        ):
2590            op_calls = [lambda: op(*args, **kwargs)]
2591            if secondary_op_call is not None:
2592                op_calls.append(secondary_op_call)
2593
2594            autograd_profiler_ctx = torch.autograd.profiler.profile(
2595                use_cuda=profile_cuda, record_shapes=True
2596            )
2597
2598            # TODO: move this test to use torch.profiler once kineto issues are
2599            # fixed internally.
2600            with autograd_profiler_ctx as prof:
2601                works = [op_call() for op_call in op_calls]
2602                if is_async:
2603                    for work in works:
2604                        work.wait()
2605
2606            if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
2607                # We are only interested in the backend's implementation not the dispatcher wrapper.
2608                events = get_profiling_event(
2609                    dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
2610                )
2611                # DETAIL debug mode can use a pg wrapper that issues more collectives
2612                # under the hood
2613                if dist.get_debug_level() != dist.DebugLevel.DETAIL:
2614                    self.assertEqual(len(events), len(op_calls))
2615                for e in events:
2616                    self.assertTrue(e.is_async)
2617                    self.assertEqual(e.count, 1)
2618                    self.assertGreaterEqual(e.cpu_time, 0)
2619                    # Verify tensor shapes if given
2620                    # DETAIL debug mode can use a pg wrapper that issues more collectives
2621                    # under the hood
2622                    if (
2623                        tensor_shapes is not None
2624                        and dist.get_debug_level() != dist.DebugLevel.DETAIL
2625                    ):
2626                        self.assertEqual(
2627                            e.input_shapes,
2628                            tensor_shapes,
2629                            f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
2630                        )
2631
2632        # ALL REDUCE
2633        def _test_all_reduce_helper(
2634            self,
2635            group,
2636            group_id,
2637            rank,
2638            op,
2639            master_value,
2640            worker_value,
2641            expected_value,
2642            cuda=False,
2643            rank_to_GPU=None,
2644            dtype=torch.float,
2645            async_op=False,
2646        ):
2647            for src in group:
2648                curr_value = master_value if rank == src else worker_value
2649
2650                tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
2651                if cuda:
2652                    tensor = tensor.cuda(rank_to_GPU[rank][0])
2653                if tensor.dtype == torch.complex64:
2654                    tensor_shapes = [torch.view_as_real(tensor).shape]
2655                else:
2656                    tensor_shapes = [tensor.shape]
2657                self.call_dist_op(
2658                    ":all_reduce",
2659                    async_op,
2660                    dist.all_reduce,
2661                    tensor,
2662                    op,
2663                    group_id,
2664                    async_op=async_op,
2665                    tensor_shapes=tensor_shapes,
2666                )
2667                # Currently, only Gloo backend has profiling tested with CUDA enabled.
2668                # Only run cuda profiling test for one rank to speed up since
2669                # running with different src_rank does not affect the correctness.
2670                if (
2671                    src == 0
2672                    and cuda
2673                    and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
2674                ):
2675                    self.call_dist_op(
2676                        ":all_reduce",
2677                        async_op,
2678                        dist.all_reduce,
2679                        tensor,
2680                        op,
2681                        group_id,
2682                        async_op=async_op,
2683                        profile_cuda=True,
2684                        tensor_shapes=tensor_shapes,
2685                    )
2686
2687            self._barrier()
2688
2689        @skip_but_pass_in_sandcastle_if(
2690            BACKEND == "nccl", "Nccl does not support CPU tensors"
2691        )
2692        def test_all_reduce_sum(self):
2693            group, group_id, rank = self._init_global_test()
2694            self._test_all_reduce_helper(
2695                group,
2696                group_id,
2697                rank,
2698                dist.ReduceOp.SUM,
2699                2,
2700                10,
2701                2 + (10 * (len(group) - 1)),
2702            )
2703
2704        @skip_but_pass_in_sandcastle_if(
2705            BACKEND == "nccl", "Nccl does not support CPU tensors"
2706        )
2707        def test_all_reduce_sum_async(self):
2708            group, group_id, rank = self._init_global_test()
2709            self._test_all_reduce_helper(
2710                group,
2711                group_id,
2712                rank,
2713                dist.ReduceOp.SUM,
2714                2,
2715                10,
2716                2 + (10 * (len(group) - 1)),
2717                async_op=True,
2718            )
2719
2720        @skip_but_pass_in_sandcastle_if(
2721            BACKEND != "gloo" and BACKEND != "nccl",
2722            "Only Gloo and NCCL backends will have CUDA allReduce tested",
2723        )
2724        @skip_if_no_gpu
2725        def test_all_reduce_sum_cuda(self):
2726            torch.cuda.set_device(self.rank)
2727            group, group_id, rank = self._init_global_test()
2728            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2729            self._test_all_reduce_helper(
2730                group,
2731                group_id,
2732                rank,
2733                dist.ReduceOp.SUM,
2734                2,
2735                10,
2736                2 + (10 * (len(group) - 1)),
2737                True,
2738                rank_to_GPU,
2739            )
2740
2741        @skip_but_pass_in_sandcastle_if(
2742            BACKEND != "gloo" and BACKEND != "nccl",
2743            "Only Gloo and NCCL backends will have CUDA allReduce tested",
2744        )
2745        @skip_if_no_gpu
2746        def test_all_reduce_sum_cuda_async(self):
2747            torch.cuda.set_device(self.rank)
2748            group, group_id, rank = self._init_global_test()
2749            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2750            self._test_all_reduce_helper(
2751                group,
2752                group_id,
2753                rank,
2754                dist.ReduceOp.SUM,
2755                2,
2756                10,
2757                2 + (10 * (len(group) - 1)),
2758                True,
2759                rank_to_GPU,
2760                async_op=True,
2761            )
2762
2763        @skip_but_pass_in_sandcastle_if(
2764            BACKEND == "nccl", "Nccl does not support CPU tensors"
2765        )
2766        def test_all_reduce_sum_complex(self):
2767            group, group_id, rank = self._init_global_test()
2768            self._test_all_reduce_helper(
2769                group,
2770                group_id,
2771                rank,
2772                dist.ReduceOp.SUM,
2773                complex(2, 3),
2774                complex(10, 11),
2775                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
2776                dtype=torch.cfloat,
2777            )
2778
2779        @skip_but_pass_in_sandcastle_if(
2780            BACKEND == "nccl", "Nccl does not support CPU tensors"
2781        )
2782        def test_all_reduce_complex_unsupported_ops(self):
2783            unsupported_ops = [
2784                dist.ReduceOp.MAX,
2785                dist.ReduceOp.MIN,
2786                dist.ReduceOp.PRODUCT,
2787                dist.ReduceOp.BAND,
2788                dist.ReduceOp.BOR,
2789                dist.ReduceOp.BXOR,
2790            ]
2791            group, group_id, rank = self._init_global_test()
2792            for unsupported_op in unsupported_ops:
2793                with self.assertRaisesRegex(
2794                    ValueError, "all_reduce does not support"
2795                ):
2796                    dist.all_reduce(
2797                        _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
2798                    )
2799
2800        @skip_but_pass_in_sandcastle_if(
2801            BACKEND != "gloo" and BACKEND != "nccl",
2802            "Only Gloo and NCCL backends will have CUDA allReduce tested",
2803        )
2804        @skip_if_no_gpu
2805        def test_all_reduce_sum_cuda_complex(self):
2806            torch.cuda.set_device(self.rank)
2807            group, group_id, rank = self._init_global_test()
2808            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
2809            self._test_all_reduce_helper(
2810                group,
2811                group_id,
2812                rank,
2813                dist.ReduceOp.SUM,
2814                complex(2, 3),
2815                complex(10, 11),
2816                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
2817                True,
2818                rank_to_GPU,
2819                dtype=torch.cfloat,
2820            )
2821
2822        @skip_but_pass_in_sandcastle_if(
2823            BACKEND == "nccl", "Nccl does not support CPU tensors"
2824        )
2825        def test_all_reduce_product(self):
2826            group, group_id, rank = self._init_global_test()
2827            self._test_all_reduce_helper(
2828                group,
2829                group_id,
2830                rank,
2831                dist.ReduceOp.PRODUCT,
2832                2,
2833                10,
2834                reduce(operator.mul, [10] * (len(group) - 1), 2),
2835            )
2836
2837        @skip_but_pass_in_sandcastle_if(
2838            BACKEND == "nccl", "Nccl does not support CPU tensors"
2839        )
2840        def test_all_reduce_min(self):
2841            group, group_id, rank = self._init_global_test()
2842            self._test_all_reduce_helper(
2843                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2844            )
2845
2846        @skip_but_pass_in_sandcastle_if(
2847            BACKEND == "nccl", "Nccl does not support CPU tensors"
2848        )
2849        def test_all_reduce_max(self):
2850            group, group_id, rank = self._init_global_test()
2851            self._test_all_reduce_helper(
2852                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2853            )
2854
2855        @skip_if_small_worldsize
2856        @skip_but_pass_in_sandcastle_if(
2857            BACKEND == "nccl", "Nccl does not support CPU tensors"
2858        )
2859        def test_all_reduce_group_sum(self):
2860            group, group_id, rank = self._init_group_test()
2861            self._test_all_reduce_helper(
2862                group,
2863                group_id,
2864                rank,
2865                dist.ReduceOp.SUM,
2866                2,
2867                10,
2868                2 + (10 * (len(group) - 1)),
2869            )
2870
2871        @skip_if_small_worldsize
2872        @skip_but_pass_in_sandcastle_if(
2873            BACKEND == "nccl", "Nccl does not support CPU tensors"
2874        )
2875        def test_all_reduce_group_product(self):
2876            group, group_id, rank = self._init_group_test()
2877            self._test_all_reduce_helper(
2878                group,
2879                group_id,
2880                rank,
2881                dist.ReduceOp.PRODUCT,
2882                2,
2883                10,
2884                reduce(operator.mul, [10] * (len(group) - 1), 2),
2885            )
2886
2887        @skip_if_small_worldsize
2888        @skip_but_pass_in_sandcastle_if(
2889            BACKEND == "nccl", "Nccl does not support CPU tensors"
2890        )
2891        def test_all_reduce_group_min(self):
2892            group, group_id, rank = self._init_group_test()
2893            self._test_all_reduce_helper(
2894                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2895            )
2896
2897        @skip_if_small_worldsize
2898        @skip_but_pass_in_sandcastle_if(
2899            BACKEND == "nccl", "Nccl does not support CPU tensors"
2900        )
2901        def test_all_reduce_group_max(self):
2902            group, group_id, rank = self._init_group_test()
2903            self._test_all_reduce_helper(
2904                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2905            )
2906
2907        @skip_but_pass_in_sandcastle_if(
2908            BACKEND == "nccl", "Nccl does not support CPU tensors"
2909        )
2910        def test_all_reduce_full_group_sum(self):
2911            group, group_id, rank = self._init_full_group_test()
2912            self._test_all_reduce_helper(
2913                group,
2914                group_id,
2915                rank,
2916                dist.ReduceOp.SUM,
2917                2,
2918                10,
2919                2 + (10 * (len(group) - 1)),
2920            )
2921
2922        @skip_but_pass_in_sandcastle_if(
2923            BACKEND == "nccl", "Nccl does not support CPU tensors"
2924        )
2925        def test_all_reduce_full_group_product(self):
2926            group, group_id, rank = self._init_full_group_test()
2927            self._test_all_reduce_helper(
2928                group,
2929                group_id,
2930                rank,
2931                dist.ReduceOp.PRODUCT,
2932                2,
2933                10,
2934                reduce(operator.mul, [10] * (len(group) - 1), 2),
2935            )
2936
2937        @skip_but_pass_in_sandcastle_if(
2938            BACKEND == "nccl", "Nccl does not support CPU tensors"
2939        )
2940        def test_all_reduce_full_group_min(self):
2941            group, group_id, rank = self._init_full_group_test()
2942            self._test_all_reduce_helper(
2943                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
2944            )
2945
2946        @skip_but_pass_in_sandcastle_if(
2947            BACKEND == "nccl", "Nccl does not support CPU tensors"
2948        )
2949        def test_all_reduce_full_group_max(self):
2950            group, group_id, rank = self._init_full_group_test()
2951            self._test_all_reduce_helper(
2952                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
2953            )
2954
2955        # SPARSE ALL REDUCE
2956        def _test_sparse_all_reduce_sum(self, fn):
2957            group, group_id, rank = self._init_global_test()
2958
2959            tests = simple_sparse_reduce_tests(
2960                rank, dist.get_world_size(), num_inputs=1
2961            )
2962            for (inputs, outputs) in tests:
2963                tensors = [fn(input) for input in inputs]
2964                dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
2965                self.assertEqual(tensors[0], outputs[0])
2966
2967        @skip_but_pass_in_sandcastle_if(
2968            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
2969        )
2970        def test_sparse_all_reduce_sum(self):
2971            self._test_sparse_all_reduce_sum(lambda t: t)
2972
2973        @skip_but_pass_in_sandcastle_if(
2974            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
2975        )
2976        @skip_if_no_gpu
2977        def test_sparse_all_reduce_sum_cuda(self):
2978            self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
2979
2980        # ALL REDUCE - COALESCED
2981        @staticmethod
2982        def _all_reduce_coalesced_sum_test_cases(group_size):
2983            return (
2984                [2, 3, complex(2, 3)],
2985                [10, 11, complex(10, 11)],
2986                [
2987                    2 + 10 * (group_size - 1),
2988                    3 + 11 * (group_size - 1),
2989                    complex(2, 3) + complex(10, 11) * (group_size - 1),
2990                ],
2991                [torch.float, torch.float, torch.cfloat],
2992            )
2993
2994        @staticmethod
2995        def _all_reduce_coalesced_product_test_cases(group_size):
2996            return (
2997                [1, 2],
2998                [3, 4],
2999                [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
3000                [torch.float, torch.float],
3001            )
3002
3003        @staticmethod
3004        def _all_reduce_coalesced_min_test_cases(group_size):
3005            return (
3006                [1, 4],
3007                [2, 3],
3008                [1, 3],
3009                [torch.float, torch.float],
3010            )
3011
3012        @staticmethod
3013        def _all_reduce_coalesced_max_test_cases(group_size):
3014            return (
3015                [1, 4],
3016                [2, 3],
3017                [2, 4],
3018                [torch.float, torch.float],
3019            )
3020
3021        @skip_but_pass_in_sandcastle_if(
3022            BACKEND == "nccl", "Nccl does not support CPU tensors"
3023        )
3024        def test_all_reduce_coalesced_max_complex_unsupported(self):
3025            group, group_id, rank = self._init_global_test()
3026            with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
3027                dist.all_reduce_coalesced(
3028                    [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
3029                )
3030
3031        def _test_all_reduce_coalesced_helper(
3032            self,
3033            group,
3034            group_id,
3035            rank,
3036            op,
3037            cuda=False,
3038            rank_to_GPU=None,
3039        ):
3040            test_case_func = {
3041                dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
3042                dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
3043                dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
3044                dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
3045            }[op]
3046
3047            master_values, worker_values, expected_values, dtypes = test_case_func(
3048                len(group)
3049            )
3050
3051            for src in group:
3052                curr_values = master_values if rank == src else worker_values
3053                tensors = [
3054                    _build_tensor(src + 1, val, dtype=dtype)
3055                    for dtype, val in zip(dtypes, curr_values)
3056                ]
3057                if cuda:
3058                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
3059                tensor_shapes = []
3060                for tensor in tensors:
3061                    if tensor.dtype == torch.complex64:
3062                        tensor_shapes.append(torch.view_as_real(tensor).shape)
3063                    else:
3064                        tensor_shapes.append(tensor.shape)
3065                self.call_dist_op(
3066                    ":all_reduce",
3067                    False,
3068                    dist.all_reduce_coalesced,
3069                    tensors,
3070                    op,
3071                    group_id,
3072                    tensor_shapes=tensor_shapes,
3073                )
3074                expected_tensors = [
3075                    _build_tensor(src + 1, expected_value, dtype=dtype)
3076                    for dtype, expected_value in zip(dtypes, expected_values)
3077                ]
3078                self.assertEqual(tensors, expected_tensors)
3079
3080            self._barrier()
3081
3082        @require_backend_is_available({"gloo"})
3083        def test_all_reduce_coalesced_sum(self):
3084            group, group_id, rank = self._init_global_test()
3085            self._test_all_reduce_coalesced_helper(
3086                group,
3087                group_id,
3088                rank,
3089                dist.ReduceOp.SUM,
3090                cuda=False,
3091                rank_to_GPU=None,
3092            )
3093
3094        @require_backend_is_available({"gloo"})
3095        def test_all_reduce_coalesced_product(self):
3096            group, group_id, rank = self._init_global_test()
3097            self._test_all_reduce_coalesced_helper(
3098                group,
3099                group_id,
3100                rank,
3101                dist.ReduceOp.PRODUCT,
3102                cuda=False,
3103                rank_to_GPU=None,
3104            )
3105
3106        @require_backend_is_available({"gloo"})
3107        def test_all_reduce_coalesced_min(self):
3108            group, group_id, rank = self._init_global_test()
3109            self._test_all_reduce_coalesced_helper(
3110                group,
3111                group_id,
3112                rank,
3113                dist.ReduceOp.MIN,
3114                cuda=False,
3115                rank_to_GPU=None,
3116            )
3117
3118        @require_backend_is_available({"gloo"})
3119        def test_all_reduce_coalesced_max(self):
3120            group, group_id, rank = self._init_global_test()
3121            self._test_all_reduce_coalesced_helper(
3122                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
3123            )
3124
3125        @skip_if_small_worldsize
3126        @require_backend_is_available({"gloo"})
3127        def test_all_reduce_coalesced_group_sum(self):
3128            group, group_id, rank = self._init_group_test()
3129            self._test_all_reduce_coalesced_helper(
3130                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
3131            )
3132
3133        @skip_if_small_worldsize
3134        @require_backend_is_available({"gloo"})
3135        def test_all_reduce_coalesced_group_product(self):
3136            group, group_id, rank = self._init_group_test()
3137            self._test_all_reduce_coalesced_helper(
3138                group,
3139                group_id,
3140                rank,
3141                dist.ReduceOp.PRODUCT,
3142                cuda=False,
3143                rank_to_GPU=None,
3144            )
3145
3146        @skip_if_small_worldsize
3147        @require_backend_is_available({"gloo"})
3148        def test_all_reduce_coalesced_group_min(self):
3149            group, group_id, rank = self._init_group_test()
3150            self._test_all_reduce_coalesced_helper(
3151                group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
3152            )
3153
3154        @skip_if_small_worldsize
3155        @require_backend_is_available({"gloo"})
3156        def test_all_reduce_coalesced_group_max(self):
3157            group, group_id, rank = self._init_group_test()
3158            self._test_all_reduce_coalesced_helper(
3159                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
3160            )
3161
3162        @require_backend_is_available({"gloo"})
3163        def test_all_reduce_coalesced_full_group_sum(self):
3164            group, group_id, rank = self._init_full_group_test()
3165            self._test_all_reduce_coalesced_helper(
3166                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
3167            )
3168
3169        @require_backend_is_available({"gloo"})
3170        def test_all_reduce_coalesced_full_group_product(self):
3171            group, group_id, rank = self._init_full_group_test()
3172            self._test_all_reduce_coalesced_helper(
3173                group,
3174                group_id,
3175                rank,
3176                dist.ReduceOp.PRODUCT,
3177                cuda=False,
3178                rank_to_GPU=None,
3179            )
3180
3181        @require_backend_is_available({"gloo"})
3182        def test_all_reduce_coalesced_full_group_min(self):
3183            group, group_id, rank = self._init_full_group_test()
3184            self._test_all_reduce_coalesced_helper(
3185                group,
3186                group_id,
3187                rank,
3188                dist.ReduceOp.MIN,
3189                cuda=False,
3190                rank_to_GPU=None,
3191            )
3192
3193        @require_backend_is_available({"gloo"})
3194        def test_all_reduce_coalesced_full_group_max(self):
3195            group, group_id, rank = self._init_full_group_test()
3196            self._test_all_reduce_coalesced_helper(
3197                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
3198            )
3199
3200        # SCATTER
3201        def _test_scatter_helper(
3202            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
3203        ):
3204            for dest in group:
3205                tensor = _build_tensor(dest + 1, -1, dtype=dtype)
3206                expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
3207                tensors = (
3208                    [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
3209                    if rank == dest
3210                    else []
3211                )
3212                if cuda:
3213                    tensor = tensor.cuda(rank_to_GPU[rank][0])
3214                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
3215                if dtype == torch.complex64:
3216                    tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
3217                else:
3218                    tensor_shapes = [t.shape for t in tensors]
3219                self.call_dist_op(
3220                    ":scatter",
3221                    False,
3222                    dist.scatter,
3223                    tensor,
3224                    src=dest,
3225                    scatter_list=tensors,
3226                    group=group_id,
3227                    expect_event=False,
3228                    tensor_shapes=tensor_shapes,
3229                )
3230                self.assertEqual(tensor, expected_tensor)
3231
3232            self._barrier()
3233
3234        @skip_but_pass_in_sandcastle_if(
3235            BACKEND == "nccl", "Nccl does not support CPU tensors"
3236        )
3237        @skip_but_pass_in_sandcastle_if(
3238            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3239        )
3240        def test_scatter_checks(self):
3241            group, group_id, rank = self._init_global_test()
3242            one = torch.ones([1])
3243
3244            # Specify scatter_list argument only on source rank.
3245            output = one.clone() * -1
3246            if rank == 0:
3247                scatter_list = [one.clone() * i for i in group]
3248                dist.scatter(output, src=0, scatter_list=scatter_list)
3249            else:
3250                dist.scatter(output, src=0)
3251            self.assertEqual(output, one * rank)
3252
3253            # Don't specify src argument.
3254            output = one.clone() * -1
3255            if rank == 0:
3256                scatter_list = [one.clone() * i for i in group]
3257                dist.scatter(output, scatter_list=scatter_list)
3258            else:
3259                dist.scatter(output)
3260            self.assertEqual(output, one * rank)
3261
3262        @skip_but_pass_in_sandcastle_if(
3263            BACKEND == "nccl", "Nccl does not support CPU tensors"
3264        )
3265        @skip_but_pass_in_sandcastle_if(
3266            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3267        )
3268        def test_scatter(self):
3269            group, group_id, rank = self._init_global_test()
3270            self._test_scatter_helper(group, group_id, rank)
3271
3272        @skip_but_pass_in_sandcastle_if(
3273            BACKEND != "nccl", "Only Nccl supports CUDA gather"
3274        )
3275        @skip_if_no_gpu
3276        def test_scatter_cuda(self):
3277            group, group_id, rank = self._init_global_test()
3278            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3279            self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
3280
3281        @skip_but_pass_in_sandcastle_if(
3282            BACKEND == "nccl", "Nccl does not support CPU tensors"
3283        )
3284        @skip_but_pass_in_sandcastle_if(
3285            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3286        )
3287        def test_scatter_complex(self):
3288            group, group_id, rank = self._init_global_test()
3289            self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
3290
3291        @skip_but_pass_in_sandcastle_if(
3292            BACKEND != "nccl", "Only Nccl supports CUDA gather"
3293        )
3294        @skip_if_no_gpu
3295        def test_scatter_cuda_complex(self):
3296            group, group_id, rank = self._init_global_test()
3297            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3298            self._test_scatter_helper(
3299                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
3300            )
3301
3302        @skip_but_pass_in_sandcastle_if(
3303            BACKEND == "nccl", "Nccl does not support CPU tensors"
3304        )
3305        @skip_but_pass_in_sandcastle_if(
3306            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3307        )
3308        @skip_if_small_worldsize
3309        def test_scatter_group(self):
3310            group, group_id, rank = self._init_group_test()
3311            self._test_scatter_helper(group, group_id, rank)
3312
3313        @skip_but_pass_in_sandcastle_if(
3314            BACKEND == "nccl", "Nccl does not support CPU tensors"
3315        )
3316        @skip_but_pass_in_sandcastle_if(
3317            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3318        )
3319        def test_scatter_full_group(self):
3320            group, group_id, rank = self._init_full_group_test()
3321            self._test_scatter_helper(group, group_id, rank)
3322
3323        # GATHER
3324        def _test_gather_helper(
3325            self, group, group_id, rank, cuda=False, rank_to_GPU=None
3326        ):
3327            for dest in group:
3328                tensor = _build_tensor(dest + 1, rank)
3329                tensors = (
3330                    [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
3331                )
3332                if cuda:
3333                    tensor = tensor.cuda(rank_to_GPU[rank][0])
3334                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
3335                self.call_dist_op(
3336                    ":gather",
3337                    False,
3338                    dist.gather,
3339                    tensor,
3340                    dst=dest,
3341                    gather_list=tensors,
3342                    group=group_id,
3343                    expect_event=False,
3344                    tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
3345                )
3346                if rank == dest:
3347                    expected_tensors = [_build_tensor(dest + 1, i) for i in group]
3348                    for t1, t2 in zip(tensors, expected_tensors):
3349                        self.assertEqual(t1, t2)
3350
3351            self._barrier()
3352
3353        @skip_but_pass_in_sandcastle_if(
3354            BACKEND == "nccl", "Nccl does not support CPU tensors"
3355        )
3356        @skip_but_pass_in_sandcastle_if(
3357            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3358        )
3359        def test_gather_checks(self):
3360            group, group_id, rank = self._init_global_test()
3361            one = torch.ones([1])
3362
3363            # Specify gather_list argument only on destination rank.
3364            if rank == 0:
3365                gather_list = [one.clone() for _ in group]
3366                dist.gather(one * rank, dst=0, gather_list=gather_list)
3367                for i in group:
3368                    self.assertEqual(gather_list[i], one * i)
3369            else:
3370                dist.gather(one * rank, dst=0)
3371
3372            # Don't specify dst argument.
3373            if rank == 0:
3374                gather_list = [one.clone() for _ in group]
3375                dist.gather(one * rank, gather_list=gather_list)
3376                for i in group:
3377                    self.assertEqual(gather_list[i], one * i)
3378            else:
3379                dist.gather(one * rank)
3380
3381        @skip_but_pass_in_sandcastle_if(
3382            BACKEND == "nccl", "Nccl does not support CPU tensors"
3383        )
3384        @skip_but_pass_in_sandcastle_if(
3385            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3386        )
3387        def test_gather(self):
3388            group, group_id, rank = self._init_global_test()
3389            self._test_gather_helper(group, group_id, rank)
3390
3391        @skip_but_pass_in_sandcastle_if(
3392            BACKEND != "nccl", "Only Nccl supports CUDA gather"
3393        )
3394        @skip_if_no_gpu
3395        def test_gather_cuda(self):
3396            group, group_id, rank = self._init_global_test()
3397            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3398            self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
3399
3400        @skip_but_pass_in_sandcastle_if(
3401            BACKEND == "nccl", "Nccl does not support CPU tensors"
3402        )
3403        @skip_but_pass_in_sandcastle_if(
3404            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3405        )
3406        @skip_if_small_worldsize
3407        def test_gather_group(self):
3408            group, group_id, rank = self._init_group_test()
3409            self._test_gather_helper(group, group_id, rank)
3410
3411        @skip_but_pass_in_sandcastle_if(
3412            BACKEND == "nccl", "Nccl does not support CPU tensors"
3413        )
3414        @skip_but_pass_in_sandcastle_if(
3415            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
3416        )
3417        def test_gather_full_group(self):
3418            group, group_id, rank = self._init_full_group_test()
3419            self._test_gather_helper(group, group_id, rank)
3420
3421        # ALL GATHER
3422        def _test_all_gather_helper(
3423            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
3424        ):
3425            for dest in group:
3426                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
3427                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
3428                allgather = dist.all_gather
3429                if cuda:
3430                    tensor = tensor.cuda(rank_to_GPU[rank][0])
3431                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
3432                if tensors[0].dtype == torch.complex64:
3433                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
3434                else:
3435                    tensor_shapes = [tensors[0].shape]
3436                self.call_dist_op(
3437                    ":all_gather",
3438                    False,
3439                    allgather,
3440                    tensors,
3441                    tensor,
3442                    group_id,
3443                    False,
3444                    tensor_shapes=tensor_shapes,
3445                )
3446
3447                expected_tensors = [
3448                    _build_tensor(dest + 1, i, dtype=dtype) for i in group
3449                ]
3450                for t1, t2 in zip(tensors, expected_tensors):
3451                    self.assertEqual(t1, t2)
3452
3453            self._barrier()
3454
3455        @skip_but_pass_in_sandcastle_if(
3456            BACKEND == "nccl", "Nccl does not support CPU tensors"
3457        )
3458        def test_all_gather(self):
3459            group, group_id, rank = self._init_global_test()
3460            self._test_all_gather_helper(group, group_id, rank)
3461
3462        @skip_but_pass_in_sandcastle_if(
3463            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
3464        )
3465        @skip_if_no_gpu
3466        def test_all_gather_cuda(self):
3467            group, group_id, rank = self._init_global_test()
3468            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3469            self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
3470
3471        @skip_but_pass_in_sandcastle_if(
3472            BACKEND == "nccl", "Nccl does not support CPU tensors"
3473        )
3474        def test_all_gather_complex(self):
3475            group, group_id, rank = self._init_global_test()
3476            self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
3477
3478        @skip_but_pass_in_sandcastle_if(
3479            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
3480        )
3481        @skip_if_no_gpu
3482        def test_all_gather_cuda_complex(self):
3483            group, group_id, rank = self._init_global_test()
3484            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3485            self._test_all_gather_helper(
3486                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
3487            )
3488
3489        @skip_if_small_worldsize
3490        @skip_but_pass_in_sandcastle_if(
3491            BACKEND == "nccl", "Nccl does not support CPU tensors"
3492        )
3493        def test_all_gather_group(self):
3494            group, group_id, rank = self._init_group_test()
3495            self._test_all_gather_helper(group, group_id, rank)
3496
3497        @skip_but_pass_in_sandcastle_if(
3498            BACKEND == "nccl", "Nccl does not support CPU tensors"
3499        )
3500        def test_all_gather_full_group(self):
3501            group, group_id, rank = self._init_full_group_test()
3502            self._test_all_gather_helper(group, group_id, rank)
3503
3504        @skip_but_pass_in_sandcastle_if(
3505            BACKEND != "nccl", "Only Nccl supports all_gather_v"
3506        )
3507        @skip_if_no_gpu
3508        def test_all_gather_v_cuda(self):
3509            self._barrier()
3510            group, group_id, rank = self._init_global_test()
3511            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3512            device_id = rank_to_GPU[rank][0]
3513
3514            output_split_sizes = []
3515            for dst in group:
3516                output_split_sizes.append(dst + 1)
3517            sum_len = sum(output_split_sizes)
3518            value = 2
3519
3520            for async_val in [True, False]:
3521                tensor = (
3522                    torch.empty(
3523                        output_split_sizes[rank], sum_len, sum_len, dtype=torch.float
3524                    )
3525                    .fill_(value)
3526                    .cuda(device_id)
3527                )
3528                out_tensor = _build_tensor(sum_len, -1, device_id=device_id)
3529
3530                req = dist.all_gather(
3531                    list(torch.split(out_tensor, output_split_sizes)),
3532                    tensor,
3533                    group_id,
3534                    async_val,
3535                )
3536                if async_val:
3537                    req.wait()
3538
3539                expected_value = value
3540                expected_tensor = _build_tensor(
3541                    sum_len, expected_value, device_id=device_id
3542                )
3543
3544                self.assertEqual(out_tensor, expected_tensor)
3545            self._barrier()
3546
3547        # Test all_gather accepting single tensor as output
3548        def _all_gather_into_tensor_helper(
3549            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
3550        ):
3551            if cuda:
3552                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
3553                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
3554            if tensor_out.dtype == torch.complex64:
3555                tensor_shapes = [torch.view_as_real(tensor_in).shape]
3556            else:
3557                tensor_shapes = [tensor_in.shape]
3558            self.call_dist_op(
3559                ":all_gather_into_tensor",
3560                False,
3561                dist.all_gather_into_tensor,
3562                tensor_out,
3563                tensor_in,
3564                group_id,
3565                False,
3566                expect_event=False,
3567                tensor_shapes=tensor_shapes,
3568            )
3569            return tensor_out
3570
3571        @skip_but_pass_in_sandcastle_if(
3572            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
3573        )
3574        @skip_if_no_gpu
3575        def test_all_gather_into_cat_tensor_cuda(self):
3576            group, group_id, rank = self._init_global_test()
3577            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3578            size = 2
3579            tensor_in = torch.ones([size, size]) * rank
3580            # Concatenated output
3581            tensor_out = torch.ones([len(group) * size, size]) * (-1)
3582            tensor_out = self._all_gather_into_tensor_helper(
3583                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
3584            )
3585
3586            # Check result
3587            # Concatenate all blocks into a bigger tensor
3588            expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group])
3589            self.assertEqual(tensor_out, expected_tensor)
3590            self._barrier()
3591
3592        @skip_but_pass_in_sandcastle_if(
3593            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
3594        )
3595        @skip_if_no_gpu
3596        def test_all_gather_into_stack_tensor_cuda(self):
3597            group, group_id, rank = self._init_global_test()
3598            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3599            size = 2
3600            tensor_in = torch.ones([size, size]) * rank
3601            # Stacked output
3602            tensor_out = torch.ones([len(group), size, size]) * (-1)
3603            tensor_out = self._all_gather_into_tensor_helper(
3604                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
3605            )
3606
3607            # Check result
3608            # Stack all blocks into a bigger tensor
3609            expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group])
3610            self.assertEqual(tensor_out, expected_tensor)
3611            self._barrier()
3612
3613        def _run_all_gather_coalesced_and_verify(
3614            self, output_tensor_lists, input_tensors, expected_tensors, group_id
3615        ):
3616            """
3617            Helper that runs all_gather_coalesced and returns true if output
3618            matches expectations.
3619            """
3620            tensor_shapes = []
3621            for input_tensor in input_tensors:
3622                if input_tensor.dtype == torch.complex64:
3623                    tensor_shapes.append(torch.view_as_real(input_tensor).shape)
3624                else:
3625                    tensor_shapes.append(input_tensor.shape)
3626            self.call_dist_op(
3627                ":all_gather",
3628                False,
3629                dist.all_gather_coalesced,
3630                output_tensor_lists,
3631                input_tensors,
3632                group_id,
3633                tensor_shapes=tensor_shapes,
3634            )
3635
3636            for l1, l2 in zip(output_tensor_lists, expected_tensors):
3637                for t1, t2 in zip(l1, l2):
3638                    if not torch.equal(t1, t2):
3639                        return False
3640            return True
3641
3642        def _test_all_gather_coalesced_helper(
3643            self, group, group_id, rank, dtype=torch.float
3644        ):
3645            # TODO: Instead we should probably go through _rank_not_in_group
3646            # mechanism to disable sending tensors
3647            if group_id is not None:
3648                for test_case_id in range(2, 5):
3649                    # Make sure we create tensors of incompatible sizes, e.g.
3650                    # [1], [2x2], [3x3x3] ... to be sent in one batch
3651                    input_tensors = [
3652                        _build_multidim_tensor(
3653                            tensor_id, tensor_id, rank + tensor_id, dtype=dtype
3654                        )
3655                        for tensor_id in range(1, test_case_id)
3656                    ]
3657                    output_tensor_lists = [
3658                        [
3659                            _build_multidim_tensor(
3660                                tensor_id, tensor_id, -1, dtype=dtype
3661                            )
3662                            for tensor_id in range(1, test_case_id)
3663                        ]
3664                        for _ in group
3665                    ]
3666                    expected_tensors = [
3667                        [
3668                            _build_multidim_tensor(
3669                                tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
3670                            )
3671                            for tensor_id in range(1, test_case_id)
3672                        ]
3673                        for rank_iter in group
3674                    ]
3675                    assert self._run_all_gather_coalesced_and_verify(
3676                        output_tensor_lists, input_tensors, expected_tensors, group_id
3677                    ), "output tensors do not match expected outputs"
3678
3679            self._barrier()
3680
3681        @skip_but_pass_in_sandcastle_if(
3682            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
3683            f"{BACKEND} does not support all_gather_coalesced",
3684        )
3685        def test_all_gather_coalesced_simple(self):
3686            group, group_id, rank = self._init_global_test()
3687            self._test_all_gather_coalesced_helper(group, group_id, rank)
3688
3689        @skip_but_pass_in_sandcastle_if(
3690            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
3691            f"{BACKEND} does not support all_gather_coalesced",
3692        )
3693        def test_all_gather_coalesced_complex(self):
3694            group, group_id, rank = self._init_global_test()
3695            self._test_all_gather_coalesced_helper(
3696                group, group_id, rank, dtype=torch.cfloat
3697            )
3698
3699        @skip_if_small_worldsize
3700        @skip_but_pass_in_sandcastle_if(
3701            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
3702            f"{BACKEND} does not support all_gather_coalesced",
3703        )
3704        def test_all_gather_coalesced_group(self):
3705            group, group_id, rank = self._init_group_test()
3706            self._test_all_gather_coalesced_helper(group, group_id, rank)
3707
3708        @skip_but_pass_in_sandcastle_if(
3709            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
3710            f"{BACKEND} does not support all_gather_coalesced",
3711        )
3712        def test_all_gather_coalesced_full_group(self):
3713            group, group_id, rank = self._init_full_group_test()
3714            self._test_all_gather_coalesced_helper(group, group_id, rank)
3715
3716        @skip_but_pass_in_sandcastle_if(
3717            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
3718            f"{BACKEND} does not support all_gather_coalesced",
3719        )
3720        def test_all_gather_coalesced_with_empty(self):
3721            group, group_id, rank = self._init_global_test()
3722            input_tensors = [
3723                rank * torch.ones([2, 2]),
3724                torch.ones([0]),
3725                (rank + 1) * torch.ones([3, 3]),
3726                torch.ones([0]),
3727                torch.ones([0]),
3728            ]
3729            output_tensors_lists = [
3730                [
3731                    -1 * torch.ones([2, 2]),
3732                    -1 * torch.ones([0]),
3733                    -1 * torch.ones([3, 3]),
3734                    -1 * torch.ones([0]),
3735                    -1 * torch.ones([0]),
3736                ]
3737                for _ in group
3738            ]
3739            expected_tensors = [
3740                [
3741                    r * torch.ones([2, 2]),
3742                    torch.ones([0]),
3743                    (r + 1) * torch.ones([3, 3]),
3744                    torch.ones([0]),
3745                    torch.ones([0]),
3746                ]
3747                for r in group
3748            ]
3749            assert self._run_all_gather_coalesced_and_verify(
3750                output_tensors_lists, input_tensors, expected_tensors, group_id
3751            )
3752            self._barrier()
3753
3754        # AllToAll
3755        def _test_all_to_all_single_equal_split_helper(
3756            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
3757        ):
3758            if group_id is not None:
3759                size = len(group)
3760                in_tensor = torch.ones([size, size], dtype=dtype) * rank
3761                expected_tensor = torch.cat(
3762                    [torch.ones([1, size], dtype=dtype) * i for i in group]
3763                )
3764                out_tensor = torch.ones([size, size], dtype=dtype) * -1
3765                if cuda:
3766                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
3767                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
3768                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
3769                if dtype == torch.complex64:
3770                    tensor_shapes = [torch.view_as_real(in_tensor).shape]
3771                else:
3772                    tensor_shapes = [in_tensor.shape]
3773                self.call_dist_op(
3774                    ":all_to_all",
3775                    False,
3776                    dist.all_to_all_single,
3777                    out_tensor,
3778                    in_tensor,
3779                    group=group_id,
3780                    tensor_shapes=tensor_shapes,
3781                )
3782                self.assertEqual(out_tensor, expected_tensor)
3783            self._barrier()
3784
3785        def _test_all_to_all_single_unequal_split_helper(
3786            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
3787        ):
3788            if group_id is not None:
3789                size = len(group)
3790                in_splits = [i + 1 for i in group]
3791                out_splits = [rank + 1 for _ in group]
3792                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
3793                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
3794                expected_tensor = torch.cat(
3795                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
3796                )
3797                if cuda:
3798                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
3799                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
3800                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
3801                dist.all_to_all_single(
3802                    out_tensor, in_tensor, out_splits, in_splits, group=group_id
3803                )
3804                self.assertEqual(out_tensor, expected_tensor)
3805            self._barrier()
3806
3807        def _test_all_to_all_helper(
3808            self,
3809            group,
3810            group_id,
3811            rank,
3812            cuda=False,
3813            rank_to_GPU=None,
3814            dtype=torch.float,
3815        ):
3816            if group_id is not None:
3817                size = len(group)
3818                in_splits = [i + 1 for i in group]
3819                in_tensors = [
3820                    torch.ones([in_splits[i], size], dtype=dtype) * rank
3821                    for i, _ in enumerate(group)
3822                ]
3823                out_tensors = [
3824                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
3825                ]
3826                expected_tensors = [
3827                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
3828                ]
3829                if cuda:
3830                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
3831                    expected_tensors = [
3832                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
3833                    ]
3834                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
3835                dist.all_to_all(out_tensors, in_tensors, group=group_id)
3836                for t1, t2 in zip(out_tensors, expected_tensors):
3837                    self.assertEqual(t1, t2)
3838            self._barrier()
3839
3840        @skip_but_pass_in_sandcastle_if(
3841            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3842        )
3843        def test_all_to_all_single_equal_split(self):
3844            group, group_id, rank = self._init_global_test()
3845            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
3846
3847        @skip_but_pass_in_sandcastle_if(
3848            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3849        )
3850        @skip_if_no_gpu
3851        def test_all_to_all_single_equal_split_cuda(self):
3852            group, group_id, rank = self._init_global_test()
3853            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3854            self._test_all_to_all_single_equal_split_helper(
3855                group,
3856                group_id,
3857                rank,
3858                True,
3859                rank_to_GPU,
3860            )
3861
3862        @skip_but_pass_in_sandcastle_if(
3863            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3864        )
3865        def test_all_to_all_single_equal_split_complex(self):
3866            group, group_id, rank = self._init_global_test()
3867            self._test_all_to_all_single_equal_split_helper(
3868                group, group_id, rank, dtype=torch.cfloat
3869            )
3870
3871        @skip_but_pass_in_sandcastle_if(
3872            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3873        )
3874        @skip_if_no_gpu
3875        def test_all_to_all_single_equal_split_cuda_complex(self):
3876            group, group_id, rank = self._init_global_test()
3877            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3878            self._test_all_to_all_single_equal_split_helper(
3879                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
3880            )
3881
3882        @skip_but_pass_in_sandcastle_if(
3883            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3884        )
3885        def test_all_to_all_single_unequal_split(self):
3886            group, group_id, rank = self._init_global_test()
3887            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
3888
3889        @skip_but_pass_in_sandcastle_if(
3890            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3891        )
3892        @skip_if_no_gpu
3893        def test_all_to_all_single_unequal_split_cuda(self):
3894            group, group_id, rank = self._init_global_test()
3895            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3896            self._test_all_to_all_single_unequal_split_helper(
3897                group,
3898                group_id,
3899                rank,
3900                True,
3901                rank_to_GPU,
3902            )
3903
3904        @skip_but_pass_in_sandcastle_if(
3905            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3906        )
3907        def test_all_to_all_single_unequal_split_complex(self):
3908            group, group_id, rank = self._init_global_test()
3909            self._test_all_to_all_single_unequal_split_helper(
3910                group, group_id, rank, dtype=torch.cfloat
3911            )
3912
3913        @skip_but_pass_in_sandcastle_if(
3914            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3915        )
3916        @skip_if_no_gpu
3917        def test_all_to_all_single_unequal_split_cuda_complex(self):
3918            group, group_id, rank = self._init_global_test()
3919            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3920            self._test_all_to_all_single_unequal_split_helper(
3921                group,
3922                group_id,
3923                rank,
3924                True,
3925                rank_to_GPU,
3926                dtype=torch.cfloat,
3927            )
3928
3929        @skip_but_pass_in_sandcastle_if(
3930            BACKEND != "mpi", "Only MPI supports all_to_all"
3931        )
3932        def test_all_to_all(self):
3933            group, group_id, rank = self._init_global_test()
3934            self._test_all_to_all_helper(group, group_id, rank)
3935
3936        @skip_but_pass_in_sandcastle_if(
3937            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
3938        )
3939        @skip_if_rocm_multiprocess
3940        def test_all_to_all_cuda(self):
3941            group, group_id, rank = self._init_global_test()
3942            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3943            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
3944
3945        @skip_but_pass_in_sandcastle_if(
3946            BACKEND != "mpi", "Only MPI supports all_to_all"
3947        )
3948        def test_all_to_all_complex(self):
3949            group, group_id, rank = self._init_global_test()
3950            self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)
3951
3952        @skip_but_pass_in_sandcastle_if(
3953            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
3954        )
3955        @skip_if_rocm_multiprocess
3956        def test_all_to_all_cuda_complex(self):
3957            group, group_id, rank = self._init_global_test()
3958            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3959            self._test_all_to_all_helper(
3960                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
3961            )
3962
3963        @skip_but_pass_in_sandcastle_if(
3964            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3965        )
3966        @skip_if_small_worldsize
3967        def test_all_to_all_single_equal_split_group(self):
3968            group, group_id, rank = self._init_group_test()
3969            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
3970
3971        @skip_but_pass_in_sandcastle_if(
3972            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3973        )
3974        @skip_if_no_gpu
3975        @skip_if_small_worldsize
3976        def test_all_to_all_single_equal_split_group_cuda(self):
3977            group, group_id, rank = self._init_group_test()
3978            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
3979            self._test_all_to_all_single_equal_split_helper(
3980                group,
3981                group_id,
3982                rank,
3983                True,
3984                rank_to_GPU,
3985            )
3986
3987        @skip_but_pass_in_sandcastle_if(
3988            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
3989        )
3990        @skip_if_small_worldsize
3991        def test_all_to_all_single_unequal_split_group(self):
3992            group, group_id, rank = self._init_group_test()
3993            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
3994
3995        @skip_but_pass_in_sandcastle_if(
3996            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
3997        )
3998        @skip_if_no_gpu
3999        @skip_if_small_worldsize
4000        def test_all_to_all_single_unequal_split_group_cuda(self):
4001            group, group_id, rank = self._init_global_test()
4002            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4003            self._test_all_to_all_single_unequal_split_helper(
4004                group,
4005                group_id,
4006                rank,
4007                True,
4008                rank_to_GPU,
4009            )
4010
4011        @skip_but_pass_in_sandcastle_if(
4012            BACKEND != "mpi", "Only MPI supports all_to_all"
4013        )
4014        @skip_if_small_worldsize
4015        def test_all_to_all_group(self):
4016            group, group_id, rank = self._init_group_test()
4017            self._test_all_to_all_helper(group, group_id, rank)
4018
4019        @skip_but_pass_in_sandcastle_if(
4020            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
4021        )
4022        @skip_if_small_worldsize
4023        @skip_if_rocm_multiprocess
4024        def test_all_to_all_group_cuda(self):
4025            group, group_id, rank = self._init_group_test()
4026            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4027            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
4028
4029        @skip_but_pass_in_sandcastle_if(
4030            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
4031        )
4032        def test_all_to_all_single_equal_split_full_group(self):
4033            group, group_id, rank = self._init_full_group_test()
4034            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
4035
4036        @skip_but_pass_in_sandcastle_if(
4037            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
4038        )
4039        @skip_if_no_gpu
4040        def test_all_to_all_single_equal_split_full_group_cuda(self):
4041            group, group_id, rank = self._init_full_group_test()
4042            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4043            self._test_all_to_all_single_equal_split_helper(
4044                group,
4045                group_id,
4046                rank,
4047                True,
4048                rank_to_GPU,
4049            )
4050
4051        @skip_but_pass_in_sandcastle_if(
4052            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
4053        )
4054        def test_all_to_all_single_unequal_split_full_group(self):
4055            group, group_id, rank = self._init_full_group_test()
4056            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
4057
4058        @skip_but_pass_in_sandcastle_if(
4059            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
4060        )
4061        @skip_if_no_gpu
4062        def test_all_to_all_single_unequal_split_full_group_cuda(self):
4063            group, group_id, rank = self._init_full_group_test()
4064            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4065            self._test_all_to_all_single_unequal_split_helper(
4066                group,
4067                group_id,
4068                rank,
4069                True,
4070                rank_to_GPU,
4071            )
4072
4073        @skip_but_pass_in_sandcastle_if(
4074            BACKEND != "mpi", "Only MPI supports all_to_all"
4075        )
4076        def test_all_to_all_full_group(self):
4077            group, group_id, rank = self._init_full_group_test()
4078            self._test_all_to_all_helper(group, group_id, rank)
4079
4080        @skip_but_pass_in_sandcastle_if(
4081            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
4082        )
4083        @skip_if_rocm_multiprocess
4084        def test_all_to_all_full_group_cuda(self):
4085            group, group_id, rank = self._init_full_group_test()
4086            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4087            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
4088
4089        # BARRIER
4090        def _test_barrier_helper(
4091            self, group, group_id, rank, cuda=False, rank_to_GPU=None
4092        ):
4093            WAIT_TIME = 0.3  # seconds
4094
4095            for dest in group:
4096                expected_time = torch.DoubleTensor(1).fill_(0.0)
4097                if cuda:
4098                    expected_time = expected_time.cuda(rank_to_GPU[rank][0])
4099                if dest == rank:
4100                    expected_time.fill_(time.time() + WAIT_TIME)
4101                    dist.broadcast(expected_time, dest, group_id)
4102                    time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
4103                    dist.barrier(group_id)
4104                else:
4105                    dist.broadcast(expected_time, dest, group_id)
4106                    dist.barrier(group_id)
4107                    self.assertGreaterAlmostEqual(
4108                        float(time.time()),
4109                        float(expected_time[0]),
4110                        msg="destination rank: %d, my rank: %d" % (dest, rank)
4111                        + " (if you see this failure, please report in #14554)",
4112                    )
4113
4114            # Use higher timeout for the instance where the test runs
4115            # against a subgroup and uses a CUDA tensor for expected time.
4116            # The CUDA initialization for the participating processes can
4117            # take long enough for the barrier timeout to trigger on the
4118            # process that doesn't participate in the group.
4119            self._barrier(timeout=20)
4120
4121        @skip_if_no_gpu
4122        @skip_but_pass_in_sandcastle_if(
4123            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
4124        )
4125        @skip_but_pass_in_sandcastle_if(
4126            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
4127        )
4128        def test_barrier_cuda(self):
4129            group, group_id, rank = self._init_global_test()
4130            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4131            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
4132
4133        @skip_if_small_worldsize
4134        @skip_if_no_gpu
4135        @skip_but_pass_in_sandcastle_if(
4136            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
4137        )
4138        def test_barrier_group_cuda(self):
4139            group, group_id, rank = self._init_group_test()
4140            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4141            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
4142
4143        @skip_if_small_worldsize
4144        @skip_if_no_gpu
4145        @skip_but_pass_in_sandcastle_if(
4146            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
4147        )
4148        def test_barrier_full_group_cuda(self):
4149            group, group_id, rank = self._init_full_group_test()
4150            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
4151            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
4152
4153        @skip_but_pass_in_sandcastle_if(
4154            BACKEND in DistTestCases.skip_collective["cpu barrier"],
4155            f"{BACKEND} does not support CPU barrier",
4156        )
4157        def test_barrier(self):
4158            group, group_id, rank = self._init_global_test()
4159            self._test_barrier_helper(group, group_id, rank)
4160
4161        @skip_if_small_worldsize
4162        @skip_but_pass_in_sandcastle_if(
4163            BACKEND in DistTestCases.skip_collective["cpu barrier"],
4164            f"{BACKEND} does not support CPU barrier",
4165        )
4166        def test_barrier_group(self):
4167            group, group_id, rank = self._init_group_test()
4168            self._test_barrier_helper(group, group_id, rank)
4169
4170        @skip_but_pass_in_sandcastle_if(
4171            BACKEND in DistTestCases.skip_collective["cpu barrier"],
4172            f"{BACKEND} does not support CPU barrier",
4173        )
4174        def test_barrier_full_group(self):
4175            group, group_id, rank = self._init_full_group_test()
4176            self._test_barrier_helper(group, group_id, rank)
4177
4178        def _model_step(self, model):
4179            for param in model.parameters():
4180                if param.grad is not None:
4181                    with torch.no_grad():
4182                        param += param.grad
4183                    param.grad = None
4184
4185        def _model_step_with_zero_grad(self, model):
4186            for param in model.parameters():
4187                if param.grad is not None:
4188                    with torch.no_grad():
4189                        param += param.grad
4190                    param.grad.requires_grad_(False)
4191                    param.grad.zero_()
4192
4193        def _prepare_dummy_data(self, local_bs):
4194            # global_bs for DDP should be divisible by WORLD_SIZE
4195            world_size = int(os.environ["WORLD_SIZE"])
4196            global_bs = world_size * local_bs
4197            input_cpu = torch.randn(global_bs, 2)
4198            target = torch.randn(global_bs, 4)
4199            loss = nn.MSELoss()
4200            return global_bs, input_cpu, target, loss
4201
4202        # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
4203        def _test_DDP_helper(
4204            self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
4205        ):
4206            model.train()
4207            output = model(input_var)
4208            l = loss(output, target) * scale_factor
4209            l.backward()
4210            if memory_format is not None:
4211                self.assertTrue(output.is_contiguous(memory_format=memory_format))
4212
4213        def _assert_equal_param(self, param_gpu, param_DDP):
4214            self.assertEqual(len(param_gpu), len(param_DDP))
4215            for p_gpu, p_DDP in zip(param_gpu, param_DDP):
4216                self.assertEqual(p_gpu, p_DDP)
4217
4218        def _test_DDP_niter(
4219            self,
4220            model_base,
4221            model_DDP,
4222            input,
4223            target,
4224            loss,
4225            local_bs,
4226            rank,
4227            batch_size,
4228            test_save,
4229            offset=None,
4230            world_size=0,
4231            zero_grad=False,
4232            memory_format=None,
4233            n_iter=5,
4234        ):
4235            for idx in range(n_iter):
4236                # single cpu/gpu training
4237                self._test_DDP_helper(
4238                    model_base, input, target, loss, memory_format=memory_format
4239                )
4240
4241                if offset is None:
4242                    offset = rank * local_bs
4243
4244                # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
4245                self._test_DDP_helper(
4246                    model_DDP,
4247                    input[offset : offset + local_bs],
4248                    target[offset : offset + local_bs],
4249                    loss,
4250                    world_size * local_bs / batch_size if world_size != 0 else 1,
4251                    memory_format=memory_format,
4252                )
4253
4254                # Update weights and run a second iteration to shake out errors
4255                if zero_grad:
4256                    self._model_step_with_zero_grad(model_base)
4257                    self._model_step_with_zero_grad(model_DDP)
4258                else:
4259                    self._model_step(model_base)
4260                    self._model_step(model_DDP)
4261                self._assert_equal_param(
4262                    list(model_base.parameters()), list(model_DDP.module.parameters())
4263                )
4264
4265                # Shuffle the input so that DDP input is different
4266                input = input[torch.randperm(batch_size)]
4267
4268                # save the model in the middle and reload
4269                if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
4270                    with tempfile.NamedTemporaryFile() as tmp:
4271                        if sys.platform == "win32":
4272                            torch.save(model_DDP, tmp)
4273                            tmp.seek(0)
4274                            # weights_only=False as this is legacy code that saves the model
4275                            model_DDP = torch.load(tmp, weights_only=False)
4276                        else:
4277                            torch.save(model_DDP, tmp.name)
4278                            # weights_only=False as this is legacy code that saves the model
4279                            model_DDP = torch.load(tmp.name, weights_only=False)
4280
4281            with tempfile.TemporaryFile() as tmp_file:
4282                torch.save(model_DDP, tmp_file)
4283                tmp_file.seek(0)
4284                # weights_only=False as this is legacy code that saves the model
4285                saved_model = torch.load(tmp_file, weights_only=False)
4286            for k in model_DDP.state_dict():
4287                self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
4288
4289        def _test_DistributedDataParallel(
4290            self,
4291            gpu_subset,
4292            rank,
4293            output_device=None,
4294            gradient_as_bucket_view=False,
4295            static_graph=False,
4296            set_static_graph_twice=False,
4297        ):
4298            # Run a simple end to end DDP model, use result of single node model
4299            # as baseline
4300
4301            # cpu training setup
4302            model = DDP_NET
4303
4304            # single gpu training setup
4305            model_gpu = copy.deepcopy(model)
4306            model_gpu.cuda(gpu_subset[0])
4307
4308            # DDP training setup
4309            model_DDP = copy.deepcopy(model)
4310            model_DDP.cuda(gpu_subset[0])
4311            model_DDP = nn.parallel.DistributedDataParallel(
4312                model_DDP,
4313                device_ids=gpu_subset,
4314                gradient_as_bucket_view=gradient_as_bucket_view,
4315                static_graph=static_graph,
4316            )
4317
4318            if set_static_graph_twice:
4319                model_DDP._set_static_graph()
4320
4321            # test serializable/unserializable
4322            with tempfile.NamedTemporaryFile() as tmp:
4323                if sys.platform == "win32":
4324                    torch.save(model_DDP, tmp)
4325                    tmp.seek(0)
4326                    # weights_only=False as this is legacy code that saves the model
4327                    model_DDP = torch.load(tmp, weights_only=False)
4328                else:
4329                    torch.save(model_DDP, tmp.name)
4330                    # weights_only=False as this is legacy code that saves the model
4331                    model_DDP = torch.load(tmp.name, weights_only=False)
4332
4333            # dummy data initialization
4334            local_bs = len(gpu_subset)
4335            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
4336
4337            # check two model parameters over 5 iterations
4338            self._test_DDP_niter(
4339                model_gpu,
4340                model_DDP,
4341                input_cpu.cuda(gpu_subset[0]),
4342                target.cuda(gpu_subset[0]),
4343                loss,
4344                local_bs,
4345                rank,
4346                global_bs,
4347                True,
4348            )
4349            self._barrier()
4350
4351        def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
4352            # Run a simple end to end DDP-CPU model, use result of single node
4353            # model as baseline
4354            group, group_id, rank = self._init_global_test()
4355
4356            # cpu training setup
4357            model_base = DDP_NET
4358
4359            # DDP-CPU training setup
4360            model_DDP = copy.deepcopy(model_base)
4361            model_DDP = nn.parallel.DistributedDataParallel(
4362                model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
4363            )
4364
4365            # dummy data initialization
4366            local_bs = 2
4367            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
4368
4369            # check two model parameters over 5 iterations
4370            self._test_DDP_niter(
4371                model_base,
4372                model_DDP,
4373                input_cpu,
4374                target,
4375                loss,
4376                local_bs,
4377                rank,
4378                global_bs,
4379                False,
4380                zero_grad=True,
4381            )
4382            self._barrier()
4383
4384            return model_DDP
4385
4386        @skip_but_pass_in_sandcastle_if(
4387            BACKEND == "nccl", "nccl does not support DDP on CPU models"
4388        )
4389        def test_DistributedDataParallelCPU(self):
4390            self._test_DistributedDataParallelCPU()
4391
4392        @skip_but_pass_in_sandcastle_if(
4393            BACKEND == "nccl", "nccl does not support DDP on CPU models"
4394        )
4395        def test_DistributedDataParallelCPU_grad_is_view(self):
4396            self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)
4397
4398        @skip_but_pass_in_sandcastle_if(
4399            BACKEND not in DistTestCases.backend_feature["ddp"],
4400            f"The {BACKEND} backend does not support DistributedDataParallel",
4401        )
4402        def test_DistributedDataParallel_requires_grad(self):
4403            # a module without gradients shouldn't be accepted
4404            self.assertRaises(
4405                RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
4406            )
4407            self._barrier()
4408
4409        @skip_but_pass_in_sandcastle_if(
4410            BACKEND not in DistTestCases.backend_feature["ddp"],
4411            f"The {BACKEND} backend does not support DistributedDataParallel",
4412        )
4413        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
4414        def test_ddp_zero_output_features(self):
4415            class ToyModel(nn.Module):
4416                def __init__(self) -> None:
4417                    super().__init__()
4418                    self.net1 = nn.Linear(10, 10)
4419                    self.relu = nn.ReLU()
4420                    self.net2 = nn.Linear(10, 0)
4421
4422            model = ToyModel().to(self.rank)
4423            ddp_model = nn.parallel.DistributedDataParallel(
4424                model, device_ids=[self.rank]
4425            )
4426
4427        @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
4428        def test_ddp_create_graph(self):
4429            class Model(nn.Module):
4430                def __init__(self) -> None:
4431                    super().__init__()
4432                    self.p = nn.Parameter(torch.tensor(1.0))
4433
4434                def forward(self):
4435                    return self.p.pow(2)
4436
4437            model = Model()
4438            ddp_model = torch.nn.parallel.DistributedDataParallel(model)
4439            for _ in range(6):
4440                # Verify DDP doesn't throw when ran with create_graph=True.
4441                # Although we do warn about potential issues, please see
4442                # https://github.com/pytorch/pytorch/issues/63929 for details.
4443                ddp_model().backward(create_graph=True)
4444                # grad tensors should require grad.
4445                self.assertTrue(
4446                    all(param.requires_grad for param in ddp_model.parameters())
4447                )
4448
4449        @skip_but_pass_in_sandcastle_if(
4450            BACKEND not in DistTestCases.backend_feature["ddp"],
4451            f"The {BACKEND} backend does not support DistributedDataParallel",
4452        )
4453        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
4454        def test_DistributedDataParallel_non_default_stream(self):
4455            stream = torch.cuda.Stream(self.rank)
4456            rank = self.rank
4457            with torch.cuda.stream(stream):
4458                net = torch.nn.parallel.DistributedDataParallel(
4459                    torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
4460                )
4461                for i in range(1000):
4462                    # Clear gradients manually
4463                    grad = net.module.weight.grad
4464                    if grad is not None:
4465                        grad.requires_grad_(False)
4466                        grad.zero_()
4467                    # Forward + BW
4468                    batch = torch.tensor([rank]).float().cuda(rank)
4469                    loss = net(batch).sum()
4470                    loss.backward()
4471                    # For each worker, the gradient on the weight should be worker_rank.
4472                    grad = net.module.weight.grad
4473                    avg = grad.clone()
4474                    # All-reducing the gradient averages should give us the gradient
4475                    # average. If not, then one of the workers has not correctly
4476                    # written back the averaged gradient before this all-reduce call.
4477                    dist.all_reduce(avg)
4478                    world_size = int(os.environ["WORLD_SIZE"])
4479                    avg.div_(world_size)
4480                    expected_grad = sum(i for i in range(world_size)) / world_size
4481                    self.assertEqual(
4482                        avg[0, 0],
4483                        expected_grad,
4484                        msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
4485                    )
4486
4487        @skip_but_pass_in_sandcastle_if(
4488            BACKEND not in DistTestCases.backend_feature["cuda"],
4489            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
4490        )
4491        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
4492        def test_ddp_comm_hook_logging(self):
4493            hooks = [
4494                default.allreduce_hook,
4495                default.fp16_compress_hook,
4496                powerSGD.powerSGD_hook,
4497                powerSGD.batched_powerSGD_hook,
4498                quantization_hooks.quantization_pertensor_hook,
4499                quantization_hooks.quantization_perchannel_hook,
4500            ]
4501
4502            cpp_builtin_hooks = [
4503                dist.BuiltinCommHookType.ALLREDUCE,
4504                dist.BuiltinCommHookType.FP16_COMPRESS,
4505            ]
4506
4507            for hook in hooks:
4508                ddp_model = torch.nn.parallel.DistributedDataParallel(
4509                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
4510                    device_ids=[self.rank],
4511                )
4512                ddp_logging_data = ddp_model._get_ddp_logging_data()
4513                # Hook not registered yet, so should be empty
4514                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
4515                ddp_model.register_comm_hook(None, hook)
4516                ddp_logging_data = ddp_model._get_ddp_logging_data()
4517                self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)
4518
4519            for hook in cpp_builtin_hooks:
4520                ddp_model = torch.nn.parallel.DistributedDataParallel(
4521                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
4522                    device_ids=[self.rank],
4523                )
4524                ddp_logging_data = ddp_model._get_ddp_logging_data()
4525                # Hook not registered yet, so should be empty
4526                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
4527                ddp_model._register_builtin_comm_hook(hook)
4528                ddp_logging_data = ddp_model._get_ddp_logging_data()
4529                self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))
4530
4531            # No hook registered
4532            ddp_model = torch.nn.parallel.DistributedDataParallel(
4533                torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
4534                device_ids=[self.rank],
4535            )
4536            ddp_logging_data = ddp_model._get_ddp_logging_data()
4537            # Hook not registered yet, so should be empty
4538            self.assertEqual(ddp_logging_data.get("comm_hook"), None)
4539            # After second forward pass, hook should still be empty string
4540            for i in range(2):
4541                inp = torch.ones(1, 1, device=self.rank)
4542                loss = ddp_model(inp).sum()
4543                loss.backward()
4544
4545            ddp_logging_data = ddp_model._get_ddp_logging_data()
4546            # Note: DETAIL debug mode logs DDP logging data to stdout and
4547            # thus accesses std::map, which fills in a default value for the
4548            # type if it didn't exist.
4549            self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")
4550
4551        def _test_ddp_hook_with_optimizer_parity(
4552            self,
4553            grad_as_bucket_view,
4554            static_graph,
4555            optim_cls,
4556            optimize_subset,
4557            *functional_optim_args,
4558            **functional_optim_kwargs,
4559        ):
4560            rank = self.rank
4561            torch.cuda.set_device(rank)
4562            torch.manual_seed(rank)
4563            torch.cuda.manual_seed(rank)
4564            models_to_test = [
4565                (LargeNet(), torch.randn(1, 1000).cuda()),
4566            ]
4567            if HAS_TORCHVISION:
4568                models_to_test.append(
4569                    (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
4570                )
4571            for (model, inp) in models_to_test:
4572                # Enable determinism in cudnn operators
4573                with torch.backends.cudnn.flags(
4574                    enabled=True, deterministic=True, benchmark=False
4575                ):
4576                    # Create DDP model that runs optimizer in fused fashion.
4577                    ddp_model_with_optimizer_hook = (
4578                        torch.nn.parallel.DistributedDataParallel(
4579                            copy.deepcopy(model).cuda(),
4580                            device_ids=[self.rank],
4581                            gradient_as_bucket_view=grad_as_bucket_view,
4582                            static_graph=static_graph,
4583                        )
4584                    )
4585
4586                    # Create DDP model with no hook that does optimizer after
4587                    # backward.
4588                    ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
4589                        copy.deepcopy(model).cuda(),
4590                        device_ids=[self.rank],
4591                        gradient_as_bucket_view=grad_as_bucket_view,
4592                        static_graph=static_graph,
4593                    )
4594                    hook_params = ddp_model_with_optimizer_hook.parameters()
4595                    no_hook_params = ddp_model_with_no_hook.parameters()
4596                    if optimize_subset:
4597                        hook_params = list(hook_params)
4598                        no_hook_params = list(no_hook_params)
4599                        self.assertGreater(len(hook_params), 0)
4600                        hook_params = [hook_params[0]]
4601                        no_hook_params = [no_hook_params[0]]
4602
4603                    # Register a fused optimizer that will run optimizer in step
4604                    # with allreduce.
4605
4606                    if optimize_subset:
4607                        # API where optim_params is specified.
4608                        ddp_model_with_optimizer_hook._register_fused_optim(
4609                            optim_cls,
4610                            *functional_optim_args,
4611                            optim_params=hook_params,
4612                            **functional_optim_kwargs,
4613                        )
4614                    else:
4615                        # API where optim_params is omitted
4616                        ddp_model_with_optimizer_hook._register_fused_optim(
4617                            optim_cls,
4618                            *functional_optim_args,
4619                            **functional_optim_kwargs,
4620                        )
4621
4622                    optimizer_no_hook = optim_cls(
4623                        no_hook_params,
4624                        *functional_optim_args,
4625                        **functional_optim_kwargs,
4626                    )
4627
4628                    # Verify parameters are equal initially.
4629                    for hook_param, allreduce_param in zip(
4630                        ddp_model_with_optimizer_hook.parameters(),
4631                        ddp_model_with_no_hook.parameters(),
4632                    ):
4633                        self.assertEqual(hook_param, allreduce_param)
4634
4635                    # Save old parameters to later verify optimizer modified them.
4636                    opt_hook_init_params = copy.deepcopy(
4637                        list(ddp_model_with_optimizer_hook.parameters())
4638                    )
4639
4640                    # Run optimizer with hook model.
4641                    for i in range(6):
4642                        ddp_model_with_optimizer_hook.zero_grad()
4643                        out = ddp_model_with_optimizer_hook(inp)
4644                        loss = out.sum()
4645                        loss.backward()
4646
4647                    dist.barrier()
4648
4649                    # Run regular model.
4650                    for i in range(6):
4651                        ddp_model_with_no_hook.zero_grad()
4652                        out = ddp_model_with_no_hook(inp)
4653                        loss = out.sum()
4654                        loss.backward()
4655                        optimizer_no_hook.step()
4656
4657                    dist.barrier()
4658
4659                    # Now verify parameters are equal.
4660                    for hook_param, allreduce_param in zip(
4661                        ddp_model_with_optimizer_hook.parameters(),
4662                        ddp_model_with_no_hook.parameters(),
4663                    ):
4664                        self.assertEqual(hook_param, allreduce_param)
4665
4666                    # Verify optimizer modified appropriate parameter set,
4667                    # otherwise they'd be trivially equal above.
4668                    if optimize_subset:
4669                        self.assertNotEqual(
4670                            opt_hook_init_params[0],
4671                            next(iter(ddp_model_with_optimizer_hook.parameters())),
4672                        )
4673                        # Untouched params should be equal
4674                        self.assertEqual(
4675                            opt_hook_init_params[1:],
4676                            list(ddp_model_with_optimizer_hook.parameters())[1:],
4677                        )
4678                    else:
4679                        self.assertNotEqual(
4680                            opt_hook_init_params,
4681                            list(ddp_model_with_optimizer_hook.parameters()),
4682                        )
4683                    dist.barrier()
4684
4685        """
4686        # Commenting out the following 3 tests as they cause Sandcastle jobs to fail
4687        # Failure signature:
4688        # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw
4689
4690        from torch.testing._internal.common_utils import parametrize
4691
4692        @skip_but_pass_in_sandcastle_if(
4693            BACKEND == "nccl" or BACKEND == "ucc",
4694            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
4695        )
4696        @skip_if_lt_x_gpu(2)
4697        @parametrize("grad_as_bucket_view", [True, False])
4698        @parametrize("static_graph", [True, False])
4699        @parametrize("optimize_subset", [True, False])
4700        def test_ddp_hook_with_optimizer_parity_adamw(
4701            self,
4702            grad_as_bucket_view,
4703            static_graph,
4704            optimize_subset,
4705        ):
4706            adamw_lr = 1e-2
4707            adamw_betas = (0.9, 0.99)
4708            adamw_eps = 1e-6
4709            self._test_ddp_hook_with_optimizer_parity(
4710                grad_as_bucket_view,
4711                static_graph,
4712                torch.optim.AdamW,
4713                optimize_subset,
4714                adamw_lr,
4715                betas=adamw_betas,
4716                eps=adamw_eps,
4717            )
4718
4719        @skip_but_pass_in_sandcastle_if(
4720            BACKEND == "nccl" or BACKEND == "ucc",
4721            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
4722        )
4723        @skip_if_lt_x_gpu(2)
4724        @parametrize("optimize_subset", [True, False])
4725        def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
4726            adam_lr = 1e-2
4727            adam_betas = (0.9, 0.99)
4728            adam_eps = 1e-6
4729            self._test_ddp_hook_with_optimizer_parity(
4730                True,  # grad as bucket view
4731                False,  # static graph
4732                torch.optim.Adam,
4733                optimize_subset,
4734                adam_lr,
4735                betas=adam_betas,
4736                eps=adam_eps,
4737            )
4738
4739        @skip_but_pass_in_sandcastle_if(
4740            BACKEND == "nccl" or BACKEND == "ucc",
4741            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
4742        )
4743        @skip_if_lt_x_gpu(2)
4744        @parametrize("optimize_subset", [True, False])
4745        def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
4746            sgd_lr = 1e-2
4747            sgd_momentum = 0.9
4748            sgd_weight_decay = 0.01
4749            # Not testing grad_as_bucket_view and static_graph as they are
4750            # tested in AdamW test above.
4751            self._test_ddp_hook_with_optimizer_parity(
4752                True,  # grad as bucket view
4753                False,  # static_graph
4754                torch.optim.SGD,
4755                optimize_subset,
4756                sgd_lr,
4757                momentum=sgd_momentum,
4758                weight_decay=sgd_weight_decay,
4759            )
4760        """
4761
4762        @skip_if_lt_x_gpu(2)
4763        def test_get_data_parallel_params(self):
4764            torch.cuda.set_device(self.rank)
4765            model = TwoLinLayerNet().cuda()
4766            # Parameters to ignore are in the format {module_name}.{param_name}
4767            params_to_ignore = ["a.weight"]
4768            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
4769                model, params_to_ignore
4770            )
4771            ddp_model = torch.nn.parallel.DistributedDataParallel(
4772                model, device_ids=[self.rank]
4773            )
4774            dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
4775                model, named_params=True
4776            )
4777            for name, _ in dp_params:
4778                self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
4779
4780            # test named_params=False, just check if returns the expected
4781            # no of parameters.
4782            num_ddp_params = len(list(model.parameters())) - 1
4783            count = 0
4784            dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(model, named_params=False)
4785            for _ in dp_params:
4786                count += 1
4787            self.assertEqual(count, num_ddp_params)
4788
4789        def _test_ddp_apply_optim_in_backward(
4790            self,
4791            optim_cls,
4792            optim_kwargs,
4793            init_before,
4794            gradient_as_bucket_view=True,
4795        ):
4796            # Need to seed to ensure inputs are unique across rank. Otherwise,
4797            # allreduce won't have any effect.
4798            torch.manual_seed(self.rank)
4799            torch.cuda.manual_seed(self.rank)
4800            torch.cuda.set_device(self.rank)
4801
4802            # Test a simple linear as well as a ResNet model.
4803            models_to_test = [
4804                nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda()
4805            ]
4806            if HAS_TORCHVISION:
4807                models_to_test.append(torchvision.models.resnet50().cuda())
4808
4809            for j, model in enumerate(models_to_test):
4810                model_optim_in_bwd = copy.deepcopy(model)
4811                model = nn.parallel.DistributedDataParallel(
4812                    model,
4813                    device_ids=[self.rank],
4814                    gradient_as_bucket_view=gradient_as_bucket_view,
4815                )
4816                optim = optim_cls(model.parameters(), **optim_kwargs)
4817                if init_before:
4818                    _apply_optimizer_in_backward(
4819                        optimizer_class=optim_cls,
4820                        params=model_optim_in_bwd.parameters(),
4821                        optimizer_kwargs=optim_kwargs,
4822                    )
4823                model_optim_in_bwd = nn.parallel.DistributedDataParallel(
4824                    model_optim_in_bwd,
4825                    device_ids=[self.rank],
4826                    gradient_as_bucket_view=gradient_as_bucket_view,
4827                )
4828                if not init_before:
4829                    _apply_optimizer_in_backward(
4830                        optimizer_class=optim_cls,
4831                        params=model_optim_in_bwd.parameters(),
4832                        optimizer_kwargs=optim_kwargs,
4833                    )
4834
4835                for p1, p2 in zip(model.parameters(), model_optim_in_bwd.parameters()):
4836                    self.assertEqual(p1, p2, "Parameters not initially equal!")
4837                # Enable determinism in cudnn operators
4838                with torch.backends.cudnn.flags(
4839                    enabled=True, deterministic=True, benchmark=False
4840                ):
4841                    for i in range(8):
4842                        inp = (
4843                            torch.randn(1, 3, 1000, 1000, device="cuda")
4844                            if j == 1
4845                            else torch.randn(10, 3, device="cuda")
4846                        )
4847                        model(inp).sum().backward()
4848                        optim.step()
4849                        model_optim_in_bwd(
4850                            inp
4851                        ).sum().backward()  # runs optimizer as well
4852                        for p1, p2 in zip(
4853                            model.parameters(), model_optim_in_bwd.parameters()
4854                        ):
4855                            self.assertEqual(
4856                                p1, p2, f"Params not equal at iteration {i}"
4857                            )
4858                            self.assertTrue(
4859                                p2.grad is None,
4860                                f"Optim in backward grad is not None at {i}",
4861                            )
4862
4863                        # set_to_none for regular optimizer to match in backward
4864                        # case.
4865                        optim.zero_grad(set_to_none=True)
4866
4867        @skip_if_lt_x_gpu(2)
4868        def test_ddp_apply_optim_in_backward(self):
4869            for optim_cls, init_before in itertools.product(
4870                [torch.optim.SGD, torch.optim.Adam], [True, False]
4871            ):
4872                with self.subTest(optim_cls=optim_cls):
4873                    self._test_ddp_apply_optim_in_backward(
4874                        optim_cls=optim_cls,
4875                        optim_kwargs={"lr": 0.03},
4876                        init_before=init_before,
4877                    )
4878
4879        @skip_if_lt_x_gpu(2)
4880        def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
4881            for init_before in [True, False]:
4882                self._test_ddp_apply_optim_in_backward(
4883                    optim_cls=torch.optim.SGD,
4884                    optim_kwargs={"lr": 0.03},
4885                    init_before=init_before,
4886                    gradient_as_bucket_view=False,
4887                )
4888
4889        @skip_if_lt_x_gpu(2)
4890        def test_ddp_apply_optim_in_backward_ignored_params(self):
4891            torch.cuda.set_device(self.rank)
4892            for init_before in [True, False]:
4893                with self.subTest(init_before=init_before):
4894                    torch.manual_seed(self.rank)
4895                    torch.cuda.manual_seed(self.rank)
4896                    model = TwoLinLayerNet()
4897                    # Parameters to ignore are in the format {module_name}.{param_name}
4898                    params_to_ignore = ["a.weight"]
4899                    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
4900                        model, params_to_ignore
4901                    )
4902                    if init_before:
4903                        _apply_optimizer_in_backward(
4904                            optimizer_class=torch.optim.SGD,
4905                            params=model.parameters(),
4906                            optimizer_kwargs={"lr": 0.03},
4907                        )
4908                    net = torch.nn.parallel.DistributedDataParallel(
4909                        model.cuda(self.rank),
4910                        device_ids=[self.rank],
4911                    )
4912                    if not init_before:
4913                        _apply_optimizer_in_backward(
4914                            optimizer_class=torch.optim.SGD,
4915                            params=model.parameters(),
4916                            optimizer_kwargs={"lr": 0.03},
4917                        )
4918                    inp = torch.randn(1, 10)
4919                    a, b = net(inp)
4920                    (a.transpose(0, 1) @ b).sum().backward()
4921                    # a.weight did not go through allreduce, so optimizer acted on local
4922                    # gradient, which should be different across ranks. Remaining params
4923                    # should be equal.
4924                    models = [None for _ in range(dist.get_world_size())]
4925                    dist.all_gather_object(models, model)
4926                    rank0_model, remainder = models[0], models[1:]
4927                    for m in remainder:
4928                        self.assertNotEqual(rank0_model.a.weight, m.a.weight)
4929                        self.assertEqual(
4930                            list(rank0_model.b.parameters()), list(m.b.parameters())
4931                        )
4932                        self.assertEqual(rank0_model.a.bias, m.a.bias)
4933
4934        def _get_fp16_config(self) -> _MixedPrecision:
4935            return _MixedPrecision(
4936                param_dtype=torch.float16,
4937                reduce_dtype=torch.float16,
4938                buffer_dtype=torch.float16,
4939            )
4940
4941        @skip_if_lt_x_gpu(2)
4942        def test_ddp_native_mixed_precision_ignored_params(self):
4943            rank = self.rank
4944            torch.manual_seed(rank)
4945            torch.cuda.manual_seed(rank)
4946            torch.cuda.set_device(rank)
4947            model = TwoLinLayerNet()
4948            model.register_buffer("buffer", torch.ones(5))
4949            # Parameters to ignore are in the format {module_name}.{param_name}
4950            to_ignore = ["a.weight", "buffer"]
4951            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
4952                model, to_ignore,
4953            )
4954            mp_config = self._get_fp16_config()
4955            net = torch.nn.parallel.DistributedDataParallel(
4956                model.to(rank),
4957                device_ids=[rank],
4958                mixed_precision=mp_config,
4959                gradient_as_bucket_view=True,
4960            )
4961            to_ignore = [f"module.{name}" for name in to_ignore]
4962            expected_ignored = len(to_ignore)
4963            n_ignored = 0
4964            # ignored params should not have _mp_param or _fp_param fields.
4965            for (n, p) in itertools.chain(net.named_parameters(), net.named_buffers()):
4966                if n in to_ignore:
4967                    n_ignored += 1
4968                    self.assertFalse(hasattr(p, '_mp_param'))
4969                    self.assertFalse(hasattr(p, '_fp_param'))
4970                else:
4971                    self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
4972                    self.assertEqual(torch.float32, p._fp_param.dtype)
4973
4974            self.assertEqual(expected_ignored, n_ignored)
4975
4976        def _test_ddp_native_mixed_precision(
4977            self, gradient_as_bucket_view, set_grad_to_none
4978        ):
4979            rank = self.rank
4980            torch.manual_seed(rank)
4981            torch.cuda.manual_seed(rank)
4982            torch.cuda.set_device(rank)
4983            inp = torch.randn(10, 1)
4984            mp_config = self._get_fp16_config()
4985
4986            class MyModel(torch.nn.Module):
4987                def __init__(self) -> None:
4988                    super().__init__()
4989                    self.m = torch.nn.Linear(1, 5)
4990                    self.register_buffer('buffer', torch.randn(1, 2))
4991                    self.p = torch.nn.Parameter(
4992                        torch.randn(10, 5), requires_grad=False
4993                    )
4994
4995                def forward(self_, x):  # noqa: B902
4996                    params = self_.m.parameters()
4997                    for p in params:
4998                        self.assertEqual(mp_config.param_dtype, p.dtype)
4999
5000                    self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype)
5001
5002                    self.assertEqual(mp_config.param_dtype, x.dtype)
5003                    return self_.m(x) + self_.p
5004
5005            m = MyModel()
5006
5007            net = torch.nn.parallel.DistributedDataParallel(
5008                m.to(rank),
5009                device_ids=[rank],
5010                mixed_precision=mp_config,
5011                gradient_as_bucket_view=gradient_as_bucket_view,
5012            )
5013            # Buffers are casted in constructor.
5014            self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype)
5015            # Each param should have an mp_param in the lower precision, and
5016            # an fp_param in the higher precision.
5017            for p in net.parameters():
5018                self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
5019                self.assertEqual(torch.float32, p._fp_param.dtype)
5020
5021            for i in range(6):
5022                loss = net(inp).sum()
5023                loss.backward()
5024                # Verify gradient synchronization and params and grads are fp32.
5025                for n, param in net.named_parameters():
5026                    self.assertEqual(param.dtype, torch.float32)
5027                    if param.grad is None:
5028                        assert n == 'module.p'  # Only param that doesn't require grad
5029                    else:
5030                        self.assertEqual(param.grad.dtype, torch.float32)
5031                        tensor_list = [
5032                            torch.zeros_like(param.grad)
5033                            for _ in range(dist.get_world_size(net.process_group))
5034                        ]
5035                        dist.all_gather(tensor_list, param.grad)
5036                        g, rest = tensor_list[0], tensor_list[1:]
5037                        self.assertEqual(g.dtype, torch.float32)
5038                        for g_ in rest:
5039                            self.assertEqual(g_.dtype, torch.float32)
5040                            self.assertEqual(g, g_)
5041                net.zero_grad(set_to_none=set_grad_to_none)
5042
5043        @skip_if_lt_x_gpu(2)
5044        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(self):
5045            self._test_ddp_native_mixed_precision(
5046                gradient_as_bucket_view=False,
5047                set_grad_to_none=False,
5048            )
5049
5050        @skip_if_lt_x_gpu(2)
5051        def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self):
5052            self._test_ddp_native_mixed_precision(
5053                gradient_as_bucket_view=True,
5054                set_grad_to_none=False,
5055            )
5056
5057        @skip_if_lt_x_gpu(2)
5058        def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self):
5059            self._test_ddp_native_mixed_precision(
5060                gradient_as_bucket_view=True, set_grad_to_none=True
5061            )
5062
5063        @skip_if_lt_x_gpu(2)
5064        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(self):
5065            self._test_ddp_native_mixed_precision(
5066                gradient_as_bucket_view=True, set_grad_to_none=True
5067            )
5068
5069        def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
5070            rank = self.rank
5071            m = torch.nn.Linear(1, 5)
5072            try:
5073                process_group = state.process_group
5074            except AttributeError:
5075                process_group = state
5076
5077            net_with_hook = torch.nn.parallel.DistributedDataParallel(
5078                copy.deepcopy(m).to(rank),
5079                device_ids=[rank],
5080                process_group=process_group,
5081            )
5082            net_with_hook.register_comm_hook(state=state, hook=hook)
5083            net_without_hook = torch.nn.parallel.DistributedDataParallel(
5084                copy.deepcopy(m).to(rank),
5085                device_ids=[rank],
5086                process_group=process_group,
5087            )
5088            for i in range(100):
5089                # Clear gradients manually.
5090                for g in [
5091                    net_without_hook.module.weight.grad,
5092                    net_with_hook.module.weight.grad,
5093                ]:
5094                    if g is not None:
5095                        g.requires_grad_(False)
5096                        g.zero_()
5097                # Forward + BW
5098                batch = torch.tensor([rank]).float().cuda(rank)
5099                loss = net_without_hook(batch).sum()
5100                loss.backward()
5101                # For each worker, the gradient on the weight should be worker_rank.
5102                grad = net_without_hook.module.weight.grad
5103                avg = grad.clone()
5104                expected_grad = (
5105                    sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
5106                )
5107                loss_hook = net_with_hook(batch).sum()
5108                loss_hook.backward()
5109                grad_hook = net_with_hook.module.weight.grad
5110                avg_hook = grad_hook.clone()
5111
5112                if i < num_validated_iters:
5113                    # Verify hook grad with expected.
5114                    self.assertEqual(
5115                        avg_hook[0, 0].item(),
5116                        expected_grad,
5117                        msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
5118                    )
5119                    # Verify hook grad with vanilla allreduce
5120                    self.assertEqual(
5121                        avg_hook[0, 0],
5122                        avg[0, 0],
5123                        msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
5124                    )
5125
5126        @skip_but_pass_in_sandcastle_if(
5127            BACKEND not in DistTestCases.backend_feature["cuda"],
5128            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
5129        )
5130        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
5131        def test_ddp_hook_parity_allreduce(self):
5132            self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)
5133
5134        @skip_but_pass_in_sandcastle_if(
5135            BACKEND not in DistTestCases.backend_feature["cuda"],
5136            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
5137        )
5138        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
5139        def test_ddp_hook_parity_allreduce_process_group(self):
5140            # process_group is passed in to both DDP and comm. hook
5141            world_size = dist.get_world_size()
5142            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
5143            gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
5144            process_group = torch.distributed.new_group(gpus)
5145            self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)
5146
5147        @skip_but_pass_in_sandcastle_if(
5148            BACKEND not in DistTestCases.backend_feature["cuda"],
5149            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
5150        )
5151        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
5152        def test_ddp_hook_parity_powerSGD(self):
5153            for warm_start in [True, False]:
5154                powersgd_state = powerSGD.PowerSGDState(
5155                    process_group=None,
5156                    matrix_approximation_rank=1,
5157                    start_powerSGD_iter=2,
5158                    warm_start=warm_start,
5159                )
5160                self._test_ddp_hook_parity(
5161                    state=powersgd_state, hook=powerSGD.powerSGD_hook
5162                )
5163
5164        @skip_but_pass_in_sandcastle_if(
5165            BACKEND not in DistTestCases.backend_feature["cuda"],
5166            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
5167        )
5168        @skip_but_pass_in_sandcastle_if(
5169            NO_MULTIPROCESSING_SPAWN,
5170            "Disabled for environments that \
5171                         don't support multiprocessing with spawn start method",
5172        )
5173        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
5174        def test_ddp_hook_parity_post_localSGD(self):
5175            # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
5176            # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
5177            state = post_localSGD.PostLocalSGDState(
5178                process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
5179            )
5180            self._test_ddp_hook_parity(
5181                state=state, hook=post_localSGD.post_localSGD_hook
5182            )
5183            # Only validate the warmup iterations before local SGD is applied,
5184            # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
5185            # Note that in practice a model averager has to be applied to run model averaging,
5186            # so local gradient averaging is not necessary.
5187            start_localSGD_iter = 10
5188            state = post_localSGD.PostLocalSGDState(
5189                process_group=None,
5190                subgroup=dist.group.WORLD,
5191                start_localSGD_iter=start_localSGD_iter,
5192                post_local_gradient_allreduce=False,
5193            )
5194            self._test_ddp_hook_parity(
5195                state=state,
5196                hook=post_localSGD.post_localSGD_hook,
5197                num_validated_iters=start_localSGD_iter,
5198            )
5199
5200            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
5201            # For this single-node test environment, the intra-node process group is equivalent to
5202            # the global process group.
5203            if self.world_size == dist.get_world_size():
5204                state = post_localSGD.PostLocalSGDState(
5205                    process_group=None, subgroup=None, start_localSGD_iter=10
5206                )
5207                self._test_ddp_hook_parity(
5208                    state=state, hook=post_localSGD.post_localSGD_hook
5209                )
5210
5211            # Since we start local SGD later than the total number of 100 iterations,
5212            # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
5213            state = post_localSGD.PostLocalSGDState(
5214                process_group=None, subgroup=None, start_localSGD_iter=1000
5215            )
5216            self._test_ddp_hook_parity(
5217                state=state, hook=post_localSGD.post_localSGD_hook
5218            )
5219
5220        def _prepare_single_device_module(
5221            self,
5222            rank,
5223            process_group,
5224            devices,
5225            device_ids,
5226            global_batch_size,
5227            gradient_as_bucket_view=False,
5228        ):
5229            model = Net()
5230            device = devices[0] if devices else torch.device("cuda:%d" % rank)
5231            ddp_model = DistributedDataParallel(
5232                copy.deepcopy(model).to(device),
5233                device_ids=device_ids,
5234                process_group=process_group,
5235                bucket_cap_mb=0.001,
5236                gradient_as_bucket_view=gradient_as_bucket_view,
5237            )
5238
5239            model.to(device)
5240
5241            input = torch.randn(global_batch_size, 2).to(device)
5242            target = torch.randn(global_batch_size, 4).to(device)
5243
5244            return model, ddp_model, input, target
5245
5246        def _prepare_cpu_module(
5247            self,
5248            process_group,
5249            global_batch_size,
5250            gradient_as_bucket_view=False,
5251        ):
5252            model = Net()
5253            ddp_model = DistributedDataParallel(
5254                copy.deepcopy(model),
5255                process_group=process_group,
5256                bucket_cap_mb=0.001,
5257                gradient_as_bucket_view=gradient_as_bucket_view,
5258            )
5259            input = torch.randn(global_batch_size, 2)
5260            target = torch.randn(global_batch_size, 4)
5261            return model, ddp_model, input, target
5262
5263        def _test_accumulate_gradients_no_sync(
5264            self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
5265        ):
5266            """
5267            This is the recommended way to implement accumulate grads.
5268            If ``ddp_comm_hook`` input was specified, it will also register that hook
5269            to the ``ddp_model``. The hook fed into this function should not change
5270            the resulting gradients.
5271            """
5272            group, group_id, rank = self._init_global_test()
5273            world_size = get_world_size()
5274
5275            # FIXME: Add testing for gloo/CUDA
5276            if BACKEND == "mpi" or BACKEND == "gloo":
5277                global_batch_size = world_size
5278                local_batch_size = 1
5279                model, ddp_model, input, target = self._prepare_cpu_module(
5280                    group_id, global_batch_size, gradient_as_bucket_view
5281                )
5282
5283            if BACKEND == "nccl":
5284                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
5285                int_devices = rank_to_GPU[rank][:1]
5286                devices = [torch.device("cuda:" + str(i)) for i in int_devices]
5287                global_batch_size = world_size
5288                local_batch_size = len(devices)
5289                model, ddp_model, input, target = self._prepare_single_device_module(
5290                    rank,
5291                    group_id,
5292                    devices,
5293                    devices,
5294                    global_batch_size,
5295                    gradient_as_bucket_view,
5296                )
5297
5298            if ddp_comm_hook is not None:
5299                ddp_model.register_comm_hook(group_id, ddp_comm_hook)
5300
5301            def step_model(model, input, target):
5302                model.train()
5303                output = model(input)
5304                loss = F.mse_loss(output, target.to(output.device))
5305                loss.backward()
5306
5307            # ensure accumulate grads works with no_grad => no grads are accumulated.
5308            with torch.no_grad():
5309                with ddp_model.no_sync():
5310                    ddp_model.train()
5311                    ddp_model(input)
5312
5313            # check two model parameters over num_iters iterations
5314            for iteration in range(num_iters):
5315                step_model(model, input, target)
5316
5317                ddp_input = input[
5318                    rank * local_batch_size : (rank + 1) * local_batch_size
5319                ]
5320                ddp_target = target[
5321                    rank * local_batch_size : (rank + 1) * local_batch_size
5322                ]
5323
5324                if iteration % 2 == 0:
5325                    # accumulate grads locally
5326                    with ddp_model.no_sync():
5327                        step_model(ddp_model, ddp_input, ddp_target)
5328                else:
5329                    # sync grads
5330                    step_model(ddp_model, ddp_input, ddp_target)
5331
5332                for i, j in zip(model.parameters(), ddp_model.parameters()):
5333                    if not i.requires_grad:
5334                        continue
5335                    if iteration % 2 == 0:
5336                        self.assertNotEqual(i.grad, j.grad)
5337                    else:
5338                        self.assertEqual(i.grad, j.grad)
5339
5340                # Shuffle the input so that DDP input is different
5341                torch.manual_seed(1337 + iteration)
5342                input = input[torch.randperm(global_batch_size)]
5343
5344        @skip_but_pass_in_sandcastle_if(
5345            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
5346            "get_future is only supported on mpi, nccl and gloo",
5347        )
5348        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
5349        def test_accumulate_gradients_no_sync(self):
5350            """
5351            Runs _test_accumulate_gradients_no_sync using default inputs
5352            """
5353            self._test_accumulate_gradients_no_sync()
5354
5355        @skip_but_pass_in_sandcastle_if(
5356            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
5357            "get_future is only supported on mpi, nccl and gloo",
5358        )
5359        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
5360        def test_accumulate_gradients_no_sync_grad_is_view(self):
5361            """
5362            Runs _test_accumulate_gradients_no_sync using default inputs
5363            """
5364            self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)
5365
5366        @skip_but_pass_in_sandcastle_if(
5367            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
5368            "get_future is only supported on mpi, nccl and gloo",
5369        )
5370        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
5371        def test_accumulate_gradients_no_sync_allreduce_hook(self):
5372            """
5373            Runs multiple iterations on _test_accumulate_gradients_no_sync
5374            using allreduce hook and validates whether future result was properly
5375            passed as gradients in reducer.
5376            """
5377
5378            world_size = get_world_size()
5379
5380            def allreduce_hook(
5381                group_id: object, bucket: dist.GradBucket
5382            ) -> torch.futures.Future[torch.Tensor]:
5383                tensors = [bucket.buffer() / world_size]
5384                return (
5385                    group_id.allreduce(tensors)
5386                    .get_future()
5387                    .then(lambda fut: fut.value()[0])
5388                )
5389
5390            self._test_accumulate_gradients_no_sync(
5391                num_iters=4, ddp_comm_hook=allreduce_hook
5392            )
5393
5394        @skip_but_pass_in_sandcastle_if(
5395            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
5396            "get_future is only supported on mpi, nccl and gloo",
5397        )
5398        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
5399        def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
5400            """
5401            Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
5402            hook that also uses then callbacks. In first then callback result is multiplied
5403            by 2, and the second callback divides the result by 2 * world_size. It validates
5404            whether final result was properly passed as gradients in reducer.
5405            """
5406
5407            world_size = get_world_size()
5408
5409            def allreduce_with_then_hook(
5410                group_id: object, bucket: dist.GradBucket
5411            ) -> torch.futures.Future[torch.Tensor]:
5412                fut = group_id.allreduce([bucket.buffer()]).get_future()
5413
5414                def mult(fut):
5415                    # Multiply the result by 2.
5416                    return 2 * fut.wait()[0]
5417
5418                def div(fut):
5419                    # Divide the result by 2 * world_size.
5420                    return fut.wait() / (2 * world_size)
5421
5422                return fut.then(mult).then(div)
5423
5424            self._test_accumulate_gradients_no_sync(
5425                num_iters=4, ddp_comm_hook=allreduce_with_then_hook
5426            )
5427
5428        @skip_but_pass_in_sandcastle_if(
5429            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
5430            "get_future is only supported on mpi, nccl and gloo",
5431        )
5432        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
5433        def test_get_future(self):
5434            def mult(fut):
5435                return [t * 3 for t in fut.wait()]
5436
5437            def add(fut):
5438                return [t + 1 for t in fut.wait()]
5439
5440            group, group_id, rank = self._init_global_test()
5441            input = _build_tensor(3, 2)
5442            if BACKEND == "nccl":
5443                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
5444                device_id = rank_to_GPU[rank][0]
5445                input = input.to(device_id)
5446            fut = group_id.allreduce([input]).get_future()
5447            res = fut.then(mult).then(add).wait()
5448            expected = _build_tensor(3, 2 * len(group) * 3 + 1)
5449
5450            self.assertEqual(res[0], expected)
5451
5452        @skip_but_pass_in_sandcastle_if(
5453            BACKEND not in DistTestCases.backend_feature["ddp"],
5454            f"The {BACKEND} backend does not support DistributedDataParallel",
5455        )
5456        @skip_if_no_gpu
5457        def test_DistributedDataParallel(self):
5458            group, group_id, rank = self._init_global_test()
5459            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
5460            gpus = list(rank_to_GPU[rank])
5461
5462            for use_bucket_view, static_graph in itertools.product(
5463                (False, True), (False, True)
5464            ):
5465                self._test_DistributedDataParallel(
5466                    gpu_subset=gpus,
5467                    rank=rank,
5468                    gradient_as_bucket_view=use_bucket_view,
5469                    static_graph=static_graph,
5470                )
5471
5472                # test set static graph twice
5473                self._test_DistributedDataParallel(
5474                    gpu_subset=gpus,
5475                    rank=rank,
5476                    gradient_as_bucket_view=use_bucket_view,
5477                    static_graph=static_graph,
5478                    set_static_graph_twice=True,
5479                )
5480
5481                # test output_device
5482                self._test_DistributedDataParallel(
5483                    gpu_subset=gpus,
5484                    rank=rank,
5485                    output_device=torch.device("cuda"),
5486                    gradient_as_bucket_view=use_bucket_view,
5487                    static_graph=static_graph,
5488                )
5489
5490                # test device_ids
5491                gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
5492                self._test_DistributedDataParallel(
5493                    gpu_subset=gpus_list,
5494                    rank=rank,
5495                    output_device=torch.device("cuda"),
5496                    gradient_as_bucket_view=use_bucket_view,
5497                    static_graph=static_graph,
5498                )
5499
5500        def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
5501            torch.manual_seed(31415)
5502            # Creates model and optimizer in default precision
5503            model = copy.deepcopy(DDP_NET).cuda()
5504            optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
5505
5506            # Creates a GradScaler once at the beginning of training.
5507            scaler = GradScaler()
5508
5509            ddp_model = nn.parallel.DistributedDataParallel(
5510                model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
5511            )
5512
5513            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
5514            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
5515            loss_fn = nn.MSELoss()
5516
5517            # verify grads are none before training
5518            for p in ddp_model.parameters():
5519                self.assertTrue(p is not None)
5520                self.assertTrue(p.grad is None)
5521
5522            for idx in range(20):
5523                optimizer.zero_grad()
5524                # Runs the forward pass with autocasting.
5525                with autocast():
5526                    output = ddp_model(input)
5527                    loss = loss_fn(output, target)
5528
5529                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
5530                # Backward passes under autocast are not recommended.
5531                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
5532                scaler.scale(loss).backward()
5533
5534                # verify grads are not none and are valid during training
5535                for p in ddp_model.parameters():
5536                    if p.requires_grad:
5537                        self.assertTrue(p.grad is not None)
5538                        self.assertFalse(p.grad.isnan().any())
5539                        self.assertFalse(p.grad.isinf().any())
5540
5541                # scaler.step() first unscales the gradients of the optimizer's assigned params.
5542                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
5543                # otherwise, optimizer.step() is skipped.
5544                scaler.step(optimizer)
5545
5546                # Updates the scale for next iteration.
5547                scaler.update()
5548
5549                # Shuffle the input so that DDP input is different
5550                torch.manual_seed(1337 + idx)
5551                input = input[torch.randperm(dist.get_world_size() * 2)]
5552
5553            return ddp_model
5554
5555        @skip_but_pass_in_sandcastle_if(
5556            BACKEND not in DistTestCases.backend_feature["ddp"],
5557            f"The {BACKEND} backend does not support DistributedDataParallel",
5558        )
5559        @skip_if_no_gpu
5560        def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
5561            torch.cuda.set_device(self.rank)
5562            ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
5563                grad_is_view=False
5564            )
5565            ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
5566                grad_is_view=True
5567            )
5568            for i, j in zip(
5569                ddp_model_grad_not_view.parameters(),
5570                ddp_model_grad_is_view.parameters(),
5571            ):
5572                self.assertEqual(i, j)
5573
5574        def _test_DistributedDataParallel_SyncBatchNorm(
5575            self,
5576            gpu_subset,
5577            rank,
5578            local_bs,
5579            global_bs,
5580            offset,
5581            output_device=None,
5582            affine=True,
5583        ):
5584            # Run a simple end to end DDP model, use result of single node model
5585            # as baseline
5586
5587            # cpu training setup
5588            model = BN_NET if affine else BN_NET_NO_AFFINE
5589
5590            # single gpu training setup
5591            model_gpu = copy.deepcopy(model)
5592            model_gpu.cuda(gpu_subset[0])
5593
5594            # DDP training setup
5595            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
5596            model_DDP.cuda(gpu_subset[0])
5597            model_DDP = nn.parallel.DistributedDataParallel(
5598                model_DDP, device_ids=gpu_subset
5599            )
5600
5601            # test serializable/unserializable
5602            with tempfile.NamedTemporaryFile() as tmp:
5603                if sys.platform == "win32":
5604                    torch.save(model_DDP, tmp)
5605                    tmp.seek(0)
5606                    # weights_only=False as this is legacy code that saves the model
5607                    model_DDP = torch.load(tmp, weights_only=False)
5608                else:
5609                    torch.save(model_DDP, tmp.name)
5610                    # weights_only=False as this is legacy code that saves the model
5611                    model_DDP = torch.load(tmp.name, weights_only=False)
5612
5613            # data initialization
5614            input_cpu = torch.randn(global_bs, 2)
5615            target = torch.randn(global_bs, 4)
5616            loss = nn.MSELoss()
5617
5618            # check two model parameters over 5 iterations
5619            self._test_DDP_niter(
5620                model_gpu,
5621                model_DDP,
5622                input_cpu.cuda(gpu_subset[0]),
5623                target.cuda(gpu_subset[0]),
5624                loss,
5625                local_bs,
5626                rank,
5627                global_bs,
5628                True,
5629                offset,
5630                dist.get_world_size(),
5631                5 if affine else 2,
5632            )
5633            self._barrier()
5634
5635        def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
5636            learning_rate = 0.03
5637
5638            net = torch.nn.parallel.DistributedDataParallel(
5639                copy.deepcopy(DDP_NET).cuda(),
5640                device_ids=[self.rank],
5641                gradient_as_bucket_view=grad_is_view,
5642            )
5643            averager = create_averager()
5644            opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
5645
5646            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
5647                copy.deepcopy(DDP_NET).cuda(),
5648                device_ids=[self.rank],
5649                gradient_as_bucket_view=grad_is_view,
5650            )
5651            # Process group cannot be pickled in some environments,
5652            # so cannot deep copy an averager. See:
5653            # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
5654            averager2 = create_averager()
5655            post_localSGD_opt = self._create_post_localSGD_optimizer(
5656                net_using_post_localSGD_opt, learning_rate, averager2
5657            )
5658
5659            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
5660            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
5661            loss_fn = nn.MSELoss()
5662
5663            for _ in range(20):
5664                self._perform_a_train_step(opt, net, loss_fn, input, target)
5665                averager.average_parameters(net.parameters())
5666
5667                self._perform_a_train_step(
5668                    post_localSGD_opt,
5669                    net_using_post_localSGD_opt,
5670                    loss_fn,
5671                    input,
5672                    target,
5673                )
5674                for p1, p2 in zip(
5675                    net.parameters(), net_using_post_localSGD_opt.parameters()
5676                ):
5677                    self.assertEqual(p1.data, p2.data)
5678
5679            # Also check if the built-in step counters are the same to prevent a bug like #74737.
5680            self.assertEqual(averager.step, averager2.step)
5681
5682        def _create_periodic_model_averager(self):
5683            return averagers.PeriodicModelAverager(period=4, warmup_steps=10)
5684
5685        def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
5686            return post_localSGD_optimizer.PostLocalSGDOptimizer(
5687                optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
5688                averager=averager,
5689            )
5690
5691        def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
5692            optimizer.zero_grad()
5693            output = net(input)
5694            loss = loss_fn(output, target)
5695            loss.backward()
5696            optimizer.step()
5697
5698        def _test_post_localSGD_optimizer_step_reload(
5699            self, create_averager, chkpt_file
5700        ):
5701            learning_rate = 0.03
5702
5703            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
5704                copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
5705            )
5706
5707            averager = create_averager()
5708            post_localSGD_opt = self._create_post_localSGD_optimizer(
5709                net_using_post_localSGD_opt, learning_rate, averager
5710            )
5711
5712            averager2 = create_averager()
5713            dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
5714                net_using_post_localSGD_opt, learning_rate, averager2
5715            )
5716
5717            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
5718            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
5719            loss_fn = nn.MSELoss()
5720
5721            for _ in range(20):
5722                self._perform_a_train_step(
5723                    post_localSGD_opt,
5724                    net_using_post_localSGD_opt,
5725                    loss_fn,
5726                    input,
5727                    target,
5728                )
5729
5730            if self.rank == 0:
5731                torch.save(
5732                    {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file
5733                )
5734
5735            dist.barrier()
5736            map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
5737            checkpoint = torch.load(chkpt_file, map_location=map_location)
5738            dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])
5739
5740            # Check that we didn't hit the trivial case
5741            self.assertNotEqual(averager2.step, 0)
5742            # Check if dummy averager was initialized to a correct value
5743            self.assertEqual(averager.step, averager2.step)
5744
5745            # Remove 'step' entry from a checkpoint.
5746            # And make sure it is not in the state dictionary
5747            del checkpoint["optimizer_state_dict"]["step"]
5748            self.assertNotIn("step", checkpoint["optimizer_state_dict"])
5749
5750            # Check if checkpoint without a 'step' entry invokes a warning
5751            with self.assertWarnsRegex(
5752                expected_warning=UserWarning,
5753                expected_regex="Loaded state dict does not contain a step counter for an averager. "
5754                "Setting step counter to 0.",
5755            ):
5756                dummy_post_localSGD_opt.load_state_dict(
5757                    checkpoint["optimizer_state_dict"]
5758                )
5759
5760            self.assertEqual(averager2.step, 0)
5761
5762        @skip_if_lt_x_gpu(2)
5763        @skip_but_pass_in_sandcastle_if(
5764            BACKEND not in DistTestCases.backend_feature["ddp"],
5765            f"The {BACKEND} backend does not support DistributedDataParallel",
5766        )
5767        def test_post_localSGD_optimizer_parity(self):
5768            torch.cuda.set_device(self.rank)
5769            self._test_post_localSGD_optimizer_parity(
5770                self._create_periodic_model_averager,
5771                grad_is_view=False,
5772            )
5773
5774        @skip_if_lt_x_gpu(2)
5775        @skip_but_pass_in_sandcastle_if(
5776            BACKEND not in DistTestCases.backend_feature["ddp"],
5777            f"The {BACKEND} backend does not support DistributedDataParallel",
5778        )
5779        def test_post_localSGD_optimizer_parity_grad_is_view(self):
5780            torch.cuda.set_device(self.rank)
5781            self._test_post_localSGD_optimizer_parity(
5782                self._create_periodic_model_averager,
5783                grad_is_view=True,
5784            )
5785
5786        def _create_hierarchical_model_averager(self):
5787            period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
5788            return hierarchicalSGD.HierarchicalModelAverager(
5789                period_group_size_dict=period_group_size_dict, warmup_steps=4
5790            )
5791
5792        @skip_if_lt_x_gpu(4)
5793        @skip_if_odd_worldsize
5794        @skip_but_pass_in_sandcastle_if(
5795            BACKEND not in DistTestCases.backend_feature["ddp"],
5796            f"The {BACKEND} backend does not support DistributedDataParallel",
5797        )
5798        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
5799            torch.cuda.set_device(self.rank)
5800            self._test_post_localSGD_optimizer_parity(
5801                self._create_hierarchical_model_averager,
5802                grad_is_view=False,
5803            )
5804
5805        @skip_if_lt_x_gpu(4)
5806        @skip_if_odd_worldsize
5807        @skip_but_pass_in_sandcastle_if(
5808            BACKEND not in DistTestCases.backend_feature["ddp"],
5809            f"The {BACKEND} backend does not support DistributedDataParallel",
5810        )
5811        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(
5812            self,
5813        ):
5814            torch.cuda.set_device(self.rank)
5815            self._test_post_localSGD_optimizer_parity(
5816                self._create_hierarchical_model_averager,
5817                grad_is_view=True,
5818            )
5819
5820        @skip_if_lt_x_gpu(2)
5821        @skip_but_pass_in_sandcastle_if(
5822            BACKEND not in DistTestCases.backend_feature["ddp"],
5823            f"The {BACKEND} backend does not support DistributedDataParallel",
5824        )
5825        def test_post_localSGD_optimizer_step_reload(self):
5826            torch.cuda.set_device(self.rank)
5827            with _rank_temp_file() as tmp_file:
5828                self._test_post_localSGD_optimizer_step_reload(
5829                    self._create_periodic_model_averager, tmp_file
5830                )
5831
5832        @skip_but_pass_in_sandcastle_if(
5833            BACKEND not in DistTestCases.backend_feature["ddp"],
5834            f"The {BACKEND} backend does not support DistributedDataParallel",
5835        )
5836        @skip_if_no_gpu
5837        def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
5838            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
5839                torch.channels_last
5840            )
5841            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
5842                torch.channels_last_3d
5843            )
5844
5845        def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
5846            self, memory_format
5847        ):
5848            group, group_id, rank = self._init_global_test()
5849            num_processes = dist.get_world_size()
5850            local_bs = 2
5851            bs_offset = int(rank * 2)
5852            global_bs = int(num_processes * 2)
5853
5854            model = ONLY_SBN_NET
5855            model_gpu = copy.deepcopy(model).cuda(rank)
5856            model_DDP = nn.parallel.DistributedDataParallel(
5857                model_gpu, device_ids=[rank]
5858            )
5859
5860            shapes = [global_bs, 2, 4, 4] + (
5861                [] if memory_format is torch.channels_last else [4]
5862            )
5863
5864            input_gpu = (
5865                torch.randn(*shapes, dtype=torch.float)
5866                .cuda(rank)
5867                .to(memory_format=memory_format)
5868            )
5869            target_gpu = (
5870                torch.randn(*shapes, dtype=torch.float)
5871                .cuda(rank)
5872                .to(memory_format=memory_format)
5873            )
5874            loss = nn.MSELoss()
5875
5876            # check two model parameters over 5 iterations
5877            self._test_DDP_niter(
5878                model_gpu,
5879                model_DDP,
5880                input_gpu,
5881                target_gpu,
5882                loss,
5883                local_bs,
5884                rank,
5885                global_bs,
5886                True,
5887                bs_offset,
5888                dist.get_world_size(),
5889                memory_format=memory_format,
5890            )
5891            self._barrier()
5892
5893        @skip_but_pass_in_sandcastle_if(
5894            BACKEND not in DistTestCases.backend_feature["ddp"],
5895            f"The {BACKEND} backend does not support DistributedDataParallel",
5896        )
5897        @skip_if_no_gpu
5898        def test_DistributedDataParallel_SyncBatchNorm(self):
5899            group, group_id, rank = self._init_global_test()
5900            world_size = dist.get_world_size()
5901            # DDP does not support replicating BN layers within a process, hence
5902            # testing with one module replica per process
5903            gpus = [rank]
5904
5905            local_bs = 2
5906            bs_offset = int(rank * 2)
5907            global_bs = int(world_size * 2)
5908
5909            self._test_DistributedDataParallel_SyncBatchNorm(
5910                gpu_subset=gpus,
5911                rank=rank,
5912                local_bs=local_bs,
5913                global_bs=global_bs,
5914                offset=bs_offset,
5915            )
5916
5917            # test output_device
5918            self._test_DistributedDataParallel_SyncBatchNorm(
5919                gpu_subset=gpus,
5920                rank=rank,
5921                local_bs=local_bs,
5922                global_bs=global_bs,
5923                offset=bs_offset,
5924                output_device=torch.device("cuda"),
5925            )
5926
5927            # test device_ids
5928            gpus = [torch.device("cuda:" + str(i)) for i in gpus]
5929            self._test_DistributedDataParallel_SyncBatchNorm(
5930                gpu_subset=gpus,
5931                rank=rank,
5932                local_bs=local_bs,
5933                global_bs=global_bs,
5934                offset=bs_offset,
5935                output_device=torch.device("cuda"),
5936            )
5937
5938        @skip_but_pass_in_sandcastle_if(
5939            BACKEND not in DistTestCases.backend_feature["ddp"],
5940            f"The {BACKEND} backend does not support DistributedDataParallel",
5941        )
5942        @skip_if_no_gpu
5943        def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
5944            group, group_id, rank = self._init_global_test()
5945            world_size = dist.get_world_size()
5946            # DDP does not support replicating BN layers within a process, hence
5947            # testing with one module replica per process
5948            gpus = [rank]
5949
5950            local_bs = 2
5951            bs_offset = int(rank * 2)
5952            global_bs = int(world_size * 2)
5953
5954            self._test_DistributedDataParallel_SyncBatchNorm(
5955                gpu_subset=gpus,
5956                rank=rank,
5957                local_bs=local_bs,
5958                global_bs=global_bs,
5959                offset=bs_offset,
5960                affine=False,
5961            )
5962
5963        @skip_but_pass_in_sandcastle_if(
5964            BACKEND not in DistTestCases.backend_feature["ddp"],
5965            f"The {BACKEND} backend does not support DistributedDataParallel",
5966        )
5967        @skip_if_no_gpu
5968        def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
5969            group, group_id, rank = self._init_global_test()
5970            # DDP does not support replicating BN layers within a process, hence
5971            # testing with one module replica per process
5972            gpus = [rank]
5973
5974            model = nn.BatchNorm1d(2)
5975
5976            # single gpu training setup
5977            model_gpu = copy.deepcopy(model)
5978            model_gpu.cuda(gpus[0])
5979
5980            # DDP training setup
5981            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
5982            model_DDP.cuda(gpus[0])
5983            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
5984
5985            local_bs = len(gpus) * 2
5986            global_bs = dist.get_world_size() * local_bs
5987            input_cpu = torch.randn(global_bs, 2)
5988            target = torch.randn(global_bs, 2)
5989            loss = nn.MSELoss()
5990
5991            # disabling cudnn.
5992            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
5993            # numerical issue created by the divergent code path.
5994            with torch.backends.cudnn.flags(False):
5995                # check two model parameters over 5 iterations
5996                self._test_DDP_niter(
5997                    model_gpu,
5998                    model_DDP,
5999                    input_cpu.cuda(gpus[0]),
6000                    target.cuda(gpus[0]),
6001                    loss,
6002                    local_bs,
6003                    rank,
6004                    global_bs,
6005                    True,
6006                )
6007                self._barrier()
6008
6009        @skip_but_pass_in_sandcastle_if(
6010            BACKEND not in DistTestCases.backend_feature["ddp"],
6011            f"The {BACKEND} backend does not support DistributedDataParallel",
6012        )
6013        @skip_if_no_gpu
6014        @require_world_size(2)
6015        def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
6016            group, group_id, rank = self._init_global_test()
6017            # DDP does not support replicating BN layers within a process, hence
6018            # testing with one module replica per process
6019            gpus = [rank]
6020
6021            model = nn.BatchNorm1d(2)
6022
6023            # single gpu training setup
6024            model_gpu = copy.deepcopy(model)
6025            model_gpu.cuda(gpus[0])
6026
6027            # DDP training setup
6028            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
6029            model_DDP.cuda(gpus[0])
6030            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
6031
6032            local_bs = 1
6033            global_bs = dist.get_world_size()
6034            input_cpu = torch.randn(global_bs, 2)
6035            target = torch.randn(global_bs, 2)
6036            loss = nn.MSELoss()
6037
6038            # disabling cudnn.
6039            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
6040            # numerical issue created by the divergent code path.
6041            with torch.backends.cudnn.flags(False):
6042                # check two model parameters over 5 iterations
6043                self._test_DDP_niter(
6044                    model_gpu,
6045                    model_DDP,
6046                    input_cpu.cuda(gpus[0]),
6047                    target.cuda(gpus[0]),
6048                    loss,
6049                    local_bs,
6050                    rank,
6051                    global_bs,
6052                    True,
6053                )
6054                self._barrier()
6055
6056        @skip_but_pass_in_sandcastle_if(
6057            BACKEND not in DistTestCases.backend_feature["ddp"],
6058            f"The {BACKEND} backend does not support DistributedDataParallel",
6059        )
6060        @skip_if_no_gpu
6061        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
6062            self,
6063        ):
6064            group, group_id, rank = self._init_global_test()
6065            model = nn.parallel.DistributedDataParallel(
6066                ONLY_SBN_NET.cuda(rank), device_ids=[rank]
6067            )
6068
6069            input_var = []
6070            for i in range(dist.get_world_size()):
6071                input_var_rank = torch.cat(
6072                    [
6073                        torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
6074                        torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
6075                    ],
6076                    dim=1,
6077                )
6078                input_var.append(input_var_rank)
6079
6080            all_input_var = torch.cat(
6081                [
6082                    x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
6083                    for x in input_var
6084                ],
6085                dim=1,
6086            ).cuda(rank)
6087
6088            for i in range(100):
6089                y = model(input_var[rank].cuda(rank))
6090                y.mean().backward()
6091
6092            running_mean, running_var = (
6093                model.module.running_mean,
6094                model.module.running_var,
6095            )
6096            torch.testing.assert_close(running_mean, all_input_var.mean(1))
6097            torch.testing.assert_close(running_var, all_input_var.var(1))
6098
6099        @skip_but_pass_in_sandcastle_if(
6100            BACKEND not in DistTestCases.backend_feature["ddp"],
6101            f"The {BACKEND} backend does not support DistributedDataParallel",
6102        )
6103        @skip_if_no_gpu
6104        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
6105            group, group_id, rank = self._init_global_test()
6106            # only do single GPU per process
6107            gpus = [rank]
6108
6109            # cpu training setup
6110            model = BN_NET
6111
6112            num_processes = dist.get_world_size()
6113            local_bs = rank + 2
6114            bs_offset = int((rank + 3) * rank / 2)
6115            global_bs = int((num_processes + 3) * num_processes / 2)
6116
6117            self._test_DistributedDataParallel_SyncBatchNorm(
6118                gpu_subset=gpus,
6119                rank=rank,
6120                local_bs=local_bs,
6121                global_bs=global_bs,
6122                offset=bs_offset,
6123            )
6124
6125        @skip_but_pass_in_sandcastle_if(
6126            BACKEND not in DistTestCases.backend_feature["ddp"],
6127            f"The {BACKEND} backend does not support DistributedDataParallel",
6128        )
6129        @skip_if_no_gpu
6130        def test_DistributedDataParallel_SyncBatchNorm_half(self):
6131            group, group_id, rank = self._init_global_test()
6132
6133            model = copy.deepcopy(BN_NET)
6134            model = model.half()
6135            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
6136            model = nn.parallel.DistributedDataParallel(model.cuda(rank), device_ids=[rank])
6137            inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
6138            # Check that forward/backward do not error with dtype mismatch
6139            out = model(inp)
6140            self.assertEqual(out.dtype, torch.float16)
6141            out.sum().backward()
6142            for param in model.parameters():
6143                self.assertEqual(param.grad.dtype, torch.float16)
6144
6145        def _test_ddp_logging_data(self, is_gpu):
6146            rank = dist.get_rank()
6147            model_DDP = copy.deepcopy(DDP_NET)
6148            if is_gpu:
6149                model_DDP = nn.parallel.DistributedDataParallel(
6150                    model_DDP.cuda(rank), device_ids=[rank]
6151                )
6152            else:
6153                model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
6154
6155            # dummy data initialization
6156            local_bs = 2
6157            batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
6158            if is_gpu:
6159                input = input.cuda(rank)
6160                target = target.cuda(rank)
6161
6162            model_DDP._set_ddp_runtime_logging_sample_rate(2)
6163
6164            for idx in range(20):
6165                offset = rank * local_bs
6166
6167                # DDP training, DDP scatters subsets of input to nodes/GPUs
6168                self._test_DDP_helper(
6169                    model_DDP,
6170                    input[offset : offset + local_bs],
6171                    target[offset : offset + local_bs],
6172                    loss,
6173                    1,
6174                )
6175
6176                self._model_step_with_zero_grad(model_DDP)
6177
6178                # Verify DDP logging data is sampled as expected
6179                # If it has ran more than 10 iterations and this is
6180                # the sampled iteration for measuring run time stats,
6181                # the run time stats for this idx-th iteration will not
6182                # be zeros.
6183                ddp_logging_data = model_DDP._get_ddp_logging_data()
6184                if idx > 0 and (idx < 10 or idx % 2 == 0):
6185                    self.assertGreaterEqual(
6186                        ddp_logging_data.get("forward_compute_time"), 1
6187                    )
6188                    self.assertGreaterEqual(
6189                        ddp_logging_data.get("backward_compute_time"), 1
6190                    )
6191                    self.assertGreaterEqual(
6192                        ddp_logging_data.get("backward_comm_time"), 1
6193                    )
6194                    self.assertGreaterEqual(
6195                        ddp_logging_data.get("backward_compute_time"),
6196                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
6197                    )
6198                    self.assertGreaterEqual(
6199                        ddp_logging_data.get("backward_comm_time"),
6200                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
6201                    )
6202                    self.assertEqual(ddp_logging_data.get("iteration"), idx)
6203                elif idx > 0:
6204                    # if the idx-th iteration is not sampled to set runtime stats,
6205                    # ddp_logging_data.iteration will not be updated to current
6206                    # iteration.
6207                    self.assertNotEqual(ddp_logging_data.get("iteration"), idx)
6208
6209                # Shuffle the input so that DDP input is different
6210                input = input[torch.randperm(batch_size)]
6211
6212            return model_DDP
6213
6214        @skip_but_pass_in_sandcastle_if(
6215            BACKEND == "nccl", "nccl does not support DDP on CPU models"
6216        )
6217        def test_ddp_logging_data_cpu(self):
6218            def parse_env(var):
6219                return os.environ[var] if var in os.environ else "N/A"
6220
6221            dist.set_debug_level(dist.DebugLevel.INFO)
6222            group, group_id, rank = self._init_global_test()
6223            model_DDP = self._test_ddp_logging_data(is_gpu=False)
6224
6225            ddp_logging_data = model_DDP._get_ddp_logging_data()
6226            self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
6227            self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
6228            self.assertEqual(ddp_logging_data.get("module_name"), "Net")
6229            self.assertEqual(ddp_logging_data.get("device_ids"), "")
6230            # output_device is -1 in default if it is not set, e.g.
6231            # output_device of CPU training is -1.
6232            self.assertEqual(ddp_logging_data.get("output_device"), -1)
6233            self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
6234            self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
6235            self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
6236            self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
6237            self.assertEqual(
6238                ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
6239            )
6240            self.assertEqual(ddp_logging_data.get("iteration"), 18)
6241            params = list(model_DDP.parameters())
6242            num_params = 0
6243            param_size = 0
6244            params = list(filter(lambda parameter: parameter.requires_grad, params))
6245            for p in params:
6246                num_params += 1
6247                param_size += p.numel() * p.element_size()
6248            self.assertEqual(ddp_logging_data.get("dtypes"), "float")
6249            self.assertEqual(
6250                ddp_logging_data.get("total_parameter_size_bytes"), param_size
6251            )
6252            self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
6253            self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
6254            self.assertEqual(
6255                ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
6256            )
6257            self.assertEqual(
6258                ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
6259            )
6260            self.assertEqual(
6261                ddp_logging_data.get("torch_distributed_debug"),
6262                parse_env("TORCH_DISTRIBUTED_DEBUG"),
6263            )
6264            self.assertEqual(
6265                ddp_logging_data.get("cuda_visible_devices"),
6266                parse_env("CUDA_VISIBLE_DEVICES"),
6267            )
6268            if ddp_logging_data.get("backend_name") == "gloo":
6269                self.assertEqual(
6270                    ddp_logging_data.get("gloo_socket_ifname"),
6271                    parse_env("GLOO_SOCKET_IFNAME"),
6272                )
6273                self.assertEqual(
6274                    ddp_logging_data.get("gloo_device_transport"),
6275                    parse_env("GLOO_DEVICE_TRANSPORT"),
6276                )
6277                default_gloo_threads = 2
6278                self.assertEqual(
6279                    ddp_logging_data.get("gloo_num_threads"),
6280                    default_gloo_threads,
6281                )
6282
6283            self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
6284            self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
6285            self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
6286            self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
6287            self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
6288            self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
6289            # test runtime logging fields
6290            # Note: DETAIL debug mode logs DDP logging data to stdout and
6291            # thus accesses std::map, which fills in a default value for the
6292            # type if it didn't exist.
6293            self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
6294            self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
6295            self.assertEqual(
6296                ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
6297            )
6298            grad_ready_order = ddp_logging_data.get(
6299                "prev_iteration_grad_ready_order_indices"
6300            )
6301            expected_order = list(reversed([str(x) for x in range(3)]))
6302            self.assertEqual(grad_ready_order, ", ".join(expected_order))
6303            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
6304            self.assertEqual(bucket_indices, " ".join(expected_order))
6305            # It is hard to test accurate latency, but it can test whether the latency is
6306            # a valid value and in the expected range.
6307            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
6308            self.assertGreaterEqual(
6309                ddp_logging_data.get("avg_backward_compute_time"), 1
6310            )
6311            self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
6312            self.assertGreaterEqual(
6313                ddp_logging_data.get("avg_backward_compute_time"),
6314                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
6315            )
6316            self.assertGreaterEqual(
6317                ddp_logging_data.get("avg_backward_comm_time"),
6318                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
6319            )
6320            # Test host-side times are roughly in the order that we expect
6321            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
6322            bwd_comp_start_host_side_time = ddp_logging_data.get(
6323                "backward_compute_time_start"
6324            )
6325            bwd_comp_end_host_side_time = ddp_logging_data.get(
6326                "backward_compute_time_end"
6327            )
6328            bwd_comm_start_host_side_time = ddp_logging_data.get(
6329                "backward_comm_time_start"
6330            )
6331            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
6332            self.assertGreaterEqual(
6333                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
6334            )
6335            self.assertGreaterEqual(
6336                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
6337            )
6338            self.assertGreaterEqual(
6339                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
6340            )
6341            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
6342
6343            # test larger net with mixed data types, verify multiple bucket sizes
6344            model = LargeNet()
6345            model.float()
6346            model.fc1.double()
6347            model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
6348            ddp_logging_data = model_DDP._get_ddp_logging_data()
6349            params = list(model_DDP.parameters())
6350            self.assertEqual(
6351                ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
6352            )
6353            bucket_sizes = [
6354                params[1].numel() * params[1].element_size(),
6355                params[0].numel() * params[0].element_size(),
6356            ]
6357            self.assertEqual(
6358                ddp_logging_data.get("bucket_sizes"),
6359                ", ".join(str(x) for x in bucket_sizes),
6360            )
6361            self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")
6362
6363        @skip_but_pass_in_sandcastle_if(
6364            BACKEND not in DistTestCases.backend_feature["ddp"],
6365            f"The {BACKEND} backend does not support DistributedDataParallel",
6366        )
6367        @skip_if_no_gpu
6368        def test_ddp_logging_data_gpu(self):
6369            group, group_id, rank = self._init_global_test()
6370            model_DDP = self._test_ddp_logging_data(is_gpu=True)
6371            ddp_logging_data = model_DDP._get_ddp_logging_data()
6372            self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
6373            self.assertEqual(ddp_logging_data.get("output_device"), rank)
6374            grad_ready_order = ddp_logging_data.get(
6375                "prev_iteration_grad_ready_order_indices"
6376            )
6377            expected_order = list(reversed([str(x) for x in range(3)]))
6378            self.assertEqual(grad_ready_order, ", ".join(expected_order))
6379            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
6380            self.assertEqual(bucket_indices, " ".join(expected_order))
6381            # test runtime logging fields
6382            # It is hard to test accurate latency, but it can test whether the latency is
6383            # a valid value and in the expected range.
6384            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
6385            self.assertGreaterEqual(
6386                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
6387            )
6388            self.assertGreaterEqual(
6389                ddp_logging_data.get("avg_backward_compute_time"),
6390                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
6391            )
6392            self.assertGreaterEqual(
6393                ddp_logging_data.get("avg_backward_comm_time"),
6394                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
6395            )
6396            # Test host-side times are roughly in the order that we expect
6397            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
6398            bwd_comp_start_host_side_time = ddp_logging_data.get(
6399                "backward_compute_time_start"
6400            )
6401            bwd_comp_end_host_side_time = ddp_logging_data.get(
6402                "backward_compute_time_end"
6403            )
6404            bwd_comm_start_host_side_time = ddp_logging_data.get(
6405                "backward_comm_time_start"
6406            )
6407            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
6408            self.assertGreaterEqual(
6409                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
6410            )
6411            self.assertGreaterEqual(
6412                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
6413            )
6414            self.assertGreaterEqual(
6415                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
6416            )
6417            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
6418
6419        @skip_but_pass_in_sandcastle_if(
6420            BACKEND == "nccl", "nccl does not support DDP on CPU models"
6421        )
6422        def test_static_graph_api_cpu(self):
6423            model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
6424            expected_err = "should be called before training loop starts"
6425            with self.assertRaisesRegex(RuntimeError, expected_err):
6426                local_bs = 2
6427                batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
6428                offset = dist.get_rank() * local_bs
6429
6430                # DDP training, DDP scatters subsets of input to nodes/GPUs
6431                self._test_DDP_helper(
6432                    model_DDP,
6433                    input[offset : offset + local_bs],
6434                    target[offset : offset + local_bs],
6435                    loss,
6436                    1,
6437                )
6438                model_DDP._set_static_graph()
6439
6440            # Verify error was logged in ddp_logging_data.
6441            verify_ddp_error_logged(model_DDP, expected_err)
6442
6443        @skipIfNoTorchVision
6444        def test_SyncBatchNorm_process_group(self):
6445            # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
6446            # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
6447            # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
6448
6449            process_ids = 0
6450            process_group = torch.distributed.new_group([process_ids])
6451            res50_model = torchvision.models.resnet50()
6452            res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
6453                copy.deepcopy(res50_model), process_group
6454            )
6455            process_group_sync = res50_model_sync.layer1[0].bn1.process_group
6456            self.assertEqual(process_group_sync, process_group)
6457
6458        def _run_reduction_test(
6459            self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
6460        ):
6461            if reduction_fn != dist.all_reduce and dst is None:
6462                raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
6463            if dst is not None:
6464                reduction_fn(tensor, dst, op)
6465                # Only destination rank tensor is expected to have final result.
6466                if dist.get_rank() == dst:
6467                    self.assertEqual(tensor, expected_tensor)
6468            else:
6469                reduction_fn(tensor, op)
6470                self.assertEqual(tensor, expected_tensor)
6471
6472        @require_backend_is_available({"nccl"})
6473        @skip_if_lt_x_gpu(2)
6474        def test_nccl_backend_bool_allreduce(self):
6475            torch.cuda.set_device(self.rank)
6476            # Run all_reduce with PRODUCT
6477            element = self.rank % 2 == 0
6478            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
6479                input_tensor = torch.tensor([element, element]).to(self.rank)
6480                self._run_reduction_test(
6481                    input_tensor, torch.tensor([False, False]).to(self.rank), op
6482                )
6483                # Ensure that all ranks contributing True (cast to 1) results in the
6484                # correct reduction.
6485                input_tensor = torch.tensor([True, True]).to(self.rank)
6486                expected_tensor = input_tensor.clone()
6487                self._run_reduction_test(input_tensor, expected_tensor, op)
6488
6489            # Run all_reduce with SUM
6490            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
6491                input_tensor = torch.tensor([element, element]).to(self.rank)
6492                self._run_reduction_test(
6493                    input_tensor, torch.tensor([True, True]).to(self.rank), op
6494                )
6495            # TODO: NCCL backend does not work correctly for bitwise reduction ops
6496            # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
6497            # these once it is supported.
6498
6499        @require_backend_is_available({"nccl"})
6500        @skip_if_lt_x_gpu(2)
6501        def test_nccl_backend_bool_allgather(self):
6502            torch.cuda.set_device(self.rank)
6503            inp = {0: [True, True], 1: [False, True]}
6504            input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
6505            # Preserve a copy of the tensor to compare against after allgather.
6506            input_tensor_copy = input_tensor.clone()
6507            tensor_list = [
6508                torch.tensor([False, False]).to(self.rank)
6509                for _ in range(dist.get_world_size())
6510            ]
6511            dist.all_gather(tensor_list, input_tensor)
6512
6513            self.assertEqual(len(tensor_list), dist.get_world_size())
6514            for i, t in enumerate(tensor_list):
6515                expected = torch.tensor(inp[i % 2]).to(self.rank)
6516                self.assertEqual(t, expected)
6517            # Ensure that the input tensor is not modified, since this collective
6518            # does not modify its input.
6519            self.assertEqual(input_tensor_copy, input_tensor)
6520
6521        @require_backend_is_available({"nccl"})
6522        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
6523        def test_nccl_backend_bool_reduce(self):
6524            torch.cuda.set_device(self.rank)
6525            inp = {0: [True, True], 1: [False, False]}
6526            # Run reduce() with product op
6527            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
6528                input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
6529                expected = torch.tensor([False, False]).to(self.rank)
6530                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
6531                # Ensure that all ranks contributing True (cast to 1) results in the
6532                # correct reduction.
6533                input_tensor = torch.tensor([True, True]).to(self.rank)
6534                expected_tensor = input_tensor.clone()
6535                self._run_reduction_test(
6536                    input_tensor, expected_tensor, op, dist.reduce, dst=0
6537                )
6538
6539            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
6540                input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
6541                expected = (
6542                    torch.tensor([True, True]).to(self.rank)
6543                    if self.rank == 0
6544                    else input_tensor.clone()
6545                )
6546                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
6547
6548        @require_backend_is_available({"nccl"})
6549        @skip_if_lt_x_gpu(2)
6550        def test_nccl_backend_bool_broadcast(self):
6551            tensor_size = 10
6552            bcast_tensor = torch.tensor(
6553                [
6554                    (random.random() < 0.5 if self.rank == 0 else False)
6555                    for _ in range(tensor_size)
6556                ]
6557            ).to(self.rank)
6558            dist.broadcast(bcast_tensor, src=0)
6559            # Now allgather and ensure the tensors are equal.
6560            tensor_list = [
6561                torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
6562                for _ in range(dist.get_world_size())
6563            ]
6564            dist.all_gather(tensor_list, bcast_tensor)
6565            expected = tensor_list[0]
6566            for tensor in tensor_list[1:]:
6567                self.assertEqual(tensor, expected)
6568
6569        @skip_but_pass_in_sandcastle_if(
6570            BACKEND not in DistTestCases.backend_feature["ddp"],
6571            f"The {BACKEND} backend does not support DistributedDataParallel",
6572        )
6573        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
6574        def test_DistributedSampler_padding(self):
6575            # Tests padding of distributed sampler.
6576            world_size = dist.get_world_size()
6577
6578            # Simulates the 'casual' dataset size
6579            dataset_size = 100 + world_size + 1
6580            dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)]
6581
6582            # Simulates the 'tiny' dataset size
6583            dataset_tiny_size = max(world_size // 2 - 1, 1)
6584            dataset_tiny = [
6585                torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)
6586            ]
6587
6588            # Specifying drop_last=True will cause the tail of the data to be dropped.
6589            dist_sampler = DistributedSampler(dataset=dataset, drop_last=True)
6590            local_num_samples, local_dataset_size = (
6591                dist_sampler.num_samples,
6592                dist_sampler.total_size,
6593            )
6594            # The effective dataset size should be the greatest integer that is <=
6595            # dataset_size that is divisible by the world_size. This is to ensure each
6596            # rank processes the same number of samples.
6597            effective_dataset_size = (
6598                math.ceil((dataset_size - world_size) / world_size)
6599                if dataset_size % world_size != 0
6600                else dataset_size / world_size
6601            )
6602            self.assertEqual(local_num_samples, effective_dataset_size)
6603            self.assertEqual(local_dataset_size, local_num_samples * world_size)
6604            indices_list = list(iter(dist_sampler))
6605            self.assertEqual(len(indices_list), local_num_samples)
6606
6607            def validate_global_samples(local_num_samples):
6608                # Ensure that each rank processes the same number of samples.
6609                world_samples = [
6610                    torch.LongTensor([0]).to(self.rank) for _ in range(world_size)
6611                ]
6612                dist.all_gather(
6613                    world_samples, torch.tensor([local_num_samples]).to(self.rank)
6614                )
6615                world_samples = [sample.item() for sample in world_samples]
6616                self.assertEqual(len(set(world_samples)), 1)
6617
6618            validate_global_samples(local_num_samples)
6619
6620            # drop_last=False is the default and will add additional indices to be sampled,
6621            # increasing the effective dataset size.
6622            dist_sampler_added_samples = DistributedSampler(dataset=dataset)
6623            local_num_samples, local_dataset_size = (
6624                dist_sampler_added_samples.num_samples,
6625                dist_sampler_added_samples.total_size,
6626            )
6627            # The effective dataset size is the smallest integer that is >= dataset_size
6628            # and divisible by the world size.
6629            self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size))
6630            self.assertEqual(local_dataset_size, local_num_samples * world_size)
6631            indices_list = list(iter(dist_sampler_added_samples))
6632            self.assertEqual(len(indices_list), local_num_samples)
6633
6634            # Ensure that each rank processes the same number of samples.
6635            validate_global_samples(local_num_samples)
6636
6637            # Ensure additional samples are padded even when
6638            # the extremely small dataset is given.
6639            dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny)
6640            local_num_samples, local_dataset_size = (
6641                dist_sampler_added_samples_tiny.num_samples,
6642                dist_sampler_added_samples_tiny.total_size,
6643            )
6644            self.assertEqual(
6645                local_num_samples, math.ceil(dataset_tiny_size / world_size)
6646            )
6647            self.assertEqual(local_dataset_size, local_num_samples * world_size)
6648            indices_list = list(iter(dist_sampler_added_samples_tiny))
6649            self.assertEqual(len(indices_list), local_num_samples)
6650            validate_global_samples(local_num_samples)
6651
6652        def _test_allgather_object(self, subgroup=None):
6653            # Only set device for NCCL backend since it must use GPUs.
6654
6655            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
6656
6657            backend = os.environ["BACKEND"]
6658            if backend == "nccl":
6659                # Case where rank != GPU device.
6660                next_rank = (self.rank + 1) % int(self.world_size)
6661                torch.cuda.set_device(next_rank)
6662
6663            # If GPU test, add object with GPU tensor
6664            if backend == "nccl":
6665                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
6666
6667            output_gathered = [None for _ in range(dist.get_world_size())]
6668            dist.all_gather_object(
6669                output_gathered,
6670                gather_objects[self.rank % len(gather_objects)],
6671                group=subgroup,
6672            )
6673
6674            for i, val in enumerate(output_gathered):
6675                expected = gather_objects[i % len(gather_objects)]
6676                self.assertEqual(val, expected)
6677
6678        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6679        @require_n_gpus_for_nccl_backend(
6680            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
6681        )
6682        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
6683        def test_all_gather_object_default_pg(self):
6684            return self._test_allgather_object()
6685
6686        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6687        @require_n_gpus_for_nccl_backend(
6688            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
6689        )
6690        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
6691        def test_all_gather_object_subgroup(self):
6692            default = _get_default_group()
6693            backend = dist.get_backend(default)
6694            subgroup = dist.new_group(backend=backend)
6695            return self._test_allgather_object(subgroup=subgroup)
6696
6697        def _test_gather_object(self, pg=None):
6698            # Ensure stateful objects can be gathered
6699            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
6700            my_rank = dist.get_rank(pg)
6701
6702            backend = os.environ["BACKEND"]
6703            if backend == "nccl":
6704                # Case where rank != GPU device.
6705                next_rank = (self.rank + 1) % int(self.world_size)
6706                torch.cuda.set_device(next_rank)
6707
6708            # If GPU test, add object with GPU tensor
6709            if backend == "nccl":
6710                gather_objects.append(Foo(torch.randn(3, 3, device=my_rank)))
6711
6712            output_gathered = [None for _ in range(dist.get_world_size(pg))]
6713            gather_on_rank = 0
6714            dist.gather_object(
6715                gather_objects[self.rank % len(gather_objects)],
6716                object_gather_list=output_gathered
6717                if my_rank == gather_on_rank
6718                else None,
6719                dst=gather_on_rank,
6720                group=pg,
6721            )
6722            if my_rank != gather_on_rank:
6723                self.assertEqual(
6724                    output_gathered, [None for _ in range(dist.get_world_size())]
6725                )
6726            else:
6727                for i, val in enumerate(output_gathered):
6728                    expected = gather_objects[i % len(gather_objects)]
6729                    self.assertEqual(val, expected)
6730
6731            # Validate errors when objects can't be pickled.
6732            class Bar:
6733                pass
6734
6735            b = Bar()
6736            gather_objects = [b for _ in range(dist.get_world_size())]
6737            with self.assertRaisesRegex(AttributeError, "Can't pickle local object"):
6738                dist.all_gather_object(
6739                    [None for _ in range(dist.get_world_size())],
6740                    gather_objects[self.rank],
6741                    group=pg,
6742                )
6743
6744        @skip_but_pass_in_sandcastle_if(
6745            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
6746        )
6747        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6748        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
6749        def test_gather_object(self):
6750            return self._test_gather_object()
6751
6752        @skip_but_pass_in_sandcastle_if(
6753            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
6754        )
6755        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6756        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
6757        def test_gather_object_subgroup(self):
6758            default = _get_default_group()
6759            backend = dist.get_backend(default)
6760            subgroup = dist.new_group(backend=backend)
6761            return self._test_gather_object(subgroup)
6762
6763        def validate_net_equivalence(self, net):
6764            # Helper to validate synchronization of nets across ranks.
6765            net_module_states = list(net.module.state_dict().values())
6766            # Check that all tensors in module's state_dict() are equal.
6767            for t in net_module_states:
6768                tensor_list = [
6769                    torch.zeros_like(t) for _ in range(dist.get_world_size())
6770                ]
6771                dist.all_gather(tensor_list, t)
6772                for tensor in tensor_list:
6773                    self.assertEqual(tensor, t)
6774
6775        @skip_if_lt_x_gpu(2)
6776        @skip_but_pass_in_sandcastle_if(
6777            BACKEND not in DistTestCases.backend_feature["ddp"],
6778            f"The {BACKEND} backend does not support DistributedDataParallel",
6779        )
6780        def test_ddp_sync_module_states(self):
6781            # Test that after calling _sync_module_states, models across ranks
6782            # are the same and are equal to the model on the input rank.
6783            dim = 2
6784            rank = self.rank
6785            rank_to_broadcast = 1
6786            # Seed to ensure that ranks are initialized with different initial models.
6787            torch.manual_seed(rank)
6788            model = nn.Linear(dim, dim, bias=False)
6789            net = torch.nn.parallel.DistributedDataParallel(
6790                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
6791            )
6792            new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
6793            net.module = copy.deepcopy(new_model)
6794            # Assert params are different
6795            net_module_states = list(net.module.state_dict().values())
6796            for t in net_module_states:
6797                tensor_list = [
6798                    torch.zeros_like(t) for _ in range(dist.get_world_size())
6799                ]
6800                dist.all_gather(tensor_list, t)
6801                for i, tensor in enumerate(tensor_list):
6802                    if i == rank:
6803                        self.assertEqual(t, tensor)
6804                    else:
6805                        # tensor from another rank should be different.
6806                        self.assertNotEqual(t, tensor)
6807
6808            _sync_module_states(
6809                module=net.module,
6810                process_group=net.process_group,
6811                broadcast_bucket_size=net.broadcast_bucket_size,
6812                src=rank_to_broadcast,
6813                params_and_buffers_to_ignore=net.parameters_to_ignore,
6814            )
6815            # Now all model params should be the same.
6816            self.validate_net_equivalence(net)
6817            # Since the network params were broadcast from rank_to_broadcast, validate that
6818            # they are the same as new_model on rank_to_broadcast.
6819            if rank == rank_to_broadcast:
6820                expected_states = new_model.state_dict().values()
6821                for t, expected in zip(net_module_states, expected_states):
6822                    self.assertEqual(t, expected)
6823
6824        @skip_if_lt_x_gpu(2)
6825        @skip_but_pass_in_sandcastle_if(
6826            BACKEND not in DistTestCases.backend_feature["ddp"],
6827            f"The {BACKEND} backend does not support DistributedDataParallel",
6828        )
6829        def test_ddp_grad_div_uneven_inputs(self):
6830            # Test gradient division during training with join() API. If
6831            # divide_by_initial_world_size=False, we scale by the effective world
6832            # size when allreducing grads.
6833            dim = 5
6834            batch = 1
6835            grad_scale = 50
6836            rank = self.rank
6837            model = nn.Linear(dim, dim, bias=False)
6838            inp = torch.ones(batch, dim, device=self.rank) * grad_scale
6839            net = torch.nn.parallel.DistributedDataParallel(
6840                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
6841            )
6842            n_iters = 3
6843            if self.rank > 0:
6844                n_iters += 2
6845
6846            with net.join(divide_by_initial_world_size=False):
6847                for _ in range(n_iters):
6848                    loss = net(inp).sum()
6849                    loss.backward()
6850                    # The grad is always expected_grad, since we divide by the number
6851                    # of currently active processes and inactive processes contribute
6852                    # zero gradient. If we kept dividing by static initial world
6853                    # size as processes leave, the grad would be smaller.
6854                    expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale
6855                    param = next(iter(net.parameters()))
6856                    self.assertEqual(expected_grad, param.grad)
6857                    # Avoid accumulating grads so that it's the same every iteration
6858                    net.zero_grad()
6859                    torch.cuda.synchronize(device=self.rank)
6860
6861            # If divide_by_initial_world_size=True (default), we always scale grads
6862            # by the initial world_size.
6863            with net.join(divide_by_initial_world_size=True):
6864                for i in range(n_iters):
6865                    loss = net(inp).sum()
6866                    loss.backward()
6867                    effective_ws = dist.get_world_size()
6868                    if i >= 3:
6869                        effective_ws -= 1
6870                    expected_grad = (
6871                        torch.ones(dim, dim, device=self.rank)
6872                        * grad_scale
6873                        * effective_ws
6874                    ) / dist.get_world_size()
6875                    param = next(iter(net.parameters()))
6876                    self.assertEqual(expected_grad, param.grad)
6877                    # Avoid accumulating grad so that it's the same every iteration.
6878                    net.zero_grad()
6879                    torch.cuda.synchronize(device=self.rank)
6880
6881        def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None):
6882            """Runs DDP based model training and captures profiles.
6883            This test will do two profiler runs.
6884            1. An inital basic run to check if profiler events are correctly captured.
6885            2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state.
6886
6887            args
6888                profiler_ctx : Profiler context manager for pass 1
6889                profiler_ctx2 : Profiler context manager for pass 2.
6890                    This can be left out as None, in which case a deepcopy
6891                    of profiler_ctx is used.
6892            Returns:
6893                prof: Instantiated profiler object that can be used for post analysis.
6894            """
6895            batch = 3
6896            dim = 10
6897            num_iters = 6
6898            torch.cuda.set_device(self.rank)
6899            model = nn.Linear(dim, dim, bias=False)
6900            inp = torch.rand(batch, dim, device=self.rank)
6901            net = torch.nn.parallel.DistributedDataParallel(
6902                model.cuda(self.rank),
6903                device_ids=[self.rank],
6904            )
6905            if profiler_ctx2 is None:
6906                profiler_ctx2 = copy.deepcopy(profiler_ctx)
6907
6908            with profiler_ctx as prof:
6909                for i in range(num_iters):
6910                    loss = net(inp).sum()
6911                    loss.backward()
6912
6913            all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
6914            events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
6915            event_count = sum(e.count for e in events)
6916            self.assertEqual(event_count, num_iters)
6917            for event in events:
6918                self.assertTrue(event.is_async)
6919                self.assertEqual(event.name, all_reduce_event_name)
6920
6921            broadcast_event_name = f"{dist.get_backend()}:broadcast"
6922            broadcast_events = get_profiling_event(broadcast_event_name, prof, dedup_gpu_user_annotation=True)
6923            event_count = sum(e.count for e in broadcast_events)
6924            # Broadcast is called during rebuild_buckets
6925            self.assertGreaterEqual(event_count, 1)
6926            for event in broadcast_events:
6927                self.assertEqual(event.name, broadcast_event_name)
6928
6929            # Run DDP with profiling for a few iterations, then enable profiling
6930            # for a single pass, and ensure it is recorded. This tests that the
6931            # thread local state is correctly updated.
6932            net = torch.nn.parallel.DistributedDataParallel(
6933                model.cuda(self.rank),
6934                device_ids=[self.rank],
6935                find_unused_parameters=True,
6936            )
6937            for i in range(3):
6938                loss = net(inp).sum()
6939                loss.backward()
6940            # Now enable the profiler.
6941            with profiler_ctx2 as prof:
6942                loss = net(inp).sum()
6943                loss.backward()
6944
6945            events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
6946            self.assertGreaterEqual(len(events), 1)
6947            self.assertGreaterEqual(events[0].count, 1)
6948            self.assertEqual(events[0].name, all_reduce_event_name)
6949            for event in events:
6950                self.assertTrue(event.is_async)
6951            # Ensure searching unused parameters was profiled
6952            events = get_profiling_event("search_unused_parameters", prof)
6953            self.assertEqual(len(events), 1)
6954
6955            return prof
6956
6957        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6958        @skip_if_lt_x_gpu(2)
6959        @skip_but_pass_in_sandcastle("Currently failing in NVIDIA internal CI")
6960        def test_ddp_profiling_autograd_profiler(self):
6961            autograd_profiler_ctx = torch.autograd.profiler.profile()
6962            return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx)
6963
6964        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
6965        @skip_if_lt_x_gpu(2)
6966        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
6967        @skip_but_pass_in_sandcastle_if(
6968            IS_MACOS or IS_WINDOWS,
6969            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
6970        )
6971        def test_ddp_profiling_torch_profiler(self):
6972            cpu_act = torch.profiler.ProfilerActivity.CPU
6973            cuda_act = torch.profiler.ProfilerActivity.CUDA
6974            torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act])
6975            prof = self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx)
6976
6977            if dist.get_backend() != "nccl":
6978                return
6979
6980            # Note comment out the "os.remove(trace_file)" in `get_profiler_nccl_meta()`
6981            # to debug any mismatches.
6982            nccl_meta_events = get_profiler_nccl_meta(prof)
6983            self.assertGreater(len(nccl_meta_events), 0)
6984
6985            nccl_meta = self._sanity_check_profiler_nccl_meta(nccl_meta_events)
6986
6987            # additionally check the specific collectives in this test case
6988            self.assertEqual(len(nccl_meta["allreduce"]), 2)
6989            self.assertEqual(len(nccl_meta["wait"]), 1)
6990
6991            # check allreduce message sizes
6992            a0 = nccl_meta["allreduce"][0]
6993            self.assertEqual(a0["Out msg nelems"], 100, msg=f"{a0}")
6994            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
6995            a1 = nccl_meta["allreduce"][1]
6996            self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}")
6997            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
6998
6999        def _validate_execution_trace_nccl(self, et_file: str) -> None:
7000            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
7001            We test for basic fields in theese nodes in the Execution Trace.
7002            """
7003            with open(et_file) as f:
7004                et = json.load(f)
7005            pg_cfg_node = [n for n in et["nodes"] if n["name"] == "## process_group:init ##"]
7006            self.assertGreaterEqual(len(pg_cfg_node), 1)
7007            nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"]
7008            self.assertEqual(len(nccl_meta_nodes), 3)
7009            per_coll_meta = defaultdict(list)
7010
7011            # Sanity check NCCL metadata nodes
7012            for n in nccl_meta_nodes:
7013                attrs_list = n.get("attrs", [])
7014                self.assertGreater(len(attrs_list), 0)
7015                attrs = {a["name"]: a["value"] for a in attrs_list}
7016
7017                collname = attrs.get("collective_name", "")
7018                self.assertNotEqual(collname, "")
7019                self.assertNotEqual(attrs.get("dtype", ""), "")
7020
7021                per_coll_meta[collname].append(attrs)
7022                if collname in {"wait"}:
7023                    continue
7024
7025                self.assertEqual(attrs["pg_name"], "0")   # yes this is a string
7026                self.assertEqual(attrs["pg_desc"], "default_pg")
7027                self.assertEqual(attrs["pg_size"], 2)
7028
7029                self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0)
7030                self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0)
7031                self.assertTrue("in_split_size" in attrs.keys())
7032                self.assertTrue("out_split_size" in attrs.keys())
7033                self.assertEqual(attrs.get("global_rank_start", -1), 0)
7034                self.assertEqual(attrs.get("global_rank_stride", -1), 1)
7035
7036            # print(per_coll_meta)
7037            self.assertEqual(len(per_coll_meta["allreduce"]), 2)
7038            self.assertEqual(len(per_coll_meta["wait"]), 1)
7039
7040            # check allreduce message sizes
7041            a0 = per_coll_meta["allreduce"][0]
7042            self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}")
7043            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
7044            a1 = per_coll_meta["allreduce"][1]
7045            self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
7046            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
7047
7048
7049        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7050        @skip_if_lt_x_gpu(2)
7051        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
7052        @skip_but_pass_in_sandcastle_if(
7053            IS_MACOS or IS_WINDOWS,
7054            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
7055        )
7056        @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.")
7057        def test_ddp_profiling_execution_trace(self):
7058            self.assertEqual(dist.get_backend(), "nccl")
7059            # Create a temp file to save execution trace data
7060            fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
7061            fp.close()
7062            et_file = fp.name
7063            et = ExecutionTraceObserver().register_callback(et_file)
7064
7065            # first profiler context need not have ET
7066            torch_profiler_ctx1 = torch.profiler.profile(
7067                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
7068            )
7069            # collect ET in second profiler pass
7070            torch_profiler_ctx2 = torch.profiler.profile(
7071                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
7072                execution_trace_observer=et
7073            )
7074            prof = self._test_ddp_profiling(
7075                profiler_ctx=torch_profiler_ctx1,
7076                profiler_ctx2=torch_profiler_ctx2,
7077            )
7078
7079            print(f"Execution trace saved at {fp.name}")
7080            self._validate_execution_trace_nccl(et_file)
7081
7082
7083        @skip_if_lt_x_gpu(2)
7084        @skip_but_pass_in_sandcastle_if(
7085            BACKEND not in DistTestCases.backend_feature["ddp"],
7086            f"The {BACKEND} backend does not support DistributedDataParallel",
7087        )
7088        def test_ddp_join_model_equivalence(self):
7089            # Verifies equivalence with model training locally and with DDP under
7090            # the join context manager.
7091            batch = 3
7092            dim = 10
7093            learning_rate = 0.03
7094            model = nn.Linear(dim, dim, bias=False)
7095            inp = torch.rand(batch, dim, device=self.rank)
7096            local_model = copy.deepcopy(model)
7097            local_model = local_model.cuda(self.rank)
7098            rank_to_iter_mapping = {
7099                rank: 2 * (rank + 1) for rank in range(dist.get_world_size())
7100            }
7101            # run local model
7102            local_iters = sum(rank_to_iter_mapping.values())
7103            local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate)
7104            for _ in range(local_iters):
7105                local_optim.zero_grad()
7106                out = local_model(inp)
7107                loss = out.sum()
7108                loss.backward()
7109                local_optim.step()
7110
7111            # run DDP model with join API
7112            num_iters = rank_to_iter_mapping[self.rank]
7113            net = torch.nn.parallel.DistributedDataParallel(
7114                model.cuda(self.rank), device_ids=[self.rank]
7115            )
7116            ddp_optim = torch.optim.SGD(
7117                model.parameters(), lr=learning_rate * dist.get_world_size()
7118            )
7119            with net.join():
7120                for i in range(num_iters):
7121                    ddp_optim.zero_grad()
7122                    out = net(inp)
7123                    loss = out.sum()
7124                    loss.backward()
7125                    torch.cuda.synchronize(device=self.rank)
7126                    ddp_optim.step()
7127
7128            # Validate model state dicts are equal
7129            for (_, local_tensor), (_, dist_tensor) in zip(
7130                local_model.state_dict().items(), net.module.state_dict().items()
7131            ):
7132                self.assertEqual(local_tensor, dist_tensor)
7133
7134        def _run_uneven_inputs_test(
7135            self,
7136            test_case,
7137            iteration_mapping,
7138            find_unused_params,
7139        ):
7140            model = test_case.model
7141            inp = test_case.inp
7142            rank = self.rank
7143            sync_interval = test_case.sync_interval
7144            torch.cuda.set_device(rank)
7145            # Ensure all outstanding GPU work is completed so this test runs independently.
7146            dist.barrier()
7147            # Bucket_cap_mb is intentionally low to test allreduce scheduling when
7148            # there are many buckets.
7149            net = torch.nn.parallel.DistributedDataParallel(
7150                model.cuda(rank),
7151                device_ids=[rank],
7152                bucket_cap_mb=1,
7153                find_unused_parameters=find_unused_params,
7154            )
7155            # Register hook if specified
7156            if test_case.hook is not None:
7157                net.register_comm_hook(test_case.state, test_case.hook)
7158                print(f"registered hook {test_case.hook}")
7159
7160            # Determine num iters for this rank via the passed in mapping.
7161            num_iters = iteration_mapping[rank]
7162            # If we throw when earliest rank terminates, we should ensure
7163            # that we iterate for that minimum number of times.
7164            num_iters_tensor = torch.tensor(
7165                [num_iters], device=torch.cuda.current_device()
7166            )
7167            dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN)
7168            min_num_iters = num_iters_tensor.item()
7169            total_iters = 0
7170            if test_case.throw_on_early_termination:
7171                if min_num_iters == num_iters:
7172                    # Early termination rank(s)
7173                    exception_ctx = self.assertRaisesRegex(
7174                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
7175                    )
7176                else:
7177                    # Non early termination rank
7178                    exception_ctx = self.assertRaisesRegex(
7179                        RuntimeError,
7180                        "Detected at least one rank that exhausted inputs.",
7181                    )
7182            else:
7183                exception_ctx = nullcontext()
7184            with exception_ctx:
7185                with net.join(
7186                    throw_on_early_termination=test_case.throw_on_early_termination
7187                ):
7188                    for i in range(num_iters):
7189                        # Use model.no_sync() to disable grad synchronization every
7190                        # sync_interval.
7191                        if i % sync_interval != 0:
7192                            context = net.no_sync()
7193                        else:
7194                            context = nullcontext()
7195                        with context:
7196                            if isinstance(inp, tuple):
7197                                loss = net(*inp).sum()
7198                            else:
7199                                loss = net(inp).sum()
7200                            loss.backward()
7201                            self._model_step(net)
7202                            # Ensure completion of GPU kernels (including allreduce). If the
7203                            # join API is not properly implemented, then this should hang
7204                            # since the allreduce will hang.
7205                            torch.cuda.synchronize(device=rank)
7206                        total_iters += 1
7207            if test_case.throw_on_early_termination:
7208                # Ensure we iterated min_num_iters times.
7209                self.assertEqual(total_iters, min_num_iters)
7210            else:
7211                # Ensure we iterated at least min_num_iters times.
7212                self.assertGreaterEqual(total_iters, min_num_iters)
7213
7214            # Ensure completion of all GPU kernels.
7215            torch.cuda.synchronize(device=rank)
7216            # When throwing on early rank termination, we do not
7217            # broadcast model state from an authoritative rank. All models
7218            # should already be in sync.
7219            if not test_case.throw_on_early_termination:
7220                self.assertTrue(net._authoritative_rank)
7221                # All ranks should have agreed on the same authoritative_rank!
7222                final_rank_tensor = torch.tensor(
7223                    [net._authoritative_rank], device=self.rank
7224                )
7225                tensor_list = [
7226                    torch.zeros_like(final_rank_tensor)
7227                    for _ in range(dist.get_world_size())
7228                ]
7229                dist.all_gather(tensor_list, final_rank_tensor)
7230                max_rank = dist.get_world_size() - 1
7231                self.assertSetEqual(
7232                    {max_rank}, {tensor.item() for tensor in tensor_list}
7233                )
7234                # Ensure that all models are the same across ranks after all have joined.
7235                self.validate_net_equivalence(net)
7236                # Ensure that running with DDP uneven inputs was logged.
7237                ddp_logging_data = net._get_ddp_logging_data()
7238                self.assertTrue(ddp_logging_data.get("join_uneven_inputs"))
7239                dist.barrier()
7240
7241        @skip_if_lt_x_gpu(2)
7242        @skip_but_pass_in_sandcastle_if(
7243            BACKEND not in DistTestCases.backend_feature["ddp"],
7244            f"The {BACKEND} backend does not support DistributedDataParallel",
7245        )
7246        def test_ddp_uneven_inputs_stop_iteration_sync_bn(self):
7247            # Tests that uneven inputs join handler correctly throws StopIteration
7248            # for models with SyncBN or general collective comm when
7249            # throw_on_early_termination=True.
7250            class ModelWithComm(torch.nn.Module):
7251                def __init__(self) -> None:
7252                    super().__init__()
7253                    self.lin = nn.Linear(2, 40, bias=False)
7254
7255                def forward(self, x):
7256                    x = self.lin(x)
7257                    dist.all_reduce(x)
7258                    return x
7259
7260            torch.cuda.set_device(self.rank)
7261            model_bn = BN_NET
7262            model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
7263                copy.deepcopy(model_bn)
7264            ).cuda(self.rank)
7265            comm_model = ModelWithComm().cuda(self.rank)
7266            model_input = torch.randn(10, 2).cuda(torch.cuda.current_device())
7267
7268            for model in [model_bn, comm_model]:
7269                model = torch.nn.parallel.DistributedDataParallel(
7270                    model,
7271                    device_ids=[self.rank],
7272                )
7273                min_num_iters = 5
7274                if self.rank != 0:
7275                    # Early termination rank(s)
7276                    num_iters = min_num_iters
7277                    exception_ctx = self.assertRaisesRegex(
7278                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
7279                    )
7280                else:
7281                    # Non early termination rank
7282                    num_iters = min_num_iters * 2
7283                    exception_ctx = self.assertRaisesRegex(
7284                        RuntimeError,
7285                        "Detected at least one rank that exhausted inputs.",
7286                    )
7287                n = 0
7288                with exception_ctx:
7289                    with model.join(throw_on_early_termination=True):
7290                        for i in range(num_iters):
7291                            loss = model(model_input).sum()
7292                            loss.backward()
7293                            self._model_step(model)
7294                            n += 1
7295
7296                self.assertEqual(n, min_num_iters)
7297                # Verify model equivalence
7298                self.validate_net_equivalence(model)
7299
7300        @skip_if_lt_x_gpu(2)
7301        @skip_but_pass_in_sandcastle_if(
7302            BACKEND not in DistTestCases.backend_feature["ddp"],
7303            f"The {BACKEND} backend does not support DistributedDataParallel",
7304        )
7305        def test_ddp_uneven_inputs(self):
7306            dim = 1000
7307            batch = 1
7308            # Create a variety of models to run uneven input tests on.
7309            large_model = nn.Sequential(
7310                nn.Conv2d(1, 20, 5),
7311                nn.ReLU(),
7312                nn.Conv2d(20, 32, 5),
7313                nn.ReLU(),
7314                nn.Conv2d(32, 256, 5),
7315                nn.ReLU(),
7316            )
7317            small_model = nn.Linear(dim, dim, bias=False)
7318            bn_net = BatchNormNet()
7319
7320            class UnusedParamModule(nn.Module):
7321                def __init__(self, unused_params_rank):
7322                    super().__init__()
7323                    self.t0 = Task()
7324                    self.t1 = Task()
7325                    self.unused_params_rank = unused_params_rank
7326
7327                def task_parameters(self):
7328                    return (self.t0.p, self.t1.p)
7329
7330                def forward(self, x, rank):
7331                    return (
7332                        self.t1(self.t0(x))
7333                        if rank != self.unused_params_rank
7334                        else self.t1(x)
7335                    )
7336
7337            unjoined_rank_with_unused_params_model = UnusedParamModule(1)
7338            joined_rank_with_unused_params_model = UnusedParamModule(0)
7339
7340            rank = self.rank
7341            models_to_test = [
7342                # Network with batchnorm
7343                DDPUnevenTestInput(
7344                    name="batch_norm_net",
7345                    model=bn_net,
7346                    inp=torch.ones(batch, 2, device=rank),
7347                    sync_interval=1,
7348                ),
7349                DDPUnevenTestInput(
7350                    name="large_conv_model",
7351                    model=large_model,
7352                    inp=torch.ones(batch, batch, dim, dim, device=rank),
7353                    sync_interval=1,
7354                ),
7355                DDPUnevenTestInput(
7356                    name="small_model",
7357                    model=small_model,
7358                    inp=torch.ones(batch, dim, device=rank),
7359                    sync_interval=1,
7360                ),
7361                # Unused parameter test where rank that does not join early has unused params
7362                DDPUnevenTestInput(
7363                    name="unjoined_rank_with_unused_params_model",
7364                    model=unjoined_rank_with_unused_params_model,
7365                    inp=(torch.ones(batch, 2, device=rank), rank),
7366                    sync_interval=1,
7367                ),
7368                # Unused parameter test where rank that does join early has unused params
7369                DDPUnevenTestInput(
7370                    name="joined_rank_with_unused_params_model",
7371                    model=joined_rank_with_unused_params_model,
7372                    inp=(torch.ones(batch, 2, device=rank), rank),
7373                    sync_interval=1,
7374                ),
7375            ]
7376
7377            # Test models that have hook installed.
7378            models_with_hook = [
7379                DDPUnevenTestInput(
7380                    name="small_model_allreduce_hook",
7381                    model=small_model,
7382                    hook=default.allreduce_hook,
7383                    state=None,
7384                    inp=torch.ones(batch, dim, device=rank),
7385                    sync_interval=1,
7386                ),
7387                DDPUnevenTestInput(
7388                    name="small_model_power_sgd_hook",
7389                    model=small_model,
7390                    hook=powerSGD.powerSGD_hook,
7391                    state=powerSGD.PowerSGDState(
7392                        process_group=None,
7393                        matrix_approximation_rank=1,
7394                        # Config so that powerSGD runs immediately instead of
7395                        # allreduce.
7396                        start_powerSGD_iter=1,
7397                        warm_start=False,
7398                        use_error_feedback=False,
7399                    ),
7400                    inp=torch.ones(batch, dim, device=rank),
7401                    sync_interval=1,
7402                ),
7403            ]
7404            models_to_test.extend(models_with_hook)
7405
7406            # Add resnet model if we have torchvision installed.
7407            if HAS_TORCHVISION:
7408                resnet_model = torchvision.models.resnet50()
7409                models_to_test.append(
7410                    DDPUnevenTestInput(
7411                        name="resnet_model",
7412                        model=resnet_model,
7413                        inp=torch.ones(1, 3, 1000, 1000),
7414                        sync_interval=1,
7415                    )
7416                )
7417
7418            # Test with no_sync every 2, 3, 4, ... iterations.
7419            models_with_sync = []
7420            for i, test_input in enumerate(models_to_test):
7421                models_with_sync.append(
7422                    DDPUnevenTestInput(
7423                        name=test_input.name,
7424                        model=test_input.model,
7425                        inp=test_input.inp,
7426                        sync_interval=i + 2,
7427                    )
7428                )
7429
7430            throw_on_early_term_tests = []
7431            for test_input in models_to_test:
7432                throw_on_early_term_tests.append(
7433                    DDPUnevenTestInput(
7434                        name=test_input.name,
7435                        model=test_input.model,
7436                        inp=test_input.inp,
7437                        sync_interval=test_input.sync_interval,
7438                        throw_on_early_termination=True,
7439                    )
7440                )
7441
7442            models_to_test.extend(models_with_sync)
7443            models_to_test.extend(throw_on_early_term_tests)
7444
7445            # 0 iteration tests for when one process does not train model at all, so
7446            # we must shadow the broadcast calls made when rebuilding buckets.
7447            baseline_num_iters = [0, 5]
7448            iteration_offsets = [2, 3, 10]
7449            num_uneven_ranks = [1]
7450            if dist.get_world_size() > 2:
7451                num_uneven_ranks.append(2)
7452            iteration_mappings = []
7453            # Generate rank : num_iters mappings for various uneven input scenarios.
7454            # This includes cases where rank 0 joins early and all other ranks join
7455            # later, and scenarios where multiple ranks join early, but at different
7456            # iterations, and later ranks join later.
7457            for num_early_join_ranks in num_uneven_ranks:
7458                for baseline_iter in baseline_num_iters:
7459                    for offset in iteration_offsets:
7460                        mapping = dict.fromkeys(range(0, num_early_join_ranks), baseline_iter)
7461                        # if num_early_join_ranks > 1, ranks > 0 that will join early
7462                        # iterate offset//2 more times than rank 0, to test nodes
7463                        # depleting inputs at different times.
7464                        if num_early_join_ranks > 1:
7465                            for rank in mapping.keys():
7466                                if rank > 0:
7467                                    mapping[rank] += offset // 2
7468                        mapping.update(
7469                            dict.fromkeys(range(num_early_join_ranks, dist.get_world_size()), baseline_iter + offset)
7470                        )
7471                        iteration_mappings.append(mapping)
7472
7473            for (test_case, iteration_mapping) in itertools.product(
7474                models_to_test, iteration_mappings
7475            ):
7476                if self.rank == 0:
7477                    print(
7478                        f"""Running test: {test_case.name} sync interval
7479                        {test_case.sync_interval} with iteration mapping
7480                        {iteration_mapping}"""
7481                    )
7482                self._run_uneven_inputs_test(
7483                    test_case,
7484                    iteration_mapping,
7485                    find_unused_params=("unused_params_model" in test_case.name),
7486                )
7487
7488        @skip_if_lt_x_gpu(2)
7489        @skip_but_pass_in_sandcastle_if(
7490            BACKEND not in DistTestCases.backend_feature["ddp"],
7491            f"The {BACKEND} backend does not support DistributedDataParallel",
7492        )
7493        def test_ddp_uneven_input_join_disable(self):
7494            # tests that if net.join() with enable=False is specified, DDP works as
7495            # expected with even inputs.
7496            torch.manual_seed(self.rank)
7497            net = torch.nn.parallel.DistributedDataParallel(
7498                torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank]
7499            )
7500            inp = torch.ones(1) * self.rank
7501            n_iters = 5
7502            world_size = dist.get_world_size()
7503            with net.join(enable=False):
7504                for _ in range(n_iters):
7505                    # Clear grads
7506                    grad = net.module.weight.grad
7507                    if grad is not None:
7508                        grad.requires_grad_(False)
7509                        grad.zero_()
7510                    out = net(inp)
7511                    loss = out.sum()
7512                    loss.backward()
7513                    # Validate gradients to ensure that we divide by the correct
7514                    # world_size when join mode is disabled.
7515                    expected_grad = sum(i for i in range(world_size)) / world_size
7516                    self.assertEqual(net.module.weight.grad.item(), expected_grad)
7517
7518            join_config = net._join_config
7519            self.assertFalse(join_config.enable)
7520            self.validate_net_equivalence(net)
7521
7522        @skip_if_lt_x_gpu(2)
7523        @skip_but_pass_in_sandcastle_if(
7524            BACKEND not in DistTestCases.backend_feature["ddp"],
7525            f"The {BACKEND} backend does not support DistributedDataParallel",
7526        )
7527        def test_ddp_uneven_input_exception(self):
7528            # Tests that exceptions during training are correctly propagated by the
7529            # context manager.
7530            error_str = "Intentional error"
7531
7532            class ExceptionModule(nn.Module):
7533                def __init__(self) -> None:
7534                    super().__init__()
7535                    self.param = nn.Parameter(torch.ones(1, requires_grad=True))
7536
7537                def forward(self, _):
7538                    raise ValueError(error_str)
7539
7540            exception_module = ExceptionModule()
7541            net = torch.nn.parallel.DistributedDataParallel(
7542                exception_module.cuda(self.rank), device_ids=[self.rank]
7543            )
7544            inp = torch.ones(1)
7545            with self.assertRaisesRegex(ValueError, error_str):
7546                with net.join():
7547                    out = net(inp)
7548                    loss = out.sum()
7549                    loss.backward()
7550
7551        def _test_broadcast_object_list(self, group=None):
7552            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
7553
7554            # Only set device for NCCL backend since it must use GPUs.
7555            # Case where rank != GPU device.
7556            next_rank = (self.rank + 1) % int(self.world_size)
7557            backend = os.environ["BACKEND"]
7558            if backend == "nccl":
7559                torch.cuda.set_device(next_rank)
7560
7561            src_rank = 0
7562            # If GPU test, add object with GPU tensor
7563            if backend == "nccl":
7564                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
7565
7566            if IS_FBCODE:
7567                # Create Tensor with > 2^31 Bytes storage requirements
7568                # Only on FBCODE as testing OOMs in OSS
7569                gather_objects.append(Foo(torch.randn(3, 178956971)))
7570            objects = (
7571                gather_objects
7572                if self.rank == src_rank
7573                else [None for _ in gather_objects]
7574            )
7575
7576            # Single object test with device specified. Backend="gloo", device=cpu
7577            if backend != "nccl":
7578                single_obj_list = [objects[0]]
7579                if self.rank != src_rank:
7580                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
7581                dist.broadcast_object_list(
7582                    single_obj_list, src=0, group=group, device=torch.device("cpu")
7583                )
7584                self.assertEqual(single_obj_list[0], gather_objects[0])
7585
7586            # Single object test with device specified. Backend="gloo", device=current_device+1
7587            # The test is gated by the fact GPU count is the same as world size to avoid the case
7588            # when backend is gloo but there is no multiple GPU devices.
7589            if backend != "nccl" and torch.cuda.device_count() == int(self.world_size):
7590                single_obj_list = [objects[0]]
7591                if self.rank != src_rank:
7592                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
7593                dist.broadcast_object_list(
7594                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
7595                )
7596                self.assertEqual(single_obj_list[0], gather_objects[0])
7597
7598            # Single object test with device specified. Backend="nccl", device=current_device+1
7599            if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
7600                single_obj_list = [objects[0]]
7601                if self.rank != src_rank:
7602                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
7603                dist.broadcast_object_list(
7604                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
7605                )
7606                self.assertEqual(single_obj_list[0], gather_objects[0])
7607
7608            # Single object test: backward compatibility with device unspecified
7609            single_obj_list = [objects[0]]
7610            if self.rank != src_rank:
7611                self.assertNotEqual(single_obj_list[0], gather_objects[0])
7612            dist.broadcast_object_list(single_obj_list, src=0, group=group)
7613            self.assertEqual(single_obj_list[0], gather_objects[0])
7614
7615            # Multiple input objects test
7616            if self.rank != src_rank:
7617                self.assertNotEqual(objects, gather_objects)
7618            dist.broadcast_object_list(objects, src=0, group=group)
7619            self.assertEqual(objects, gather_objects)
7620
7621        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7622        @require_n_gpus_for_nccl_backend(
7623            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
7624        )
7625        @with_dist_debug_levels(levels=["DETAIL"])
7626        @unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
7627        def test_broadcast_object_list(self):
7628            return self._test_broadcast_object_list()
7629
7630        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7631        @require_n_gpus_for_nccl_backend(
7632            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
7633        )
7634        @with_dist_debug_levels(levels=["DETAIL"])
7635        def _test_broadcast_object_list_subgroup(self):
7636            default = _get_default_group()
7637            backend = dist.get_backend(default)
7638            subgroup = dist.new_group(backend=backend)
7639            return self._test_broadcast_object_list(subgroup)
7640
7641        def _test_ddp_ignore_params_arg(self, static_graph=False):
7642            class TestModel(nn.Module):
7643                def __init__(self, rank):
7644                    self.rank = rank
7645                    super().__init__()
7646                    self.fc1 = nn.Linear(1, 1, bias=False)
7647                    # Proxy that will be materialized to another architecture later.
7648                    # (after wrapping model with DDP)
7649                    if self.rank == 0:
7650                        self.fc2 = nn.Linear(1, 10, bias=False)
7651                    else:
7652                        self.fc2 = nn.Linear(10, 10, bias=False)
7653
7654                def forward(self, x):
7655                    x = self.fc1(x)
7656                    x = self.fc2(x)
7657                    return x
7658
7659            device_id = self.rank
7660            # Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
7661            for (find_unused, broadcast_buffers) in itertools.product(
7662                [False, True], [False, True]
7663            ):
7664                model = TestModel(self.rank).float().to(device_id)
7665                # Note that the model can have different shape buffers if we pass
7666                # them in to be ignored as well.
7667                model.fc2.register_buffer(
7668                    "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
7669                )
7670                proxy_params = list(model.fc2.parameters())
7671                proxy_buffers = list(model.fc2.buffers())
7672                model_fc2_name = next(
7673                    module_name
7674                    for module_name, module in model.named_modules()
7675                    if module is model.fc2
7676                )
7677                proxy_param_names = [
7678                    f"{model_fc2_name}.{param_name}"
7679                    for param_name, _ in model.fc2.named_parameters()
7680                ]
7681                proxy_buffer_names = [
7682                    f"{model_fc2_name}.{buf_name}"
7683                    for buf_name, _ in model.fc2.named_buffers()
7684                ]
7685                # Specify that we should ignore proxy_params since it will be
7686                # materialized later.
7687                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
7688                    model, proxy_param_names + proxy_buffer_names
7689                )
7690                ddp = torch.nn.parallel.DistributedDataParallel(
7691                    model,
7692                    device_ids=[device_id],
7693                    find_unused_parameters=find_unused,
7694                    broadcast_buffers=broadcast_buffers,
7695                    static_graph=static_graph,
7696                )
7697                # Materialize new params. These are not registered in DDP and thus
7698                # don't have autograd hooks installed on them.
7699                ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
7700
7701                # local model with the new materialized parameters.
7702                local_model = copy.deepcopy(ddp.module).cuda(self.rank)
7703
7704                inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
7705                for i in range(6):
7706                    ddp(inp).sum().backward()
7707
7708                    local_model(inp).sum().backward()
7709                    # materialized param grad is not touched by DDP, so its grad should
7710                    # be the same as if running locally.
7711                    for materialized_param, local_param in zip(
7712                        ddp.module.fc2.parameters(), local_model.fc2.parameters()
7713                    ):
7714                        self.assertEqual(materialized_param.grad, local_param.grad)
7715
7716                    # fc1 parameter grad should still be different, due to allreduce.
7717                    for synced_param, local_param in zip(
7718                        ddp.module.fc1.parameters(), local_model.fc1.parameters()
7719                    ):
7720                        self.assertFalse(synced_param.grad == local_param.grad)
7721
7722                    # Proxy module grad should not be touched
7723                    for proxy_param in proxy_params:
7724                        self.assertTrue(proxy_param.grad is None)
7725
7726                # Synchronize since we run multiple iterations of this test, to
7727                # isolate failure hangs.
7728                torch.cuda.synchronize(device=self.rank)
7729
7730        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7731        @skip_if_lt_x_gpu(2)
7732        def test_ddp_ignore_params_arg(self):
7733            self._test_ddp_ignore_params_arg(static_graph=False)
7734            self._test_ddp_ignore_params_arg(static_graph=True)
7735
7736        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
7737        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7738        @skip_if_lt_x_gpu(2)
7739        def test_ddp_unused_params_rebuild_buckets_exception(self):
7740            class ToyModel(nn.Module):
7741                def __init__(self) -> None:
7742                    super().__init__()
7743                    self.net1 = nn.Linear(10, 10, bias=False)
7744                    self.net2 = nn.Linear(10, 10, bias=False)
7745
7746                def forward(self, x):
7747                    return self.net1(x)
7748
7749            ddp = torch.nn.parallel.DistributedDataParallel(
7750                ToyModel().cuda(self.rank), device_ids=[self.rank]
7751            )
7752            for i in range(2):
7753                inp = torch.rand(1, 10)
7754                if i > 0:
7755                    # On 2nd iteration, this will fail during rebuild_buckets,
7756                    # but we should report an error regarding unused parameters
7757                    # since that is the underlying root cause.
7758                    try:
7759                        ddp(inp).sum().backward()
7760                    except RuntimeError as e:
7761                        msg = str(e)
7762                        verify_ddp_error_logged(ddp, msg)
7763                        expected_strs = [
7764                            ddp_prev_reduction_unfinished_str,
7765                            ddp_recommend_find_unused_params_str,
7766                            ddp_outputs_not_used_in_loss_str,
7767                        ]
7768                        # In debug mode, should show parameters that weren't reduced.
7769                        # Without debug mode, should show suggestion to use debug mode.
7770                        if dist.get_debug_level() == dist.DebugLevel.OFF:
7771                            expected_strs.append(ddp_suggest_debug_mode_str)
7772                        else:
7773                            unreduced_params = ", ".join(["net2.weight"])
7774                            expected_strs.append(
7775                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
7776                            )
7777                        for s in expected_strs:
7778                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
7779                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
7780                    else:
7781                        self.assertFalse(
7782                            True, "DDP unused parameters error not raised."
7783                        )
7784                else:
7785                    ddp(inp).sum().backward()
7786
7787            dist.barrier()
7788
7789        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7790        @skip_if_lt_x_gpu(2)
7791        def test_ddp_shared_grad_acc_unused_params(self):
7792            # When find_unused_parameters=True, ensure we mark unused parameters
7793            # even if they share gradient accumulators.
7794            class ToyModel(nn.Module):
7795                def __init__(self) -> None:
7796                    super().__init__()
7797                    # net1, bias, and net1.bias are all unused params.
7798                    self.net1 = nn.Linear(10, 5, bias=False)
7799                    self.bias = nn.Parameter(torch.zeros(5))
7800                    # net1.bias and self.bias are names for the same underlying
7801                    # parameter, so they share the same grad acc. This caused
7802                    # the bug reported in https://github.com/pytorch/pytorch/issues/41324.
7803                    self.net1.bias = self.bias
7804                    self.net2 = nn.Linear(10, 5)
7805
7806                def forward(self, x):
7807                    return self.net2(x).sum()
7808
7809            torch.cuda.set_device(self.rank)
7810            model = ToyModel().to(torch.cuda.current_device())
7811            for static in [True, False]:
7812                ddp_model = torch.nn.parallel.DistributedDataParallel(
7813                    copy.deepcopy(model),
7814                    device_ids=[self.rank],
7815                    find_unused_parameters=True,
7816                    static_graph=static,
7817                )
7818                inp = torch.randn(20, 10, device=self.rank)
7819                for i in range(6):
7820                    loss = ddp_model(inp)
7821                    # To test https://github.com/pytorch/pytorch/issues/61982
7822                    loss /= 10
7823                    loss.backward()
7824
7825        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7826        @skip_if_lt_x_gpu(2)
7827        def test_ddp_device(self):
7828            m = nn.Linear(10, 10).to(self.rank)
7829            expected_len = 2
7830
7831            class TensorWrapper:
7832                __slots__ = ["t", "moved_to_gpu"]
7833
7834                def __init__(self, t):
7835                    self.t = t
7836                    self.moved_to_gpu = False
7837
7838            # Handlers for specific types of validation we want to do based on
7839            # the input type.
7840
7841            def tuple_and_list_validator(x):
7842                self.assertTrue(len(x), expected_len)
7843                self.assertEqual(1, len({t.device for t in x}))
7844                self.assertEqual(x[0].device.index, self.rank)
7845                return x[0] + x[1]
7846
7847            def namedtuple_validator(x):
7848                self.assertEqual(x._fields, EXPECTED_FIELDS)
7849                self.assertEqual(x.a.device.index, x.b.device.index)
7850                self.assertEqual(x.a.device.index, self.rank)
7851                return x.a + x.b
7852
7853            def custom_type_validator(x):
7854                self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
7855                x.t = x.t.to(self.rank)
7856                x.moved_to_gpu = True
7857                return x.t
7858
7859            def dict_validator(x):
7860                self.assertTrue(EXPECTED_FIELDS[0] in x.keys())
7861                self.assertTrue(EXPECTED_FIELDS[1] in x.keys())
7862                self.assertEqual(1, len({t.device for t in x.values()}))
7863                self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
7864                return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]
7865
7866            validators = {
7867                TensorWrapper: custom_type_validator,
7868                tuple: tuple_and_list_validator,
7869                list: tuple_and_list_validator,
7870                TestNamedTupleInput_0: namedtuple_validator,
7871                TestNamedTupleInput_1: namedtuple_validator,
7872                dict: dict_validator,
7873            }
7874
7875            class ToyModel(torch.nn.Module):
7876                def __init__(self_):  # noqa: B902
7877                    super().__init__()
7878                    self_.lin = nn.Linear(10, 10, bias=False)
7879
7880                def forward(self_, x, expected_type):  # noqa: B902
7881                    # Similar to scatter, the recursive to in the single-device
7882                    # case does not move tensors if they are in a custom type.
7883                    self.assertTrue(isinstance(x, expected_type))
7884                    fwd_tensor = validators[expected_type](x)
7885                    return self_.lin(fwd_tensor)
7886
7887            model = torch.nn.parallel.DistributedDataParallel(
7888                ToyModel().to(self.rank), device_ids=[self.rank]
7889            )
7890
7891            def train_iter(inp, input_type):
7892                for _ in range(4):
7893                    out = model(inp, input_type)
7894                    out.sum().backward()
7895
7896            # CPU tuple input, should be moved to the proper device before call
7897            # to forward.
7898            inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
7899            train_iter(inp, tuple)
7900
7901            # List CPU input, should be moved to proper device before call to
7902            # forward.
7903            inp = [torch.randn(10, 10) for _ in range(expected_len)]
7904            train_iter(inp, list)
7905            # Custom type containing tensor. The type is maintained, but the
7906            # device is not propagated (which is what happens with scatter too)
7907            inp = TensorWrapper(torch.randn(10, 10))
7908            train_iter(inp, TensorWrapper)
7909            # NamedTuple input. The type should be maintained and tensor inputs
7910            # should be moved to the correct device as in scatter.
7911            batch = 5
7912            dim = 10
7913            a = torch.rand(batch, dim)
7914            b = torch.rand(batch, dim)
7915
7916            inp = TestNamedTupleInput_0(a, b)
7917            train_iter(inp, type(inp))
7918
7919            inp = TestNamedTupleInput_1(a, b)
7920            train_iter(inp, type(inp))
7921
7922            # dictionary input.
7923            inp = {
7924                EXPECTED_FIELDS[0]: a,
7925                EXPECTED_FIELDS[1]: b,
7926            }
7927            train_iter(inp, type(inp))
7928
7929        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
7930        @skip_if_lt_x_gpu(2)
7931        def test_ddp_namedtuple(self):
7932            batch = 5
7933            dim = 10
7934
7935            a = torch.rand(batch, dim, device=self.rank)
7936            b = torch.rand(batch, dim, device=self.rank)
7937
7938            class NamedTupleModule(torch.nn.Module):
7939                def __init__(self_):  # noqa: B902
7940                    super().__init__()
7941                    self_.lin = nn.Linear(10, 1)
7942
7943                def forward(self_, input, expected_type):  # noqa: B902
7944                    # Without NamedTuple support, this would be of type tuple.
7945                    self.assertTrue(
7946                        isinstance(input, expected_type),
7947                        f"Expected type {expected_type} but got {type(input)}",
7948                    )
7949                    self.assertEqual(input._fields, EXPECTED_FIELDS)
7950                    self.assertEqual(a, input.a)
7951                    self.assertEqual(b, input.b)
7952                    return self_.lin(torch.mul(input.a, input.b))
7953
7954            model = torch.nn.parallel.DistributedDataParallel(
7955                NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
7956            )
7957            inp = TestNamedTupleInput_0(a, b)
7958            # The following would fail if DDP does not propagate NamedTuples correctly.
7959            model(inp, type(inp))
7960
7961            inp = TestNamedTupleInput_1(a, b)
7962            model(inp, type(inp))
7963
7964        @require_backend_is_available({"gloo"})
7965        def test_grads_same_across_ranks_with_no_sync(self):
7966            group, group_id, rank = self._init_global_test()
7967            world_size = dist.get_world_size()
7968            if world_size < 2:
7969                self.skipTest("This test requires at least two ranks.")
7970
7971            class SimpleConditionalModel(nn.Module):
7972                # if rank is 0, uses nn1 on the first pass and nn2 on the second pass.
7973                # else, uses nn3 on the first pass and nn4 on the second pass.
7974
7975                def __init__(self, rank):
7976                    super().__init__()
7977
7978                    self.rank = rank
7979                    self.nn1 = nn.Linear(1, 1)
7980                    self.nn2 = nn.Linear(1, 1)
7981                    self.nn3 = nn.Linear(1, 1)
7982                    self.nn4 = nn.Linear(1, 1)
7983                    self.state = 0
7984
7985                def forward(self, input):
7986                    if self.state == 0:
7987                        self.state = 1
7988                        if self.rank == 0:
7989                            return self.nn1(input)
7990                        else:
7991                            return self.nn3(input)
7992                    else:
7993                        self.state = 0
7994                        if self.rank == 0:
7995                            return self.nn2(input)
7996                        else:
7997                            return self.nn4(input)
7998
7999            model = torch.nn.parallel.DistributedDataParallel(
8000                SimpleConditionalModel(rank), find_unused_parameters=True
8001            )
8002            mse_loss = nn.MSELoss()
8003            grad_accumulation = 2
8004
8005            for microbatch_idx in range(grad_accumulation):
8006                if microbatch_idx < grad_accumulation - 1:
8007                    context = model.no_sync
8008                else:
8009                    context = nullcontext
8010
8011                with context():
8012                    input = torch.rand((1, ))
8013                    output = model.forward(input)
8014                    target = torch.rand((1, ))
8015
8016                    loss = mse_loss(output, target)
8017                    loss.backward()
8018
8019            self.assertTrue(
8020                not any(p.grad is None for p in model.parameters()),
8021                "Gradients can't be None for any model parameter."
8022            )
8023            grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
8024
8025            # Gather all gradients to rank 0.
8026            if rank == 0:
8027                gathered_grads = [torch.zeros_like(grads) for _ in range(world_size)]
8028            else:
8029                gathered_grads = []
8030
8031            dist.gather(grads, gather_list=gathered_grads, dst=0)
8032            if rank == 0:
8033                for g in gathered_grads[1:]:
8034                    self.assertTrue(
8035                        torch.allclose(gathered_grads[0], g),
8036                        "Gradients are not the same for all ranks."
8037                    )
8038
8039        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
8040        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8041        @skip_if_lt_x_gpu(2)
8042        def test_ddp_control_flow_same_across_ranks(self):
8043            # Control flow that is the same across ranks.
8044            batch = 20
8045            dim = 10
8046
8047            world_size = dist.get_world_size()
8048            torch.cuda.set_device(self.rank)
8049            model = torch.nn.parallel.DistributedDataParallel(
8050                ControlFlowToyModel().cuda(self.rank),
8051                device_ids=[self.rank],
8052                find_unused_parameters=True,
8053            )
8054            random_input = torch.randn(batch, dim, device=self.rank)
8055            ones_input = torch.ones(batch, dim, device=self.rank)
8056            for i in range(6):
8057                if i % 2 == 0:
8058                    out = model(random_input)
8059                else:
8060                    out = model(ones_input)
8061                loss = out.sum()
8062                loss.backward()
8063                # On even iterations, 2nd param goes unused, on odd iterations,
8064                # it is used.
8065                local_used_map = model.reducer._get_local_used_map()
8066                if i % 2 == 0:
8067                    expected = torch.tensor(
8068                        [world_size, 0], device=self.rank, dtype=torch.int32
8069                    )
8070                else:
8071                    expected = torch.tensor(
8072                        [world_size, world_size], device=self.rank, dtype=torch.int32
8073                    )
8074
8075                # Validate parameter usage.
8076                variable_usage_tensor = local_used_map
8077                self.assertEqual(variable_usage_tensor, expected)
8078
8079            # Validate appropriate error message when DDP is used with
8080            # find_unused_parameters=False.
8081            model = torch.nn.parallel.DistributedDataParallel(
8082                ControlFlowToyModel().cuda(self.rank),
8083                device_ids=[self.rank],
8084                find_unused_parameters=False,
8085            )
8086            for i in range(2):
8087                if i == 0:
8088                    loss = model(random_input).sum()
8089                    loss.backward()
8090                else:
8091                    try:
8092                        loss = model(random_input).sum()
8093                        loss.backward()
8094                    except RuntimeError as e:
8095                        msg = str(e)
8096                        verify_ddp_error_logged(model, msg)
8097                        # 2nd linear layer is unused
8098                        unused_param_index = 1
8099                        expected_strs = [
8100                            ddp_prev_reduction_unfinished_str,
8101                            ddp_recommend_find_unused_params_str,
8102                            ddp_outputs_not_used_in_loss_str,
8103                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
8104                        ]
8105                        # In debug mode, should show parameters that weren't reduced.
8106                        # Without debug mode, should show suggestion to use debug mode.
8107                        if dist.get_debug_level() == dist.DebugLevel.OFF:
8108                            expected_strs.append(ddp_suggest_debug_mode_str)
8109                        else:
8110                            unreduced_params = ", ".join(["lin2.weight"])
8111                            expected_strs.append(
8112                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
8113                            )
8114                        for s in expected_strs:
8115                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
8116                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
8117                    else:
8118                        self.assertFalse(True, "DDP error not raised")
8119
8120            dist.barrier()
8121
8122        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8123        @skip_if_lt_x_gpu(2)
8124        def test_invalid_static_graph(self):
8125            world_size = dist.get_world_size()
8126            torch.cuda.set_device(self.rank)
8127            model = torch.nn.parallel.DistributedDataParallel(
8128                ControlFlowToyModel().cuda(self.rank),
8129                device_ids=[self.rank],
8130                static_graph=True,
8131            )
8132            random_input = torch.randn(20, 10, device=self.rank)
8133            ones_input = torch.ones(20, 10, device=self.rank)
8134            # unused parameter in the first iteration got used
8135            # in second iteration.
8136            expected_err = "Your training graph has changed in this iteration"
8137            with self.assertRaisesRegex(RuntimeError, expected_err):
8138                for i in range(2):
8139                    if i % 2 == 0:
8140                        out = model(random_input)
8141                    else:
8142                        out = model(ones_input)
8143                    loss = out.sum()
8144                    loss.backward()
8145
8146            verify_ddp_error_logged(model, expected_err)
8147
8148            # used parameter in the first iteration got unused
8149            # in second iteration.
8150            with self.assertRaisesRegex(
8151                RuntimeError,
8152                "Expected to have finished reduction in the prior iteration "
8153                "before starting a new one. This error indicates that your "
8154                "training graph has changed in this iteration, "
8155                "e.g., one parameter is used in first iteration, "
8156                "but then got unused in the second iteration. "
8157                "this is not compatible with static_graph set to True.\n"
8158                "Parameter indices which did not receive grad for",
8159            ):
8160                for i in range(2):
8161                    if i % 2 != 0:
8162                        out = model(random_input)
8163                    else:
8164                        out = model(ones_input)
8165                    loss = out.sum()
8166                    loss.backward()
8167
8168            verify_ddp_error_logged(model, "Expected to have finished reduction")
8169
8170        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
8171        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8172        @skip_if_lt_x_gpu(2)
8173        def test_ddp_control_flow_different_across_ranks(self):
8174            # Control flow that is different across ranks.
8175            batch = 20
8176            dim = 10
8177
8178            class ToyModel(nn.Module):
8179                def __init__(self, rank):
8180                    super().__init__()
8181                    self.lin1 = nn.Linear(10, 10, bias=False)
8182                    self.lin2 = nn.Linear(10, 10, bias=False)
8183                    self.rank = rank
8184
8185                def forward(self, x):
8186                    # Control-flow that is rank and input dependent for the
8187                    # model.
8188                    use_second_layer = (
8189                        torch.equal(x, torch.ones(batch, dim, device=x.device))
8190                        and self.rank == 1
8191                    )
8192
8193                    if use_second_layer:
8194                        return self.lin2(F.relu(self.lin1(x)))
8195                    else:
8196                        return F.relu(self.lin1(x))
8197
8198            world_size = dist.get_world_size()
8199            torch.cuda.set_device(self.rank)
8200            model = torch.nn.parallel.DistributedDataParallel(
8201                ToyModel(self.rank).cuda(self.rank),
8202                device_ids=[self.rank],
8203                find_unused_parameters=True,
8204            )
8205            random_input = torch.randn(batch, dim, device=self.rank)
8206            ones_input = torch.ones(batch, dim, device=self.rank)
8207            for i in range(6):
8208                if i % 2 == 0:
8209                    out = model(random_input)
8210                else:
8211                    out = model(ones_input)
8212                loss = out.sum()
8213                loss.backward()
8214                # On even iterations, 2nd param goes unused, on odd iterations,
8215                # it is used only on rank 1.
8216                local_used_map = model.reducer._get_local_used_map()
8217
8218                if i % 2 == 0:
8219                    expected = torch.tensor(
8220                        [world_size, 0], device=self.rank, dtype=torch.int32
8221                    )
8222                else:
8223                    expected = torch.tensor(
8224                        [world_size, 1], device=self.rank, dtype=torch.int32
8225                    )
8226
8227                variable_usage_tensor = local_used_map
8228                # Validate parameter usage. On odd iterations, 2nd param is only
8229                # used on rank 1.
8230                self.assertEqual(variable_usage_tensor, expected)
8231
8232            # Validate appropriate error message when DDP is used with
8233            # find_unused_parameters=False.
8234            model = torch.nn.parallel.DistributedDataParallel(
8235                ToyModel(self.rank).cuda(self.rank),
8236                device_ids=[self.rank],
8237                find_unused_parameters=False,
8238            )
8239            for i in range(2):
8240                if i == 0:
8241                    loss = model(random_input).sum()
8242                    loss.backward()
8243                else:
8244                    try:
8245                        loss = model(random_input).sum()
8246                        loss.backward()
8247                    except RuntimeError as e:
8248                        msg = str(e)
8249                        verify_ddp_error_logged(model, msg)
8250                        unused_param_index = 1
8251                        expected_strs = [
8252                            ddp_prev_reduction_unfinished_str,
8253                            ddp_recommend_find_unused_params_str,
8254                            ddp_outputs_not_used_in_loss_str,
8255                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
8256                        ]
8257                        # In debug mode, should show parameters that weren't reduced.
8258                        # Without debug mode, should show suggestion to use debug mode.
8259                        if dist.get_debug_level() == dist.DebugLevel.OFF:
8260                            expected_strs.append(ddp_suggest_debug_mode_str)
8261                        else:
8262                            unreduced_params = ", ".join(["lin2.weight"])
8263                            expected_strs.append(
8264                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
8265                            )
8266                        for s in expected_strs:
8267                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
8268                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
8269                    else:
8270                        self.assertFalse(True, "DDP error not raised")
8271
8272            dist.barrier()
8273
8274        @require_backend_is_available({"gloo"})
8275        def test_scatter_object_list(self):
8276            src_rank = 0
8277            scatter_list = (
8278                COLLECTIVES_OBJECT_TEST_LIST
8279                if self.rank == src_rank
8280                else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
8281            )
8282            world_size = dist.get_world_size()
8283            scatter_list = scatter_list[:world_size]
8284            i = 0
8285            while len(scatter_list) < world_size:
8286                scatter_list.append(scatter_list[i])
8287                i += 1
8288
8289            output_obj_list = [None]
8290            dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
8291            self.assertEqual(
8292                output_obj_list[0],
8293                COLLECTIVES_OBJECT_TEST_LIST[
8294                    self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
8295                ],
8296            )
8297            # Ensure errors are raised upon incorrect arguments.
8298            with self.assertRaisesRegex(
8299                ValueError,
8300                "Expected argument scatter_object_output_list to be a list of size at least 1.",
8301            ):
8302                dist.scatter_object_list([], scatter_list, src=src_rank)
8303
8304        def _generate_sparse_tensors_for_bucket_assignment_test(self):
8305            tensors = [
8306                torch.empty([50], dtype=torch.float),
8307                torch.empty([25], dtype=torch.double),
8308                torch.empty([50], dtype=torch.float),
8309                torch.empty([25], dtype=torch.double),
8310                torch.empty([50], dtype=torch.float),
8311                torch.empty([25], dtype=torch.double),
8312            ]
8313
8314            tensors_sparse = [t.to_sparse() for t in tensors]
8315            return tensors_sparse
8316
8317        def _test_compute_bucket_assignment_by_size(self, use_logger):
8318            group_gloo = dist.new_group(
8319                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
8320            )
8321            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
8322            # determinism.
8323            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
8324            group_to_use = dist.new_group(
8325                backend=dist.get_backend(), timeout=timedelta(seconds=5)
8326            )
8327            torch.cuda.set_device(self.rank)
8328
8329            # Create a valid model. The constructor initializes the logger that we use later.
8330            # We never actually use the rest of the model - we only need its logger.
8331            net = EmbeddingNetDifferentParams(0)
8332            net = torch.nn.parallel.DistributedDataParallel(
8333                net.to(self.rank),
8334                device_ids=[self.rank],
8335                process_group=group_to_use,
8336            )
8337
8338            # if we don't pass a logger then we can only check that an exception was thrown.
8339            expected_err = "No support for sparse tensors."
8340            with self.assertRaisesRegex(RuntimeError, expected_err):
8341                tensors_sparse = (
8342                    self._generate_sparse_tensors_for_bucket_assignment_test()
8343                )
8344                if use_logger:
8345                    result = dist._compute_bucket_assignment_by_size(
8346                        tensors_sparse, [400], logger=net.logger
8347                    )
8348                else:
8349                    result = dist._compute_bucket_assignment_by_size(
8350                        tensors_sparse, [400]
8351                    )
8352            if use_logger:
8353                verify_ddp_error_logged(net, expected_err)
8354
8355            # Perform gloo-based barrier to ensure one rank doesn't exit test
8356            # early which causes failure with Barrier.sync.
8357            dist.barrier(group_gloo)
8358
8359        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8360        @skip_if_lt_x_gpu(2)
8361        def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
8362            self._test_compute_bucket_assignment_by_size(use_logger=False)
8363
8364        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8365        @skip_if_lt_x_gpu(2)
8366        def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
8367            self._test_compute_bucket_assignment_by_size(use_logger=True)
8368
8369        def _determine_expected_error_verify_model_across_rank(
8370            self, group_to_use, diff_num_params=False
8371        ):
8372            # When running with NCCL backend, we don't expect an error on rank 0,
8373            # rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When
8374            # running with Gloo or with debug mode wrapper, we expect the error
8375            # to be caught inline.
8376            # All ranks report same error when there is a # of parameter
8377            # mismatch since we use allgather in the impl.
8378            if diff_num_params:
8379                expected_err = "DDP expects same model across all ranks"
8380                ctx = self.assertRaisesRegex(RuntimeError, expected_err)
8381                return ctx, expected_err
8382
8383            is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL
8384            if self.rank == 0:
8385                if (
8386                    dist.get_backend(group_to_use) == dist.Backend.NCCL
8387                    and not is_detail_dbg_mode
8388                ):
8389                    expected_err = "caught collective operation timeout"
8390                    ctx = self.assertRaisesRegex(RuntimeError, expected_err)
8391                else:
8392                    expected_err = None
8393                    ctx = self.assertRaises(RuntimeError)
8394            else:
8395                expected_err = "appears not to match"
8396                ctx = self.assertRaisesRegex(RuntimeError, expected_err)
8397            return ctx, expected_err
8398
8399        def _test_verify_model_across_rank(self, use_logger):
8400            group_gloo = dist.new_group(
8401                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
8402            )
8403            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
8404            # determinism.
8405            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
8406            group_to_use = dist.new_group(
8407                backend=dist.get_backend(), timeout=timedelta(seconds=5)
8408            )
8409            torch.cuda.set_device(self.rank)
8410            ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
8411                group_to_use
8412            )
8413
8414            # Create a valid model. The constructor initializes the logger that we use later.
8415            net = EmbeddingNetDifferentParams(0)
8416            net = torch.nn.parallel.DistributedDataParallel(
8417                net.to(self.rank),
8418                device_ids=[self.rank],
8419                process_group=group_to_use,
8420            )
8421
8422            # Modify the model so that the number of parameters are different for each rank.
8423            # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes,
8424            # so we can check if the correct error is thrown and is logged.
8425            # We can't do this in the constructor above otherwise the logger will
8426            # not be properly initialized.
8427            net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
8428
8429            # if we pass a logger we can verify that it was logged
8430            with ctx:
8431                if use_logger:
8432                    _verify_param_shape_across_processes(
8433                        net.process_group, list(net.parameters()), net.logger
8434                    )
8435                else:
8436                    _verify_param_shape_across_processes(
8437                        net.process_group, list(net.parameters())
8438                    )
8439                # Should only be run by rank 0, and blocking_wait catches and
8440                # reports exception.
8441                dist.barrier(group_to_use)
8442
8443            # We don't check when self.rank != 0 because the logger doesn't log
8444            # the error "Caught collective operation" as that is not thrown in the reducer.
8445            if use_logger and self.rank != 0:
8446                verify_ddp_error_logged(net, expected_err)
8447
8448            # Perform gloo-based barrier to ensure one rank doesn't exit test
8449            # early which causes failure with Barrier.sync.
8450            dist.barrier(group_gloo)
8451
8452        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8453        @skip_but_pass_in_sandcastle_if(
8454            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
8455        )
8456        @skip_if_lt_x_gpu(2)
8457        def test_verify_model_across_rank_with_logger(self):
8458            self._test_verify_model_across_rank(use_logger=True)
8459
8460        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8461        @skip_but_pass_in_sandcastle_if(
8462            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
8463        )
8464        @skip_if_lt_x_gpu(2)
8465        def test_verify_model_across_rank_without_logger(self):
8466            self._test_verify_model_across_rank(use_logger=False)
8467
8468        def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
8469            with ctx:
8470                net = torch.nn.parallel.DistributedDataParallel(
8471                    net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
8472                )
8473                # Should only be run by rank 0, and blocking_wait catches and
8474                # reports exception.
8475                dist.barrier(ddp_group)
8476
8477            # can't use verify_ddp_error_logged here because net was never properly constructed
8478
8479            # Perform gloo-based barrier to ensure one rank doesn't exit test
8480            # early which causes failure with Barrier.sync.
8481            dist.barrier(group_gloo)
8482
8483        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8484        @skip_but_pass_in_sandcastle_if(
8485            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
8486        )
8487        @skip_if_lt_x_gpu(2)
8488        def test_ddp_model_diff_shape_across_ranks(self):
8489            group_gloo = dist.new_group(
8490                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
8491            )
8492            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
8493            # determinism.
8494            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
8495            group_to_use = dist.new_group(
8496                backend=dist.get_backend(), timeout=timedelta(seconds=10)
8497            )
8498            torch.cuda.set_device(self.rank)
8499            ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
8500                group_to_use
8501            )
8502            # Creates network with different sized embedding table on different
8503            # ranks. This should throw an error during DDP init.
8504            net = EmbeddingNetDifferentParams(self.rank)
8505            self._run_test_ddp_model_with_diff_params(
8506                ctx, net, group_to_use, group_gloo
8507            )
8508
8509        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8510        @skip_but_pass_in_sandcastle_if(
8511            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
8512        )
8513        @skip_if_lt_x_gpu(2)
8514        def test_ddp_model_diff_num_params_across_ranks(self):
8515            group_gloo = dist.new_group(
8516                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
8517            )
8518            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
8519            # determinism.
8520            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
8521            group_to_use = dist.new_group(
8522                backend=dist.get_backend(), timeout=timedelta(seconds=10)
8523            )
8524            torch.cuda.set_device(self.rank)
8525            ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
8526                group_to_use, diff_num_params=True
8527            )
8528
8529            # Creates network with diff # of param across ranks, reducer should
8530            # recognize this and throw appropriate error.
8531            net = EmbeddingNetDifferentParams(
8532                self.rank, diff_num_params=(self.rank == 1)
8533            )
8534
8535            self._run_test_ddp_model_with_diff_params(
8536                ctx,
8537                net,
8538                group_to_use,
8539                group_gloo,
8540            )
8541
8542        def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view):
8543            model = module_cls()
8544            local_net = copy.deepcopy(model)
8545            net = torch.nn.parallel.DistributedDataParallel(
8546                copy.deepcopy(model).cuda(self.rank),
8547                device_ids=[self.rank],
8548                find_unused_parameters=True,
8549            )
8550
8551            # Tests that certain parameters not getting gradient since the
8552            # output is unused in loss computation is supported. Specifically,
8553            # checks that the grads remain unchanged and are the same as local
8554            # training.
8555            inp = torch.randn(10, 10)
8556
8557            # Ensure that if a param is not used in loss computation, its
8558            # gradient is untouched, i.e. if it is None before it is None after,
8559            # not zero.
8560            if module_cls == DictOutputModule:
8561                a, b = local_net(inp)["predictions"]
8562                a_dist, b_dist = net(inp)["predictions"]
8563            else:
8564                a, b = local_net(inp)
8565                a_dist, b_dist = net(inp)
8566
8567            loss_dist = b_dist.sum()
8568            loss_dist.backward()
8569
8570            # Ensure that gradient corresponding to parameter "a" was not
8571            # touched, i.e. it is None and matches the local grad.
8572            if module_cls == DictOutputModule:
8573                self.assertTrue(net.module.module.a.weight.grad is None)
8574                self.assertEqual(
8575                    net.module.module.a.weight.grad, local_net.module.a.weight.grad
8576                )
8577            else:
8578                self.assertTrue(net.module.a.weight.grad is None)
8579                self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad)
8580
8581            saved_a_local_grad = None
8582            saved_a_dist_grad = None
8583            net.zero_grad()
8584            local_net.zero_grad()
8585            for i in range(6):
8586                if module_cls == DictOutputModule:
8587                    a, b = local_net(inp)["predictions"]
8588                    a_dist, b_dist = net(inp)["predictions"]
8589                else:
8590                    a, b = local_net(inp)
8591                    a_dist, b_dist = net(inp)
8592                if i < 2:
8593                    # Use both params in loss computation. Later, "a" will go
8594                    # unused and we check to ensure DDP supports this and
8595                    # gradients remain the same as local training.
8596                    t = a @ b
8597                    t_dist = a_dist @ b_dist
8598                    loss = t.sum()
8599                    loss_dist = t_dist.sum()
8600                else:
8601                    # Model output "a" unused in loss.
8602                    loss = b.sum()
8603                    loss_dist = b_dist.sum()
8604                loss.backward()
8605                loss_dist.backward()
8606                if i == 1:
8607                    # Save grads to compare with them in next iterations.
8608                    if module_cls == DictOutputModule:
8609                        saved_a_local_grad = local_net.module.a.weight.grad
8610                        saved_a_dist_grad = net.module.module.a.weight.grad
8611                    else:
8612                        saved_a_local_grad = local_net.a.weight.grad
8613                        saved_a_dist_grad = net.module.a.weight.grad
8614                    self.assertEqual(saved_a_local_grad, saved_a_dist_grad)
8615                elif i >= 2:
8616                    # parameter "a" of both models should be the same and not change
8617                    if module_cls == DictOutputModule:
8618                        self.assertEqual(
8619                            net.module.module.a.weight.grad, saved_a_dist_grad
8620                        )
8621                        self.assertEqual(
8622                            local_net.module.a.weight.grad, saved_a_local_grad
8623                        )
8624                    else:
8625                        self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad)
8626                        self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
8627
8628                # Verify grads are the same
8629                for (local_param, dist_param) in zip(
8630                    local_net.parameters(), net.parameters()
8631                ):
8632                    local_grad = local_param.grad
8633                    dist_grad = dist_param.grad
8634                    self.assertEqual(local_grad, dist_grad)
8635
8636            dist.barrier()
8637
8638        @skip_but_pass_in_sandcastle_if(
8639            BACKEND not in DistTestCases.backend_feature["ddp"],
8640            f"The {BACKEND} backend does not support DistributedDataParallel",
8641        )
8642        @skip_if_lt_x_gpu(2)
8643        def test_output_unused_in_loss_tuple_module(self):
8644            module_cls = UnusedParamTwoLinLayerNet
8645            for grad_as_bucket_view in [True, False]:
8646                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
8647
8648        @skip_but_pass_in_sandcastle_if(
8649            BACKEND not in DistTestCases.backend_feature["ddp"],
8650            f"The {BACKEND} backend does not support DistributedDataParallel",
8651        )
8652        @skip_if_lt_x_gpu(2)
8653        def test_output_unused_in_loss_dict_module(self):
8654            module_cls = DictOutputModule
8655            for grad_as_bucket_view in [True, False]:
8656                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
8657
8658        @skip_but_pass_in_sandcastle_if(
8659            BACKEND not in DistTestCases.backend_feature["ddp"],
8660            f"The {BACKEND} backend does not support DistributedDataParallel",
8661        )
8662        @skip_if_lt_x_gpu(2)
8663        def test_undefined_grad_parity_unused_parameters(self):
8664            # TODO: enable this for general training use cases:
8665            # https://github.com/pytorch/pytorch/issues/58511.
8666            x = torch.ones(1, 2).to(self.rank)
8667            net = Net().to(self.rank)
8668            local_net = copy.deepcopy(net)
8669            net = torch.nn.parallel.DistributedDataParallel(
8670                net,
8671                device_ids=[self.rank],
8672                find_unused_parameters=True,
8673            )
8674            out = net(x).sum()
8675            local_out = local_net(x).sum()
8676            # Simulates undefined gradients.
8677            torch._C._functions.UndefinedGrad()(out).backward()
8678            torch._C._functions.UndefinedGrad()(local_out).backward()
8679            for (dist_param_name, dist_param), (local_param_name, local_param) in zip(
8680                net.named_parameters(), local_net.named_parameters()
8681            ):
8682                dist_grad = dist_param.grad
8683                local_grad = local_param.grad
8684                self.assertEqual(
8685                    dist_grad,
8686                    local_grad,
8687                    f"""DDP param {dist_param_name} with grad {dist_grad}
8688                    does not match local param {local_param_name} with grad
8689                    {local_grad}""",
8690                )
8691
8692        def _test_different_graph_across_ranks(
8693            self, find_unused_parameters=False, static_graph=False
8694        ):
8695            class ToyModel(nn.Module):
8696                def __init__(self, rank):
8697                    super().__init__()
8698                    self.lin1 = nn.Linear(10, 10, bias=False)
8699                    self.lin2 = nn.Linear(10, 10, bias=False)
8700                    self.rank = rank
8701
8702                def forward(self, x):
8703                    if self.rank == 0:
8704                        return self.lin2(F.relu(self.lin1(x)))
8705                    else:
8706                        return F.relu(self.lin1(x))
8707
8708            torch.manual_seed(31415)
8709            world_size = dist.get_world_size()
8710            torch.cuda.set_device(self.rank)
8711            model = ToyModel(self.rank).cuda(self.rank)
8712            ddp_model = torch.nn.parallel.DistributedDataParallel(
8713                model,
8714                device_ids=[self.rank],
8715                find_unused_parameters=find_unused_parameters,
8716                gradient_as_bucket_view=True,
8717                static_graph=static_graph,
8718            )
8719            random_input = torch.randn(20, 10, device=self.rank)
8720            for i in range(10):
8721                out = ddp_model(random_input)
8722                loss = out.sum()
8723                loss.backward()
8724            return ddp_model
8725
8726        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8727        @skip_if_lt_x_gpu(2)
8728        def test_different_graph_across_ranks(self):
8729            base_model = self._test_different_graph_across_ranks(
8730                find_unused_parameters=True
8731            )
8732            self.assertFalse(
8733                base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
8734            )
8735            static_model = self._test_different_graph_across_ranks(static_graph=True)
8736            self.assertTrue(
8737                static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
8738            )
8739            for i, j in zip(base_model.parameters(), static_model.parameters()):
8740                self.assertEqual(i, j)
8741
8742        @require_backend_is_available({"gloo"})
8743        @skip_but_pass_in_sandcastle_if(
8744            IS_MACOS or IS_WINDOWS,
8745            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
8746        )
8747        def test_monitored_barrier_gloo(self):
8748            tensors = [torch.ones(10) * self.rank]
8749            # Kick off some allreduce work on all ranks
8750            for _ in range(10):
8751                dist.all_reduce(torch.cat(tensors))
8752            # Run monitored barrier and ensure it passes
8753            timeout = timedelta(seconds=2)
8754            dist.monitored_barrier(timeout=timeout)
8755            # Check monitored_barrier success with wait_all_ranks=True
8756            for _ in range(10):
8757                dist.all_reduce(torch.cat(tensors))
8758            dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
8759            # All ranks besides 1 call into barrier, rank 0 should report failure
8760            # while others report gloo error.
8761            failed_rank = 1
8762            src_rank = 0
8763            if self.rank == src_rank:
8764                with self.assertRaisesRegex(
8765                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
8766                ):
8767                    dist.monitored_barrier(timeout=timeout)
8768            elif self.rank != failed_rank:
8769                # Other ranks should not pass barrier since rank 0 failed.
8770                err_regex = (
8771                    f"Rank {self.rank} successfully reached monitoredBarrier,"
8772                    f" but received errors while waiting for send/recv from rank"
8773                    f" {src_rank}"
8774                )
8775                with self.assertRaisesRegex(RuntimeError, err_regex):
8776                    dist.monitored_barrier(timeout=timeout)
8777
8778            # We need a barrier since otherwise failed_rank exits too early
8779            # and cause a timeout.
8780            self._barrier(timeout=30)
8781
8782        @require_backend_is_available({"gloo"})
8783        def test_monitored_barrier_gloo_subgroup(self):
8784            # Tests that monitored_barrier works as expected on non-default
8785            # process groups.
8786            failed_rank = 1
8787            timeout = 0.1
8788            subgroup = dist.new_group(ranks=[0, 1])
8789
8790            if self.rank == failed_rank:
8791                return
8792
8793            if self.rank == 0:
8794                with self.assertRaisesRegex(
8795                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
8796                ):
8797                    dist.monitored_barrier(subgroup, timeout)
8798            else:
8799                # Other ranks call into monitored_barrier, but this should be a
8800                # noop because they are not part of the subgroup. Verify that
8801                # there are no errors here.
8802                dist.monitored_barrier(subgroup, timeout)
8803
8804        def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks):
8805            # tests expected behavior when nonzero rank hangs.
8806            nccl_pg = dist.new_group(
8807                ranks=list(range(int(self.world_size))),
8808                # provide sufficient timeout so communicators
8809                # can be initialized in ctor.
8810                timeout=timedelta(seconds=15),
8811                backend=dist.Backend.NCCL,
8812            )
8813            gloo_pg = dist.new_group(
8814                ranks=list(range(int(self.world_size))),
8815                backend=dist.Backend.GLOO,
8816            )
8817            tensors = [torch.ones(10, device=self.rank) * self.rank]
8818            # Let all ranks call allreduce first to set up communicators etc.
8819            # Directly simulating error here will run into store issue described
8820            # in https://github.com/pytorch/pytorch/issues/54524.
8821            nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
8822            # All ranks besides 0 call into allreduce. This is to simulate a
8823            # desync across the world, where some ranks call into
8824            # monitored_barrier() and others are stuck in collective comm. In
8825            # practice, we don't need TORCH_NCCL_BLOCKING_WAIT, but we use it in this
8826            # test to ensure it exits cleanly.
8827            if self.rank != 0:
8828                # Can get different errors here depending on whether gloo-based
8829                # wrapper PG is enabled or not, since with wrapper pg, it will
8830                # fail in a collective synchronization check and not actually
8831                # call into the nccl pg.
8832                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
8833                    err_regex = "Timed out waiting"
8834                else:
8835                    err_regex = "caught collective operation timeout"
8836                with self.assertRaisesRegex(RuntimeError, err_regex):
8837                    nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1))
8838            else:
8839                # Rank 0 should report first (in order) timed out rank or all ranks
8840                # depending on wait_all_ranks flag passed into monitored_barrier.
8841                if wait_all_ranks:
8842                    rank_str = ", ".join(
8843                        [str(i) for i in range(1, int(self.world_size))]
8844                    )
8845                    err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
8846                else:
8847                    expected_first_fail_rank = 1
8848                    err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier"
8849                monitored_barrier_timeout_seconds = timedelta(seconds=0.1)
8850                with self.assertRaisesRegex(RuntimeError, err_regex):
8851                    gloo_pg.monitored_barrier(
8852                        monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
8853                    )
8854
8855            self._barrier(timeout=30)
8856
8857        @with_nccl_blocking_wait
8858        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8859        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
8860        def test_monitored_barrier_allreduce_hang(self):
8861            # tests expected behavior when nonzero rank hangs and we want to
8862            # report first timed out rank.
8863            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False)
8864
8865        @with_nccl_blocking_wait
8866        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8867        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
8868        def test_monitored_barrier_allreduce_hang_wait_all_ranks(self):
8869            # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
8870            os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
8871            # tests expected behavior when nonzero rank hangs and we want to
8872            # report all timed out ranks.
8873            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True)
8874
8875        @require_backend_is_available({"gloo"})
8876        def test_monitored_barrier_gloo_rank_0_timeout(self):
8877            # tests error when rank 0 exhausts its given timeout.
8878            process_group = dist.new_group(ranks=list(range(int(self.world_size))))
8879            timeout = timedelta(seconds=0)
8880            if self.rank == 0:
8881                with self.assertRaisesRegex(
8882                    RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier"
8883                ):
8884                    process_group.monitored_barrier(timeout)
8885
8886        @require_backend_is_available({"gloo"})
8887        @skip_if_small_worldsize
8888        @skip_but_pass_in_sandcastle_if(
8889            IS_MACOS or IS_WINDOWS,
8890            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
8891        )
8892        def test_monitored_barrier_failure_order(self):
8893            # Ensure that the first (in sorted order) rank is reported when
8894            # multiple ranks fail to pass the monitored_barrier.
8895            # TODO(#54879): Provide ability to wait and report all failed ranks
8896            expected_first_failed_rank = 2
8897            timeout = timedelta(seconds=2)
8898            src_rank = 0
8899            if self.rank == src_rank:
8900                with self.assertRaisesRegex(
8901                    RuntimeError, f"Rank {expected_first_failed_rank}"
8902                ):
8903                    dist.monitored_barrier(timeout=timeout)
8904            elif self.rank == 1:
8905                err_regex = (
8906                    f"Rank {self.rank} successfully reached monitoredBarrier,"
8907                    f" but received errors while waiting for send/recv from rank"
8908                    f" {src_rank}"
8909                )
8910                with self.assertRaisesRegex(RuntimeError, err_regex):
8911                    dist.monitored_barrier(timeout=timeout)
8912
8913        @require_backend_is_available({"gloo"})
8914        @skip_if_small_worldsize
8915        def test_monitored_barrier_wait_all_ranks(self):
8916            # Tests simple case where > 1 rank does not call into monitored
8917            # barrier and verifies all ranks are reported by rank 0.
8918            if self.rank == 0:
8919                timeout = timedelta(seconds=0.1)
8920                rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))])
8921                err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
8922                with self.assertRaisesRegex(RuntimeError, err_regex):
8923                    dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
8924
8925        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
8926        @with_dist_debug_levels(levels=["INFO"])
8927        @skip_if_lt_x_gpu(2)
8928        def test_ddp_build_debug_param_to_name_mapping(self):
8929            model = TwoLinLayerNet()
8930            net = torch.nn.parallel.DistributedDataParallel(
8931                model.cuda(self.rank),
8932                device_ids=[self.rank],
8933            )
8934            expected_mapping = {0: "a.weight", 1: "b.weight"}
8935            net_params, _ = net._build_params_for_reducer()
8936            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
8937            self.assertDictEqual(expected_mapping, param_to_name_mapping)
8938
8939            # Test when DDP is used with ignored parameters.
8940            model = TwoLinLayerNet()
8941            # Parameters to ignore are in the format {module_name}.{param_name}
8942            params_to_ignore = ["a.weight"]
8943            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
8944                model, params_to_ignore
8945            )
8946            net = torch.nn.parallel.DistributedDataParallel(
8947                model.cuda(self.rank),
8948                device_ids=[self.rank],
8949            )
8950            expected_mapping = {0: "b.weight"}
8951            net_params, _ = net._build_params_for_reducer()
8952            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
8953            self.assertDictEqual(expected_mapping, param_to_name_mapping)
8954
8955            # Test errors are raised when DDP and module parameters mismatch.
8956            # This generally indicates a bug with DDP and is not expected to
8957            # happen in user applications.
8958            model = TwoLinLayerNet()
8959            net = torch.nn.parallel.DistributedDataParallel(
8960                model.cuda(self.rank),
8961                device_ids=[self.rank],
8962            )
8963            net_params, _ = net._build_params_for_reducer()
8964            if self.rank == 0:
8965                print(type(net_params[0]))
8966
8967            net_params.extend(
8968                [
8969                    torch.nn.Parameter(torch.ones(1)),
8970                    torch.nn.Parameter(torch.ones(1)),
8971                ]
8972            )
8973
8974            with self.assertRaisesRegex(ValueError, "Expected param to name mapping"):
8975                net._build_debug_param_to_name_mapping(net_params)
8976
8977            net_params = net_params[:-3]
8978            with self.assertRaisesRegex(ValueError, "Param with name"):
8979                net._build_debug_param_to_name_mapping(net_params)
8980
8981            net_params.extend(
8982                [
8983                    torch.nn.Parameter(torch.ones(1)),
8984                    torch.nn.Parameter(torch.ones(1)),
8985                ]
8986            )
8987
8988        @skip_but_pass_in_sandcastle_if(
8989            BACKEND not in DistTestCases.backend_feature["ddp"],
8990            f"The {BACKEND} backend does not support DistributedDataParallel",
8991        )
8992        @with_dist_debug_levels(levels=["INFO"])
8993        @skip_if_lt_x_gpu(2)
8994        def test_ddp_build_debug_param_to_name_mapping_requires_grad(self):
8995            class Net(nn.Module):
8996                def __init__(self) -> None:
8997                    super().__init__()
8998                    self.lin = nn.Linear(10, 10)
8999                    # Is not tracked by DDP and should not show up in param to
9000                    # name mapping.
9001                    self.lin.bias.requires_grad_(False)
9002
9003                def forward(self, x):
9004                    return self.lin(x)
9005
9006            model = Net()
9007            net = torch.nn.parallel.DistributedDataParallel(
9008                model.cuda(self.rank), device_ids=[self.rank]
9009            )
9010            expected_mapping = {
9011                0: "lin.weight",
9012            }
9013            net_params, _ = net._build_params_for_reducer()
9014            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
9015            self.assertEqual(param_to_name_mapping, expected_mapping)
9016
9017        def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse):
9018            debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF
9019
9020            class SubModule(nn.Module):
9021                def __init__(self) -> None:
9022                    super().__init__()
9023                    self.embedding_net = EmbeddingNetDifferentParams(0)
9024                    self.lin = TwoLinLayerNet()
9025                    self.bn = BatchNormNet()
9026                    self.lin_layer = nn.Linear(4, 10, bias=False)
9027
9028                def forward(self, x):
9029                    x = self.bn(x)
9030                    x = self.lin_layer(x)
9031                    x = self.lin.a(x)  # self.lin.b param unused
9032                    # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and
9033                    # self.embedding_net.lin unused.
9034                    return x
9035
9036            class MyModel(nn.Module):
9037                def __init__(self) -> None:
9038                    super().__init__()
9039                    self.sub_module = SubModule()
9040
9041                def forward(self, x):
9042                    return self.sub_module(x)
9043
9044            model = MyModel()
9045            sparse_embedding_fqns = []
9046            if ignore_sparse:
9047                for module_name, module in model.named_modules():
9048                    if module == model.sub_module.embedding_net.embedding:
9049                        for parameter_name, param in module.named_parameters(
9050                            recurse=False
9051                        ):
9052                            fqn = f"{module_name}.{parameter_name}"
9053                            sparse_embedding_fqns.append(fqn)
9054
9055                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
9056                    model, sparse_embedding_fqns
9057                )
9058                unused_modules = [
9059                    model.sub_module.embedding_net.lin,
9060                    model.sub_module.lin.b,
9061                ]
9062            else:
9063                unused_modules = list(model.sub_module.embedding_net.modules()) + [
9064                    model.sub_module.lin.b,
9065                ]
9066
9067            expected_unused_param_fqns = []
9068            used_param_fqns = []  # Validate that these don't mistakenly show up.
9069            fqn_to_param_index = {}
9070            index = 0
9071            for module_name, module in model.named_modules():
9072                for parameter_name, param in module.named_parameters(recurse=False):
9073                    fqn = f"{module_name}.{parameter_name}"
9074                    fqn_to_param_index[fqn] = index
9075                    if fqn not in sparse_embedding_fqns:
9076                        index += 1
9077                    if module in unused_modules:
9078                        expected_unused_param_fqns.append(fqn)
9079                    else:
9080                        if (
9081                            not ignore_sparse
9082                            or module != model.sub_module.embedding_net.embedding
9083                        ):
9084                            used_param_fqns.append(fqn)
9085
9086            net = torch.nn.parallel.DistributedDataParallel(
9087                model.cuda(self.rank),
9088                device_ids=[self.rank],
9089            )
9090            batch, dim = 10, 2
9091            inp = torch.ones(batch, dim)
9092            for i in range(2):
9093                if i == 0:
9094                    out = net(inp)
9095                    loss = out.sum()
9096                    loss.backward()
9097                else:
9098                    try:
9099                        out = net(inp)
9100                        loss = out.sum()
9101                        loss.backward()
9102                    except RuntimeError as e:
9103                        e = str(e)
9104
9105                        unused_param_substr = e[e.find("did not receive grad") :]
9106                        # Validate that each unused param fully qualified name
9107                        # shows up in error logs. We do this instead of
9108                        # constructing a joined string since order of parameters
9109                        # can be different in Reducer. In addition, validate
9110                        # param indices show up as well.
9111                        for unused_param_fqn in expected_unused_param_fqns:
9112                            self.assertTrue(
9113                                unused_param_fqn in unused_param_substr
9114                                or debug_mode_off
9115                            )
9116                            self.assertTrue(
9117                                str(fqn_to_param_index[unused_param_fqn])
9118                                in unused_param_substr,
9119                                f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}",
9120                            )
9121
9122                        # Validate that used param fqns don't show up in error
9123                        # logs.
9124                        for used_param_fqn in used_param_fqns:
9125                            self.assertFalse(used_param_fqn in unused_param_substr)
9126                        # Validate that ignored param fqns don't show up as unused
9127                        # (since DDP does not track them)
9128                        for sparse_param_fqn in sparse_embedding_fqns:
9129                            self.assertFalse(sparse_param_fqn in unused_param_substr)
9130                    else:
9131                        self.assertTrue(False, "Expected error was not raised!")
9132
9133        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
9134        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
9135        @skip_if_lt_x_gpu(2)
9136        def test_ddp_multiple_nested_unused_params_error(self):
9137            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False)
9138
9139        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
9140        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
9141        @skip_if_lt_x_gpu(2)
9142        def test_ddp_multiple_nested_unused_params_err_ignore_params(self):
9143            # Tests unused parameter reporting when DDP is configured to ignore
9144            # certain parameters.
9145            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True)
9146
9147        @skip_but_pass_in_sandcastle_if(
9148            BACKEND not in DistTestCases.backend_feature["ddp"],
9149            f"The {BACKEND} backend does not support DistributedDataParallel",
9150        )
9151        @skip_if_lt_x_gpu(2)
9152        def test_ddp_inference(self):
9153            # tests that DDP module can be run on a single node with no_grad
9154            # or eval setting and there is no hang.
9155            rank = self.rank
9156            torch.cuda.set_device(rank)
9157            model = Net().cuda()
9158            local_model = copy.deepcopy(model)
9159            model = torch.nn.parallel.DistributedDataParallel(
9160                model,
9161                device_ids=[rank],
9162            )
9163            syncbn_model = nn.SyncBatchNorm(
9164                2, momentum=0.99, track_running_stats=False
9165            ).cuda()
9166            local_syncbn_model = copy.deepcopy(syncbn_model)
9167            syncbn_model = torch.nn.parallel.DistributedDataParallel(
9168                syncbn_model, device_ids=[rank]
9169            )
9170            inp = torch.randn(10, 2, device=rank)
9171            inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
9172            tests = [
9173                (model, local_model, inp),
9174                (syncbn_model, local_syncbn_model, inp_syncbn),
9175            ]
9176            for test in tests:
9177                test_model, test_local_model, test_inp = test
9178                if self.rank == 0:
9179                    test_model.eval()
9180                    test_local_model.eval()
9181                    for _ in range(6):
9182                        self.assertEqual(
9183                            test_model(test_inp), test_local_model(test_inp)
9184                        )
9185
9186            # Barrier since only rank 0 runs inference. Test should be
9187            # much faster than 30s, but this is to avoid flakiness.
9188            self._barrier(timeout=30)
9189
9190        @skip_but_pass_in_sandcastle_if(
9191            BACKEND not in DistTestCases.backend_feature["ddp"],
9192            f"The {BACKEND} backend does not support DistributedDataParallel",
9193        )
9194        @skip_if_lt_x_gpu(2)
9195        @unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
9196        def test_ddp_sync_bn_training_vs_eval(self):
9197            rank = self.rank
9198            torch.cuda.set_device(rank)
9199            # Need to set track_running_stats=False, when track_running_stats=True,
9200            # bn_training is False and sync could not occur in eval model.
9201            model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda(
9202                rank
9203            )
9204            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
9205            # Test sync occurs in training mode.
9206            with torch.autograd.profiler.profile() as prof:
9207                for i in range(6):
9208                    inp = torch.randn(10, 2, 4, 4).cuda(rank)
9209                    out = model(inp)
9210                    loss = out.sum()
9211                    loss.backward()
9212
9213            # SyncBN allgathers stats across all ranks, so verify call to
9214            # all_gather in profiler.
9215            if BACKEND == "nccl":
9216                all_gather_calls = get_profiling_event("_all_gather_base", prof)
9217            else:
9218                all_gather_calls = get_profiling_event("all_gather", prof)
9219            self.assertNotEqual([], all_gather_calls)
9220
9221            # Only do inference on one rank. If SyncBN did collective stats sync,
9222            # this would hang/error.
9223            model_inference = model.module
9224            if self.rank == 0:
9225                model_inference.eval()
9226                with torch.autograd.profiler.profile() as prof:
9227                    for i in range(6):
9228                        inp = torch.randn(10, 2, 4, 4).cuda(rank)
9229                        out = model_inference(inp)
9230                        loss = out.sum()
9231                        loss.backward()
9232
9233                # Ensure sync does not occur in eval() mode.
9234                if BACKEND == "nccl":
9235                    all_gather_calls = get_profiling_event("_all_gather_base", prof)
9236                else:
9237                    all_gather_calls = get_profiling_event("all_gather", prof)
9238                self.assertEqual([], all_gather_calls)
9239
9240        @skip_if_lt_x_gpu(2)
9241        @skip_but_pass_in_sandcastle_if(
9242            BACKEND not in DistTestCases.backend_feature["ddp"],
9243            f"The {BACKEND} backend does not support DistributedDataParallel",
9244        )
9245        def test_ddp_python_error_logged(self):
9246            # Most python exceptions in DDP are raised during init before
9247            # reducer is constructed, so we don't have a logger in those cases.
9248            # However, the below is one example where a python error is thrown
9249            # after reducer is constructed.
9250            model = TwoLinLayerNet().cuda(self.rank)
9251            model = torch.nn.parallel.DistributedDataParallel(
9252                model,
9253                device_ids=[self.rank],
9254            )
9255            expected_err = "must be callable"
9256            with self.assertRaisesRegex(TypeError, expected_err):
9257                model.register_comm_hook({}, {})
9258
9259            verify_ddp_error_logged(model, expected_err)
9260
9261        @skip_if_lt_x_gpu(2)
9262        @skip_but_pass_in_sandcastle_if(
9263            BACKEND not in DistTestCases.backend_feature["ddp"],
9264            f"The {BACKEND} backend does not support DistributedDataParallel",
9265        )
9266        def test_ddp_static_graph_nested_types(self):
9267            # Tests for static graph training when outputs are not just tensors
9268            # but can be (nested) tuple, list, dict, etc.
9269            rank = self.rank
9270            torch.cuda.set_device(rank)
9271
9272            class NestedOutputModule(torch.nn.Module):
9273                def __init__(self) -> None:
9274                    super().__init__()
9275                    self.lin = nn.Linear(100, 1, bias=False)
9276
9277                def forward(self, inp, output_type):
9278                    if output_type == "tuple":
9279                        return (
9280                            self.lin(inp),
9281                            (
9282                                self.lin(inp),
9283                                self.lin(inp),
9284                            ),
9285                        )
9286                    elif output_type == "list":
9287                        return [
9288                            self.lin(inp),
9289                            [
9290                                self.lin(inp),
9291                                self.lin(inp),
9292                            ],
9293                        ]
9294                    elif output_type == "dict":
9295                        return {
9296                            "a": self.lin(inp),
9297                            "b": {
9298                                "c": self.lin(inp),
9299                            },
9300                        }
9301
9302            def get_loss(model_output):
9303                loss = 0.0
9304                if isinstance(model_output, torch.Tensor):
9305                    return model_output.sum()
9306                elif isinstance(model_output, dict):
9307                    for value in model_output.values():
9308                        loss += get_loss(value)
9309                elif isinstance(model_output, (tuple, list)):
9310                    for x in model_output:
9311                        loss += get_loss(x)
9312                else:
9313                    raise ValueError(f"Unknown model output type {type(model_output)}")
9314                return loss
9315
9316            model = NestedOutputModule().cuda(rank)
9317            model_static_graph = copy.deepcopy(model)
9318            model = torch.nn.parallel.DistributedDataParallel(
9319                model,
9320                device_ids=[rank],
9321            )
9322            model_static_graph = torch.nn.parallel.DistributedDataParallel(
9323                model,
9324                device_ids=[rank],
9325                static_graph=True,
9326            )
9327            inp = torch.randn(10, 100)
9328            type_mapping = {
9329                "list": list,
9330                "tuple": tuple,
9331                "dict": dict,
9332            }
9333            for output_type in type_mapping.keys():
9334                for i in range(6):
9335                    out = model(inp, output_type=output_type)
9336                    loss = get_loss(out)
9337                    loss.backward()
9338                    self._model_step(model)
9339                    out_static = model_static_graph(inp, output_type=output_type)
9340                    self.assertTrue(isinstance(out_static, type_mapping[output_type]))
9341                    loss_static = get_loss(out_static)
9342                    loss_static.backward()
9343                    self._model_step(model_static_graph)
9344                    for (p, p_static) in zip(
9345                        model.parameters(), model_static_graph.parameters()
9346                    ):
9347                        self.assertEqual(p, p_static)
9348
9349        @skip_if_lt_x_gpu(2)
9350        @skip_but_pass_in_sandcastle_if(
9351            BACKEND not in DistTestCases.backend_feature["ddp"],
9352            f"The {BACKEND} backend does not support DistributedDataParallel",
9353        )
9354        def test_ddp_returns_tensor_with_no_grad(self):
9355            # Tests case where module returns tensor that does not require grad.
9356            torch.cuda.set_device(self.rank)
9357
9358            class MyModel(nn.Module):
9359                def __init__(self) -> None:
9360                    super().__init__()
9361                    self.fc1 = nn.Linear(10, 10, bias=False)
9362                    self.fc2 = nn.Linear(10, 10, bias=False)
9363
9364                def forward(self, x):
9365                    x = self.fc2(F.relu(self.fc1(x)))
9366                    y = x.clone()
9367                    x = x.detach()
9368                    assert not x.requires_grad
9369                    return (x, y)
9370
9371            model = MyModel().to(self.rank)
9372            inp = torch.randn(1, 10, device=self.rank)
9373            for (find_unused, static_graph) in itertools.product(
9374                [True, False], [True, False]
9375            ):
9376                ddp = DistributedDataParallel(
9377                    model,
9378                    device_ids=[self.rank],
9379                    output_device=self.rank,
9380                    find_unused_parameters=find_unused,
9381                    static_graph=static_graph,
9382                )
9383                for i in range(6):
9384                    out = ddp(inp)
9385                    self.assertFalse(out[0].requires_grad)
9386                    o = (out[0] + out[1]).sum()
9387                    o.backward()
9388
9389        @skip_if_lt_x_gpu(2)
9390        @skip_but_pass_in_sandcastle_if(
9391            BACKEND not in DistTestCases.backend_feature["ddp"],
9392            f"The {BACKEND} backend does not support DistributedDataParallel",
9393        )
9394        def test_detect_ddp_is_actually_static(self):
9395            class ToyModel(nn.Module):
9396                def __init__(self) -> None:
9397                    super().__init__()
9398                    self.net1 = nn.Linear(10, 10, bias=False)
9399                    self.net2 = nn.Linear(10, 10)
9400
9401                def forward(self, x, find_unused, dynamic):
9402                    if find_unused:
9403                        if dynamic:
9404                            return self.net2(self.net1(x))
9405                        else:
9406                            return self.net2(x)
9407                    else:
9408                        return self.net2(self.net1(x))
9409
9410            # Set of unused parameters don't change across iterations
9411            torch.cuda.set_device(self.rank)
9412            model = ToyModel().cuda()
9413            for find_unused in [True, False]:
9414                ddp = torch.nn.parallel.DistributedDataParallel(
9415                    model,
9416                    device_ids=[self.rank],
9417                    find_unused_parameters=find_unused,
9418                )
9419                inp = torch.randn(1, 10, device="cuda")
9420                for _ in range(6):
9421                    out = ddp(inp, find_unused=find_unused, dynamic=False)
9422                    loss = out.sum()
9423                    loss.backward()
9424                    self.assertTrue(ddp.reducer._ddp_graph_static())
9425
9426            # Set of unused parameters dynamically change
9427            ddp = torch.nn.parallel.DistributedDataParallel(
9428                model,
9429                device_ids=[self.rank],
9430                find_unused_parameters=True,
9431            )
9432            inp = torch.randn(1, 10, device="cuda")
9433            for i in range(6):
9434                out = ddp(inp, find_unused=True, dynamic=i % 2 == 0)
9435                loss = out.sum()
9436                loss.backward()
9437            self.assertFalse(ddp.reducer._ddp_graph_static())
9438
9439        def _test_ddp_new_tensor_in_fwd(self, static_graph):
9440            # Test from https://github.com/pytorch/pytorch/issues/60733
9441            class MyModel(nn.Module):
9442                def __init__(self) -> None:
9443                    super().__init__()
9444                    self.fc1 = nn.Linear(10, 10, bias=False)
9445                    self.fc2 = nn.Linear(10, 10, bias=False)
9446                    self.device = self.fc1.weight.device
9447
9448                def __init_opt(self):
9449                    opt = torch.randn(1, 10, device=self.device)
9450                    return opt
9451
9452                def forward(self, x, opt_1, opt_2, opt_nested):
9453                    x = F.relu(self.fc1(x))
9454                    x = self.fc2(x)
9455                    if opt_1 is None:
9456                        opt_1 = self.__init_opt()
9457                    if opt_2 is None:
9458                        opt_2 = self.__init_opt()
9459                    if opt_nested is None or not torch.is_tensor(opt_nested):
9460                        opt_nested = self.__init_opt()
9461                    # Test multiple tensors as well as newly created tensors
9462                    # within a struct.
9463                    return x, opt_1, opt_2, {"tensor": opt_nested}
9464
9465            model = MyModel().to(self.rank)
9466            for find_unused in [True, False]:
9467                ddp = DistributedDataParallel(
9468                    model,
9469                    device_ids=[self.rank],
9470                    output_device=self.rank,
9471                    broadcast_buffers=False,
9472                    find_unused_parameters=find_unused,
9473                    static_graph=static_graph,
9474                )
9475
9476                opt = [None for _ in range(3)]
9477                for i in range(2):
9478                    ddp.zero_grad()
9479                    x = torch.randn(1, 10, device=self.rank)
9480                    out, opt[0], opt[1], opt[2] = ddp(
9481                        x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2]
9482                    )
9483                    for i in range(len(opt)):
9484                        if torch.is_tensor(opt[i]):
9485                            self.assertEqual(opt[i].grad_fn, None)
9486                        else:
9487                            self.assertEqual(opt[i]["tensor"].grad_fn, None)
9488                    out.mean().backward()
9489
9490        @skip_if_lt_x_gpu(2)
9491        @skip_but_pass_in_sandcastle_if(
9492            BACKEND not in DistTestCases.backend_feature["ddp"],
9493            f"The {BACKEND} backend does not support DistributedDataParallel",
9494        )
9495        def test_ddp_new_tensor_in_fwd(self):
9496            return self._test_ddp_new_tensor_in_fwd(static_graph=False)
9497
9498        @skip_if_lt_x_gpu(2)
9499        @skip_but_pass_in_sandcastle_if(
9500            BACKEND not in DistTestCases.backend_feature["ddp"],
9501            f"The {BACKEND} backend does not support DistributedDataParallel",
9502        )
9503        def test_ddp_new_tensor_in_fwd_static_graph(self):
9504            return self._test_ddp_new_tensor_in_fwd(static_graph=True)
9505
9506        def _test_ddp_buffer_hook_allreduce(self, return_futures):
9507            rank = self.rank
9508            torch.cuda.set_device(rank)
9509            torch.manual_seed(rank)
9510            torch.cuda.manual_seed(rank)
9511
9512            def buffer_comm_hook(ddp, named_buffers):
9513                buffers = [buffer for (_, buffer) in named_buffers.items()]
9514                futs = [
9515                    dist.all_reduce(
9516                        buffer, group=ddp.process_group, async_op=True
9517                    ).get_future()
9518                    for buffer in buffers
9519                ]
9520                if return_futures:
9521                    return futs
9522                else:
9523                    torch.futures.collect_all(futs).wait()
9524
9525            hook_pre_fwd = (
9526                torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD
9527            )
9528            hook_post_fwd = (
9529                torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD
9530            )
9531            for hook_run_location in [
9532                hook_pre_fwd,
9533                hook_post_fwd,
9534            ]:
9535                model = NetWithBuffers().cuda(rank)
9536                model_ddp = torch.nn.parallel.DistributedDataParallel(
9537                    model,
9538                    device_ids=[self.rank],
9539                )
9540                model_ddp._register_buffer_comm_hook(
9541                    model_ddp, buffer_comm_hook, hook_run_location
9542                )
9543                model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
9544                    copy.deepcopy(model),
9545                    device_ids=[self.rank],
9546                    broadcast_buffers=False,
9547                )
9548                inp = torch.randn(2, 10, device=rank)
9549                for i in range(2):
9550                    loss_hook = model_ddp(inp).sum()
9551                    # Since buffer reduction is done pre-forward, simulate it for
9552                    # no hook case here.
9553                    # Simulate allreduce appropriately depending on hook location.
9554                    if hook_run_location == hook_pre_fwd:
9555                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
9556                        for tensor in model_no_hook_buffers:
9557                            dist.all_reduce(tensor)
9558
9559                    loss_no_hook = model_ddp_no_hook(inp).sum()
9560                    if hook_run_location == hook_post_fwd:
9561                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
9562                        for tensor in model_no_hook_buffers:
9563                            dist.all_reduce(tensor)
9564                    torch.cuda.synchronize()
9565
9566                    # if return_futures, they are only awaited on by DDP
9567                    # at the end of the backwards pass for maximum overlap.
9568                    if not return_futures:
9569                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
9570                    loss_hook.backward()
9571                    loss_no_hook.backward()
9572                    # Note that when custom hooks return futures, this
9573                    # comparison is not expected to work when hook run location
9574                    # is pre-forward pass. This is because the hook does async
9575                    # communication and forward pass modifies the buffer without
9576                    # appropriate synchronization. Therefore, if returning
9577                    # futures from custom buffer hooks, it is advised to set
9578                    # hook run location to post forward.
9579                    if return_futures and hook_run_location == hook_post_fwd:
9580                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
9581                dist.barrier()
9582
9583        @skip_if_lt_x_gpu(2)
9584        @skip_but_pass_in_sandcastle_if(
9585            BACKEND not in DistTestCases.backend_feature["ddp"],
9586            f"The {BACKEND} backend does not support DistributedDataParallel",
9587        )
9588        def test_ddp_buffer_hook_allreduce_return_future(self):
9589            self._test_ddp_buffer_hook_allreduce(return_futures=True)
9590
9591        @skip_if_lt_x_gpu(2)
9592        @skip_but_pass_in_sandcastle_if(
9593            BACKEND not in DistTestCases.backend_feature["ddp"],
9594            f"The {BACKEND} backend does not support DistributedDataParallel",
9595        )
9596        def test_ddp_buffer_hook_allreduce(self):
9597            self._test_ddp_buffer_hook_allreduce(return_futures=False)
9598
9599        @skip_if_lt_x_gpu(2)
9600        @skip_but_pass_in_sandcastle_if(
9601            BACKEND not in DistTestCases.backend_feature["ddp"],
9602            f"The {BACKEND} backend does not support DistributedDataParallel",
9603        )
9604        def test_ddp_broadcast_buffer_via_hook(self):
9605            # test that _distributed_broadcast_coalesced via registered hook is
9606            # equivalent to DDP's default broadcast coalesced.
9607            rank = self.rank
9608            torch.cuda.set_device(rank)
9609            torch.manual_seed(rank)
9610            torch.cuda.manual_seed(rank)
9611
9612            def buffer_comm_hook(ddp, named_buffers):
9613                # named_buffers is a Dict[str, Tensor] representing a mapping
9614                # from buffer name to buffer.
9615                buffers = [buffer for (_, buffer) in named_buffers.items()]
9616                ddp._default_broadcast_coalesced(buffers)
9617
9618            model = NetWithBuffers().cuda(rank)
9619            model_ddp = torch.nn.parallel.DistributedDataParallel(
9620                model,
9621                device_ids=[self.rank],
9622            )
9623            model_ddp._register_buffer_comm_hook(model_ddp, buffer_comm_hook)
9624            model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
9625                copy.deepcopy(model),
9626                device_ids=[self.rank],
9627            )
9628            inp = torch.randn(2, 10, device=rank)
9629            for i in range(2):
9630                loss_hook = model_ddp(inp).sum()
9631                loss_no_hook = model_ddp_no_hook(inp).sum()
9632                self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
9633                loss_hook.backward()
9634                loss_no_hook.backward()
9635
9636        @skip_if_lt_x_gpu(2)
9637        @skip_but_pass_in_sandcastle_if(
9638            BACKEND not in DistTestCases.backend_feature["ddp"],
9639            f"The {BACKEND} backend does not support DistributedDataParallel",
9640        )
9641        def test_ddp_remove_autograd_hooks(self):
9642
9643            class SimulateError(torch.autograd.Function):
9644                @staticmethod
9645                def forward(ctx, input):
9646                    return input
9647
9648                @staticmethod
9649                def backward(ctx, grad_output):
9650                    raise RuntimeError
9651
9652            class MyModel(nn.Module):
9653                def __init__(self, device):
9654                    super().__init__()
9655                    self.error = True
9656                    self.fc1 = nn.Linear(10, 10).cuda(device)
9657
9658                def forward(self, inp):
9659                    if self.error:
9660                        return self.fc1(SimulateError.apply(inp))
9661                    else:
9662                        return self.fc1(inp)
9663
9664
9665            # Run with error to trigger backward pass that marks fc1 as being marked
9666            # ready. If we don't remove autograd hooks before running below it would
9667            # fail on the old autograd hook.
9668            model = MyModel(self.rank)
9669            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
9670            model_ddp1 = torch.nn.parallel.DistributedDataParallel(
9671                model,
9672                device_ids=[self.rank],
9673            )
9674
9675            with self.assertRaises(RuntimeError):
9676                model_ddp1(input).sum().backward()
9677
9678            # Remove autograd hooks on old instance.
9679            model_ddp1._remove_autograd_hooks()
9680
9681            # Try another DDP instance without error now.
9682            model.error = False
9683            model_ddp2 = torch.nn.parallel.DistributedDataParallel(
9684                model,
9685                device_ids=[self.rank],
9686            )
9687            model_ddp2(input).sum().backward()
9688
9689        @skip_if_lt_x_gpu(2)
9690        @skip_but_pass_in_sandcastle_if(
9691            BACKEND not in DistTestCases.backend_feature["ddp"],
9692            f"The {BACKEND} backend does not support DistributedDataParallel",
9693        )
9694        @unittest.skip("Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751")
9695        def test_ddp_has_finalized(self):
9696
9697            @dataclass
9698            class MyClass:
9699                obj: torch.Tensor
9700
9701            class MyModel(nn.Module):
9702                def __init__(self, rank):
9703                    super().__init__()
9704                    self.rank = rank
9705                    self.fc1 = nn.Linear(1024, 1024).cuda(rank)
9706                    self.fc2 = nn.Linear(1024, 2 * 1024).cuda(rank)
9707
9708                def forward(self, inp):
9709                    if self.rank == 0:
9710                        return self.fc1(inp), MyClass(self.fc2(inp))
9711                    else:
9712                        return self.fc1(inp), self.fc2(inp)
9713
9714            model = MyModel(self.rank)
9715            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
9716            ddp = torch.nn.parallel.DistributedDataParallel(
9717                model,
9718                device_ids=[self.rank],
9719                find_unused_parameters=True,
9720                bucket_cap_mb=(1024 * 4 / 1024 / 1024),  # One bucket per parameter.
9721            )
9722
9723            if self.rank == 0:
9724                out1, _ = ddp(input)
9725                out1.sum().backward()
9726            else:
9727                out1, out2 = ddp(input)
9728                (out1.sum() + out2.sum()).backward()
9729
9730            if self.rank == 0:
9731                with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
9732                    ddp._check_reducer_finalized()
9733
9734                with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
9735                    ddp(input)
9736            else:
9737                ddp._check_reducer_finalized()
9738                ddp(input)
9739
9740        @skip_if_lt_x_gpu(2)
9741        @skip_but_pass_in_sandcastle_if(
9742            BACKEND != "nccl",
9743            "TORCH_NCCL_USE_COMM_NONBLOCKING only applies to NCCL"
9744        )
9745        def test_nccl_init_abort(self):
9746            """
9747            Tests that we can abort a NCCL communicator during initialization and
9748            recover appropriately.
9749            """
9750            # Reinitialize global process group with TORCH_NCCL_USE_COMM_NONBLOCKING=1
9751            os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
9752            dist.destroy_process_group()
9753            timeout = timedelta(seconds=1)
9754            dist.init_process_group(
9755                init_method=INIT_METHOD,
9756                backend=BACKEND,
9757                world_size=int(os.environ["WORLD_SIZE"]),
9758                rank=self.rank,
9759                timeout=timeout,
9760            )
9761
9762            # Abort pg in background thread.
9763            running = True
9764
9765            def abort(device):
9766                pg = _get_default_group()
9767                while running:
9768                    pg._get_backend(torch.device(device))._shutdown()
9769                    time.sleep(1)
9770
9771            if self.rank != 1:
9772                import threading
9773                t = threading.Thread(target=abort, args=(self.rank,))
9774                t.start()
9775                with self.assertRaises(RuntimeError):
9776                    # First collective triggers initialization via ncclCommInitRank.
9777                    torch.distributed.barrier()
9778                running = False
9779                t.join()
9780
9781        def _run_ddp_update_process_group(self, new_pg):
9782            def get_num_torch_recompiles():
9783                guard_failures = torch._dynamo.utils.guard_failures
9784                num_recompiles = [len(guard_failures[code]) for code in guard_failures]
9785                return 0 if len(num_recompiles) == 0 else max(num_recompiles)
9786
9787            class SimulateError(torch.autograd.Function):
9788                @staticmethod
9789                def forward(ctx, input):
9790                    return input
9791
9792                @staticmethod
9793                def backward(ctx, grad_output):
9794                    raise RuntimeError
9795
9796            class MyModel(torch.nn.Module):
9797                def __init__(self, device):
9798                    super().__init__()
9799                    # 4MB for multiple buckets.
9800                    self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
9801                    self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
9802                    self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
9803
9804                def forward(self, inp, error):
9805                    if error:
9806                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
9807                    else:
9808                        return self.fc3(self.fc2(self.fc1(inp)))
9809
9810
9811            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
9812            ddp = torch.nn.parallel.DistributedDataParallel(
9813                MyModel(self.rank),
9814                device_ids=[self.rank],
9815                find_unused_parameters=True,
9816                bucket_cap_mb=1,
9817            )
9818            model = torch.compile(ddp)
9819
9820            def run_iteration():
9821                # Run regular iteration.
9822                out = model(input, error=False)
9823                out.sum().backward()
9824                torch.cuda.synchronize()
9825
9826                # Run with error.
9827                with self.assertRaises(RuntimeError):
9828                    out = model(input, error=True)
9829                    out.sum().backward()
9830                torch.cuda.synchronize()
9831
9832            run_iteration()
9833            assert 0 == get_num_torch_recompiles()
9834
9835            if new_pg:
9836                # Now reduce world_size and run iteration.
9837                group_size_2 = dist.new_group(ranks=[0, 1])
9838                ddp._update_process_group(group_size_2)
9839                if self.rank in [0, 1]:
9840                    run_iteration()
9841
9842                # Increase the world size and run iteration.
9843                group_size_3 = dist.new_group(ranks=[1, 2, 3])
9844                ddp._update_process_group(group_size_3)
9845                if self.rank in [1, 2, 3]:
9846                    run_iteration()
9847
9848                # Back to default size.
9849                ddp._update_process_group(_get_default_group())
9850                run_iteration()
9851            else:
9852                # Create default pg of smaller size.
9853                dist.destroy_process_group()
9854
9855                if self.rank in [1, 2, 3]:
9856                    dist.init_process_group(
9857                        init_method=self.init_method,
9858                        backend=BACKEND,
9859                        world_size=3,
9860                        rank=self.rank - 1,
9861                        timeout=timedelta(seconds=default_pg_timeout),
9862                    )
9863                    ddp._update_process_group(_get_default_group())
9864                    run_iteration()
9865                    dist.destroy_process_group()
9866
9867                # Need a barrier here to ensure ranks 1, 2 and 3 are done.
9868                self._barrier(wait_for=4)
9869
9870                # Need to init pg again for "_barrier" to succeed.
9871                dist.init_process_group(
9872                    init_method=self.init_method,
9873                    backend=BACKEND,
9874                    world_size=4,
9875                    rank=self.rank,
9876                    timeout=timedelta(seconds=default_pg_timeout),
9877                )
9878
9879            # Validate no more recompiles.
9880            assert 0 == get_num_torch_recompiles()
9881
9882        @skip_if_lt_x_gpu(4)
9883        @require_world_size(4)
9884        @skip_but_pass_in_sandcastle_if(
9885            BACKEND not in DistTestCases.backend_feature["ddp"],
9886            f"The {BACKEND} backend does not support DistributedDataParallel",
9887        )
9888        def test_ddp_update_process_group_new_group(self):
9889            self._run_ddp_update_process_group(new_pg=True)
9890
9891        @skip_if_lt_x_gpu(4)
9892        @require_world_size(4)
9893        @skip_but_pass_in_sandcastle_if(
9894            BACKEND not in DistTestCases.backend_feature["ddp"],
9895            f"The {BACKEND} backend does not support DistributedDataParallel",
9896        )
9897        def test_ddp_update_process_group_default_group(self):
9898            self._run_ddp_update_process_group(new_pg=False)
9899
9900        @skip_if_lt_x_gpu(4)
9901        @require_world_size(4)
9902        @skip_but_pass_in_sandcastle_if(
9903            BACKEND not in DistTestCases.backend_feature["ddp"],
9904            f"The {BACKEND} backend does not support DistributedDataParallel",
9905        )
9906        def test_ddp_update_process_group_grad_undefined(self):
9907            class SimulateError(torch.autograd.Function):
9908                @staticmethod
9909                def forward(ctx, input):
9910                    return input
9911
9912                @staticmethod
9913                def backward(ctx, grad_output):
9914                    raise RuntimeError
9915
9916            class MyModel(torch.nn.Module):
9917                def __init__(self, device):
9918                    super().__init__()
9919                    self.fc1 = torch.nn.Linear(10, 10).cuda(device)
9920                    self.fc2 = torch.nn.Linear(10, 10).cuda(device)
9921                    self.fc3 = torch.nn.Linear(10, 10).cuda(device)
9922
9923                def forward(self, inp, error):
9924                    if error:
9925                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
9926                    else:
9927                        return self.fc2(self.fc1(inp))
9928
9929
9930            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
9931            ddp = torch.nn.parallel.DistributedDataParallel(
9932                MyModel(self.rank),
9933                device_ids=[self.rank],
9934                find_unused_parameters=True,
9935                bucket_cap_mb=1,
9936            )
9937
9938            try:
9939                ddp(input, True).sum().backward()
9940            except RuntimeError:
9941                ddp._update_process_group(_get_default_group())
9942
9943            # Reset grads.
9944            for param in ddp.parameters():
9945                param.grad = None
9946
9947            # Run ddp again.
9948            ddp(input, False).sum().backward()
9949
9950        @skip_if_lt_x_gpu(4)
9951        @require_world_size(4)
9952        @skip_but_pass_in_sandcastle_if(
9953            BACKEND not in DistTestCases.backend_feature["ddp"],
9954            f"The {BACKEND} backend does not support DistributedDataParallel",
9955        )
9956        def test_ddp_update_process_group_no_find_unused(self):
9957            ddp = torch.nn.parallel.DistributedDataParallel(
9958                torch.nn.Linear(10, 10).cuda(self.rank),
9959                device_ids=[self.rank],
9960                find_unused_parameters=False,
9961            )
9962            ddp._update_process_group(_get_default_group())
9963
9964
9965        @skip_if_lt_x_gpu(2)
9966        @skip_but_pass_in_sandcastle_if(
9967            BACKEND not in DistTestCases.backend_feature["ddp"],
9968            f"The {BACKEND} backend does not support DistributedDataParallel",
9969        )
9970        def test_ddp_broadcast_buffer(self):
9971            rank = self.rank
9972            torch.cuda.set_device(rank)
9973            torch.manual_seed(rank)
9974            torch.cuda.manual_seed(rank)
9975
9976            class NetWithBuffers(nn.Module):
9977                def __init__(self) -> None:
9978                    super().__init__()
9979                    self.a = nn.Linear(10, 10, bias=False)
9980                    self.b = nn.Linear(10, 1, bias=False)
9981                    self.register_buffer("buffer", torch.randn(1, 2))
9982
9983                def forward(self, x):
9984                    return self.b(self.a(x))
9985
9986            model = NetWithBuffers().cuda(rank)
9987            model_ddp = torch.nn.parallel.DistributedDataParallel(
9988                model,
9989                device_ids=[self.rank],
9990            )
9991            inp = torch.randn(2, 10, device=rank)
9992            for i in range(2):
9993                if rank == 0:
9994                    model_ddp.module.buffer = model_ddp.module.buffer + 1
9995                loss = model_ddp(inp).sum()
9996                loss.backward()
9997                # Ensure all buffers are synchronized.
9998                bufs = [
9999                    torch.empty_like(model_ddp.module.buffer)
10000                    for _ in range(dist.get_world_size())
10001                ]
10002                dist.all_gather(bufs, model_ddp.module.buffer)
10003                rank_0_buf = bufs[0]
10004                for buf in bufs[1:]:
10005                    self.assertEqual(rank_0_buf, buf)
10006
10007        @skip_if_lt_x_gpu(2)
10008        @skip_but_pass_in_sandcastle_if(
10009            BACKEND != "nccl" and BACKEND != "gloo",
10010            "Only Nccl & Gloo backend support DistributedDataParallel",
10011        )
10012        def test_static_graph_multi_forward(self):
10013            class Net(nn.Module):
10014                def __init__(self) -> None:
10015                    super().__init__()
10016                    self.lin = nn.Linear(10, 10)
10017                    self.relu = nn.ReLU()
10018
10019                def forward(self, x):
10020                    return self.relu(self.lin(x))
10021
10022            torch.cuda.set_device(self.rank)
10023            torch.manual_seed(42 << 1337 % (self.rank + 1))
10024            model = Net().cuda(self.rank)
10025            local_model = copy.deepcopy(model)
10026            model = torch.nn.parallel.DistributedDataParallel(
10027                model, device_ids=[self.rank], static_graph=True
10028            )
10029            inp = torch.ones(2, 10, device="cuda")
10030            for _ in range(3):
10031                model.zero_grad()
10032                local_model.zero_grad()
10033                a = model(inp)
10034                b = model(inp)
10035                loss = a.sum() + b.sum()
10036                loss.backward()
10037                # Grads should be equal to a local model that ran through inp twice and averaged grads
10038                if self.rank == 0:
10039                    inp_clone = inp.clone()
10040                    for _ in range(2):
10041                        a = local_model(inp_clone)
10042                        b = local_model(inp_clone)
10043                        loss = a.sum() + b.sum()
10044                        loss.backward()
10045
10046                    ws = dist.get_world_size()
10047                    for p in local_model.parameters():
10048                        p.grad.data = p.grad / dist.get_world_size()
10049
10050                    for p_ddp, p_local in zip(
10051                        model.parameters(),
10052                        local_model.parameters()
10053                    ):
10054                        self.assertTrue(
10055                            torch.allclose(
10056                                p_ddp.grad, p_local.grad
10057                            ),
10058                            f"{p_ddp.grad} vs {p_local.grad}"
10059                        )
10060
10061            dist.barrier()
10062
10063        @skip_if_lt_x_gpu(2)
10064        @skip_but_pass_in_sandcastle_if(
10065            BACKEND != "nccl" and BACKEND != "gloo",
10066            "Only Nccl & Gloo backend support DistributedDataParallel",
10067        )
10068        def test_sync_bn_logged(self):
10069            model = BN_NET
10070            rank = self.rank
10071            # single gpu training setup
10072            model_gpu = model.cuda(rank)
10073            no_sync_bn = torch.nn.parallel.DistributedDataParallel(
10074                copy.deepcopy(model_gpu),
10075                device_ids=[self.rank],
10076            )
10077            ddp_logging_data = no_sync_bn._get_ddp_logging_data()
10078            sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
10079            self.assertFalse(sync_bn_logged)
10080            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
10081            model_DDP = torch.nn.parallel.DistributedDataParallel(
10082                model_DDP,
10083                device_ids=[self.rank],
10084            )
10085            ddp_logging_data = model_DDP._get_ddp_logging_data()
10086            sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
10087            self.assertTrue(sync_bn_logged)
10088
10089        @skip_if_lt_x_gpu(2)
10090        @skip_but_pass_in_sandcastle_if(
10091            BACKEND not in DistTestCases.backend_feature["ddp"],
10092            f"The {BACKEND} backend does not support DistributedDataParallel",
10093        )
10094        def test_stateless_api_with_ddp(self):
10095            class MockModule(torch.nn.Module):
10096                def __init__(self) -> None:
10097                    super().__init__()
10098                    self.l1 = torch.nn.Linear(1, 1)
10099                    buffer = torch.ones(1)
10100                    self.register_buffer("buffer", buffer)
10101
10102                def forward(self, x):
10103                    return self.l1(x) + self.buffer
10104
10105            device = self.rank
10106            module = MockModule().to(device)
10107            module = torch.nn.parallel.DistributedDataParallel(
10108                module, device_ids=[device]
10109            )
10110            x = torch.rand((1, 1)).to(device)
10111            weight = torch.tensor([[1.0]], device=device, requires_grad=True)
10112            bias = torch.tensor([0.0], device=device, requires_grad=True)
10113            buffer = torch.tensor([0.0], device=device)
10114            parameters = {
10115                "module.l1.weight": weight,
10116                "module.l1.bias": bias,
10117                "module.buffer": buffer,
10118            }
10119            prev_weight = module.module.l1.weight.clone()
10120            prev_buffer = module.module.buffer.clone()
10121
10122            res = torch.func.functional_call(module, parameters, x)
10123            self.assertEqual(x, res)
10124            # check that the weight remain unmodified
10125            cur_weight = module.module.l1.weight
10126            cur_buffer = module.module.buffer
10127            self.assertEqual(cur_weight, prev_weight)
10128            self.assertEqual(cur_buffer, prev_buffer)
10129            # run a backward pass and check the gradients
10130            res.backward()
10131            self.assertIsNotNone(weight.grad)
10132            self.assertIsNotNone(bias.grad)
10133            # Gradient was not calculated for the module stated and buffers
10134            self.assertIsNone(buffer.grad)
10135            self.assertIsNone(module.module.l1.weight.grad)
10136            self.assertIsNone(module.module.l1.bias.grad)
10137            self.assertIsNone(module.module.buffer.grad)
10138
10139        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
10140        @skip_if_lt_x_gpu(2)
10141        def test_ddp_forward_backward_hook(self):
10142            class DummyTestModel(nn.Module):
10143                def __init__(self) -> None:
10144                    super().__init__()
10145                    torch.manual_seed(0)
10146                    self.fc = nn.Linear(2, 2)
10147
10148                def forward(self, x):
10149                    return self.fc(x)
10150
10151            def relu_hook(module, input):
10152                return nn.functional.relu(input[0])
10153
10154            def gelu_hook(module, _input, output):
10155                return nn.functional.gelu(output)
10156
10157            def celu_hook(module, _input, output):
10158                return (nn.functional.celu(output[0]),)
10159
10160            local_model = DummyTestModel()
10161            ddp_model = DummyTestModel()
10162            local_model.fc.register_forward_pre_hook(relu_hook)
10163            local_model.fc.register_forward_hook(gelu_hook)
10164            ddp_model.fc.register_forward_pre_hook(relu_hook)
10165            ddp_model.fc.register_forward_hook(gelu_hook)
10166            local_model.fc.register_backward_hook(celu_hook)
10167            ddp_model.fc.register_backward_hook(celu_hook)
10168            ddp_model = DistributedDataParallel(
10169                ddp_model.to(self.rank), device_ids=[self.rank]
10170            )
10171            input_data = torch.rand(5, 2)
10172            output_local = local_model(input_data)
10173            output_ddp = ddp_model(input_data.to(self.rank))
10174            self.assertEqual(output_local, output_ddp)
10175            output_local.sum().backward()
10176            output_ddp.sum().backward()
10177            ddp_grads = [p.grad for p in ddp_model.parameters()]
10178            self.assertEqual(ddp_grads[0], local_model.fc.weight.grad)
10179            self.assertEqual(ddp_grads[1], local_model.fc.bias.grad)
10180
10181        def _test_hook_pickling(self, hook, hook_state):
10182            torch.manual_seed(0)
10183            learning_rate = 0.01
10184            chkpt_file = tempfile.gettempdir() + "/checkpoint.pt"
10185            rank = self.rank
10186
10187            input = torch.randn(7, 1, device=rank)
10188            target = torch.randn(7, 5, device=rank)
10189            net = torch.nn.Linear(1, 5).to(rank)
10190            ddp_model = DistributedDataParallel(copy.deepcopy(net), device_ids=[rank])
10191            dummy_ddp_model = DistributedDataParallel(
10192                copy.deepcopy(net), device_ids=[rank]
10193            )
10194            optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate)
10195            ddp_model.register_comm_hook(hook_state, hook)
10196            ddp_model.train()
10197
10198            for _ in range(10):
10199                optimizer.zero_grad()
10200                out = ddp_model(input)
10201                loss = F.mse_loss(out, target)
10202                loss.backward()
10203                optimizer.step()
10204
10205            state = {
10206                "state_dict": ddp_model.state_dict(),
10207                "comm_hook": hook,
10208                "comm_hook_state": hook_state,
10209            }
10210
10211            if rank == 0:
10212                with self.assertLogs("torch.distributed") as captured:
10213                    torch.save(state, chkpt_file)
10214
10215                # Check that the logger has only one entry
10216                self.assertEqual(len(captured.records), 1)
10217                # Check that the logger has an expected entry
10218                self.assertEqual(
10219                    captured.records[0].getMessage(),
10220                    "NOTE: Process group is not serializable and excluded from a saved state.",
10221                )
10222
10223            dist.barrier()
10224            map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
10225            with self.assertLogs("torch.distributed") as captured:
10226                checkpoint = torch.load(chkpt_file, map_location=map_location)
10227
10228            # Check that the logger has only one entry
10229            self.assertEqual(len(captured.records), 1)
10230            # Check that the logger has an expected entry
10231            self.assertEqual(
10232                captured.records[0].getMessage(),
10233                "NOTE: Process group will be set to a default group (i.e. the world size).\
10234                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded.",
10235            )
10236
10237            dummy_ddp_model.load_state_dict(checkpoint["state_dict"])
10238            dummy_hook = checkpoint["comm_hook"]
10239            dummy_hook_state = checkpoint["comm_hook_state"]
10240            dummy_optimizer = torch.optim.SGD(
10241                dummy_ddp_model.parameters(), lr=learning_rate
10242            )
10243
10244            # Check that loaded function is correct
10245            self.assertEqual(dummy_hook.__qualname__, hook.__qualname__)
10246
10247            # Check that all slots' keys were restored correctly
10248            self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__)
10249
10250            # Check that all slots' attributes are restored correctly
10251            # Excluding ``process_group`` and ``rng``.
10252            for entry in dummy_hook_state.__slots__:
10253                if entry != "process_group" and entry != "rng":
10254                    self.assertEqual(
10255                        getattr(dummy_hook_state, entry), getattr(hook_state, entry)
10256                    )
10257
10258            # Check that ``process_group`` was set to default
10259            self.assertEqual(dummy_hook_state.process_group, _get_default_group())
10260
10261            # Check that a random state was restored properly:
10262            # ``np.random.RandomState.get_state`` returns a tuple with entries:
10263            # ``bit_generator`` - str,
10264            # ``state.key`` - ndarray dtype[uint32],
10265            # ``state.pos`` - int,
10266            # ``has_gauss`` - int,
10267            # ``gauss`` - float
10268            #  (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi)
10269            # To make sure random state was restored properly, all entries should equal the original
10270            for entry1, entry2 in zip(
10271                hook_state.rng.get_state(), dummy_hook_state.rng.get_state()
10272            ):
10273                np.testing.assert_array_equal(entry1, entry2)
10274
10275            dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook)
10276            dummy_ddp_model.train()
10277
10278            for _ in range(10):
10279                optimizer.zero_grad()
10280                dummy_optimizer.zero_grad()
10281                out_origin = ddp_model(input)
10282                out_dummy = dummy_ddp_model(input)
10283                loss_origin = F.mse_loss(out_origin, target)
10284                loss_dummy = F.mse_loss(out_dummy, target)
10285                loss_origin.backward()
10286                loss_dummy.backward()
10287                optimizer.step()
10288                dummy_optimizer.step()
10289
10290            # Check that gradients after 10 epochs are the same
10291            for orig_param, dummy_param in zip(
10292                ddp_model.parameters(), dummy_ddp_model.parameters()
10293            ):
10294                self.assertEqual(orig_param.grad, dummy_param.grad)
10295
10296            dist.barrier()
10297            if rank == 0:
10298                os.remove(chkpt_file)
10299
10300        @skip_but_pass_in_sandcastle_if(
10301            BACKEND not in DistTestCases.backend_feature["cuda"],
10302            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
10303        )
10304        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
10305        @skip_but_pass_in_sandcastle_if(
10306            True, "Skipped due to flakiness"
10307        )
10308        def test_ddp_hook_pickling_powerSGD(self):
10309
10310            hook = powerSGD.powerSGD_hook
10311            powersgd_state = powerSGD.PowerSGDState(
10312                process_group=None,
10313                matrix_approximation_rank=1,
10314                start_powerSGD_iter=4,
10315            )
10316            self._test_hook_pickling(hook, powersgd_state)
10317
10318        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
10319        @skip_if_lt_x_gpu(2)
10320        def test_ddp_device_mesh_initialization(self):
10321            """
10322            Test DDP with device_mesh initialization.
10323            """
10324            world_size = int(os.environ["WORLD_SIZE"])
10325
10326            from torch.distributed.device_mesh import init_device_mesh
10327            device_mesh = init_device_mesh("cuda", (world_size,))
10328
10329            pg = _get_default_group()
10330
10331            torch.cuda.set_device(self.rank)
10332            model = TwoLinLayerNet().cuda()
10333            ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_mesh=device_mesh)
10334            self.assertEqual(ddp_model.device_mesh, device_mesh)
10335
10336            with self.assertRaisesRegex(
10337                RuntimeError, "Cannot specify both process_group and device_mesh arguments."
10338            ):
10339                ddp_model = torch.nn.parallel.DistributedDataParallel(
10340                    model, process_group=pg, device_mesh=device_mesh
10341                )
10342
10343            with self.assertRaisesRegex(
10344                RuntimeError, "Only 1D device mesh is supported,"
10345            ):
10346                device_mesh = init_device_mesh("cuda", (2, world_size // 2))
10347                ddp_model = torch.nn.parallel.DistributedDataParallel(
10348                    model, device_mesh=device_mesh
10349                )
10350
10351
10352        @skip_if_lt_x_gpu(2)
10353        @require_world_size(2)
10354        @skip_but_pass_in_sandcastle_if(
10355            BACKEND not in DistTestCases.backend_feature["ddp"],
10356            f"The {BACKEND} backend does not support DistributedDataParallel",
10357        )
10358        def test_ddp_compile_static_graph(self):
10359            "Tests that DDP works with torch compile when static_graph=True"
10360            model = torch.nn.Linear(10, 10).cuda(self.rank)
10361            model_clone = copy.deepcopy(model)
10362            ddp = torch.nn.parallel.DistributedDataParallel(
10363                model,
10364                device_ids=[self.rank],
10365            )
10366            ddp_static = torch.nn.parallel.DistributedDataParallel(
10367                model_clone,
10368                device_ids=[self.rank],
10369                static_graph=True
10370            )
10371            ddp = torch.compile(ddp)
10372            ddp_static = torch.compile(ddp_static)
10373            input = torch.rand(10, 10).cuda(self.rank)
10374            # verify output and gradient parity
10375            for _ in range(6):
10376                out_ddp = ddp(input).sum()
10377                out_ddp_static = ddp_static(input).sum()
10378                self.assertEqual(out_ddp, out_ddp_static)
10379                out_ddp.backward()
10380                out_ddp_static.backward()
10381                for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()):
10382                    self.assertEqual(p1.grad, p2.grad)
10383
10384        @skip_if_lt_x_gpu(2)
10385        @require_world_size(2)
10386        @skip_but_pass_in_sandcastle_if(
10387            BACKEND not in DistTestCases.backend_feature["ddp"],
10388            f"The {BACKEND} backend does not support DistributedDataParallel",
10389        )
10390        def test_ddp_sink_noclone(self):
10391            "Tests that we can configure DDP to avoid clone"
10392
10393            class OpPatcher(TorchDispatchMode):
10394                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
10395                    func_packet = func._overloadpacket
10396                    if func_packet == torch.ops.aten.clone:
10397                        raise RuntimeError("clone encountered!")
10398                    kwargs = kwargs if kwargs else {}
10399                    return func(*args, **kwargs)
10400
10401            class MyModel(torch.nn.Module):
10402                def __init__(self) -> None:
10403                    super().__init__()
10404                    self.fc = torch.nn.Linear(10, 10)
10405
10406                def forward(self, input):
10407                    return self.fc(input)
10408
10409            model = MyModel().cuda(self.rank)
10410            ddp = torch.nn.parallel.DistributedDataParallel(
10411                model,
10412                device_ids=[self.rank],
10413                find_unused_parameters=True,
10414            )
10415            ddp._set_ddp_sink_clone(False)
10416            input = torch.rand(10, 10).cuda(self.rank)
10417
10418            with OpPatcher() as patcher:
10419                ddp(input).sum().backward()
10420
10421
10422
10423instantiate_parametrized_tests(DistributedTest._DistTestBase)
10424