xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_object_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from functools import partial, wraps
6
7import torch
8import torch.distributed as dist
9
10
11if not dist.is_available():
12    print("Distributed not available, skipping tests", file=sys.stderr)
13    sys.exit(0)
14
15from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
16from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17
18
19if TEST_WITH_DEV_DBG_ASAN:
20    print(
21        "Skip dev-asan as torch + multiprocessing spawn have known issues",
22        file=sys.stderr,
23    )
24    sys.exit(0)
25
26BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
27WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
28
29
30def with_comms(func=None):
31    if func is None:
32        return partial(
33            with_comms,
34        )
35
36    @wraps(func)
37    def wrapper(self, *args, **kwargs):
38        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
39            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
40        self.dist_init()
41        func(self)
42        self.destroy_comms()
43
44    return wrapper
45
46
47class TestObjectCollectives(MultiProcessTestCase):
48    def setUp(self):
49        super().setUp()
50        os.environ["WORLD_SIZE"] = str(self.world_size)
51        os.environ["BACKEND"] = BACKEND
52        self._spawn_processes()
53
54    @property
55    def device(self):
56        return (
57            torch.device(self.rank)
58            if BACKEND == dist.Backend.NCCL
59            else torch.device("cpu")
60        )
61
62    @property
63    def world_size(self):
64        return WORLD_SIZE
65
66    @property
67    def process_group(self):
68        return dist.group.WORLD
69
70    def destroy_comms(self):
71        # Wait for all ranks to reach here before starting shutdown.
72        dist.barrier()
73        dist.destroy_process_group()
74
75    def dist_init(self):
76        dist.init_process_group(
77            backend=BACKEND,
78            world_size=self.world_size,
79            rank=self.rank,
80            init_method=f"file://{self.file_name}",
81        )
82
83        # set device for nccl pg for collectives
84        if BACKEND == "nccl":
85            torch.cuda.set_device(self.rank)
86
87    @with_comms()
88    def test_all_gather_object(self):
89        output = [None] * dist.get_world_size()
90        dist.all_gather_object(object_list=output, obj=self.rank)
91
92        for i, v in enumerate(output):
93            self.assertEqual(i, v, f"rank: {self.rank}")
94
95    @with_comms()
96    def test_gather_object(self):
97        output = [None] * dist.get_world_size() if self.rank == 0 else None
98        dist.gather_object(obj=self.rank, object_gather_list=output)
99
100        if self.rank == 0:
101            for i, v in enumerate(output):
102                self.assertEqual(i, v, f"rank: {self.rank}")
103
104    @with_comms()
105    def test_send_recv_object_list(self):
106        val = 99 if self.rank == 0 else None
107        object_list = [val] * dist.get_world_size()
108        if self.rank == 0:
109            dist.send_object_list(object_list, 1)
110        if self.rank == 1:
111            dist.recv_object_list(object_list, 0)
112
113        if self.rank < 2:
114            self.assertEqual(99, object_list[0])
115        else:
116            self.assertEqual(None, object_list[0])
117
118    @with_comms()
119    def test_broadcast_object_list(self):
120        val = 99 if self.rank == 0 else None
121        object_list = [val] * dist.get_world_size()
122        # TODO test with broadcast_object_list's device argument
123        dist.broadcast_object_list(object_list=object_list)
124
125        self.assertEqual(99, object_list[0])
126
127    @with_comms()
128    def test_scatter_object_list(self):
129        input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
130        output_list = [None]
131        dist.scatter_object_list(
132            scatter_object_output_list=output_list, scatter_object_input_list=input_list
133        )
134
135        self.assertEqual(self.rank, output_list[0])
136
137    # Test Object Collectives With Sub Pg
138
139    def setup_sub_pg(self):
140        rank = dist.get_rank()
141        base_rank = rank - (rank % 2)
142        ranks = [base_rank, base_rank + 1]
143        my_pg = dist.new_group(ranks, use_local_synchronization=True)
144        return rank, ranks, my_pg
145
146    @with_comms()
147    def test_subpg_scatter_object(self):
148        rank, ranks, my_pg = self.setup_sub_pg()
149        out_list = [None]
150        dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
151        self.assertEqual(rank, out_list[0])
152
153    @with_comms()
154    def test_subpg_all_gather_object(self):
155        rank, ranks, my_pg = self.setup_sub_pg()
156        out_list = [None] * len(ranks)
157        dist.all_gather_object(out_list, rank, group=my_pg)
158        self.assertEqual(ranks, out_list)
159
160    @with_comms()
161    def test_subpg_gather_object(self):
162        rank, ranks, my_pg = self.setup_sub_pg()
163        out_list = [None] * len(ranks) if rank == ranks[0] else None
164        dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
165        if rank == ranks[0]:
166            self.assertEqual(ranks, out_list)
167
168    @with_comms()
169    def test_subpg_broadcast_object(self):
170        rank, ranks, my_pg = self.setup_sub_pg()
171        out_list = [None]
172        if rank == ranks[0]:
173            out_list[0] = rank
174        dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg)
175        self.assertEqual(ranks[0], out_list[0])
176
177
178if __name__ == "__main__":
179    run_tests()
180