xref: /aosp_15_r20/external/pytorch/test/distributed/test_control_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3from datetime import timedelta
4from multiprocessing.pool import ThreadPool
5
6import torch
7import torch.distributed as dist
8from torch.testing._internal.common_utils import run_tests, TestCase
9
10
11# simple example of user code that takes the base class ControlCollectives
12# and executes multiple different collectives
13def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int:
14    timeout = timedelta(seconds=10)
15    # first a barrier
16    collectives.barrier("1", timeout, True)
17    # then an all_sum
18    out = collectives.all_sum("2", rank, timeout)
19    return out
20
21
22class TestCollectives(TestCase):
23    def test_barrier(self) -> None:
24        store = dist.HashStore()
25
26        world_size = 2
27
28        def f(rank: int) -> None:
29            collectives = dist._StoreCollectives(store, rank, world_size)
30            collectives.barrier("foo", timedelta(seconds=10), True)
31
32        with ThreadPool(world_size) as pool:
33            pool.map(f, range(world_size))
34
35    def test_broadcast(self) -> None:
36        store = dist.HashStore()
37
38        world_size = 4
39        timeout = timedelta(seconds=10)
40
41        def f(rank: int) -> None:
42            collectives = dist._StoreCollectives(store, rank, world_size)
43            if rank == 2:
44                collectives.broadcast_send("foo", b"data", timeout)
45            else:
46                out = collectives.broadcast_recv("foo", timeout)
47                self.assertEqual(out, b"data")
48
49        with ThreadPool(world_size) as pool:
50            pool.map(f, range(world_size))
51
52    def test_gather(self) -> None:
53        store = dist.HashStore()
54
55        world_size = 4
56        timeout = timedelta(seconds=10)
57
58        def f(rank: int) -> None:
59            collectives = dist._StoreCollectives(store, rank, world_size)
60            if rank == 2:
61                out = collectives.gather_recv("foo", str(rank), timeout)
62                self.assertEqual(out, [b"0", b"1", b"2", b"3"])
63            else:
64                collectives.gather_send("foo", str(rank), timeout)
65
66        with ThreadPool(world_size) as pool:
67            pool.map(f, range(world_size))
68
69    def test_scatter(self) -> None:
70        store = dist.HashStore()
71
72        world_size = 4
73        timeout = timedelta(seconds=10)
74
75        def f(rank: int) -> None:
76            collectives = dist._StoreCollectives(store, rank, world_size)
77            if rank == 2:
78                out = collectives.scatter_send(
79                    "foo", [str(i) for i in range(world_size)], timeout
80                )
81            else:
82                out = collectives.scatter_recv("foo", timeout)
83            self.assertEqual(out, str(rank).encode())
84
85        with ThreadPool(world_size) as pool:
86            pool.map(f, range(world_size))
87
88    def test_all_sum(self) -> None:
89        store = dist.HashStore()
90
91        world_size = 4
92        timeout = timedelta(seconds=10)
93
94        def f(rank: int) -> None:
95            collectives = dist._StoreCollectives(store, rank, world_size)
96            out = collectives.all_sum("foo", rank, timeout)
97            self.assertEqual(out, sum(range(world_size)))
98
99        with ThreadPool(world_size) as pool:
100            pool.map(f, range(world_size))
101
102    def test_broadcast_timeout(self) -> None:
103        store = dist.HashStore()
104
105        world_size = 4
106        timeout = timedelta(milliseconds=1)
107        collectives = dist._StoreCollectives(store, 1, world_size)
108        with self.assertRaisesRegex(Exception, "Wait timeout"):
109            collectives.broadcast_recv("foo", timeout)
110
111    def test_gather_timeout(self) -> None:
112        store = dist.HashStore()
113
114        world_size = 4
115        timeout = timedelta(milliseconds=1)
116        collectives = dist._StoreCollectives(store, 1, world_size)
117        with self.assertRaisesRegex(
118            Exception, "gather failed -- missing ranks: 0, 2, 3"
119        ):
120            collectives.gather_recv("foo", "data", timeout)
121
122    def test_scatter_timeout(self) -> None:
123        store = dist.HashStore()
124
125        world_size = 4
126        timeout = timedelta(milliseconds=1)
127        collectives = dist._StoreCollectives(store, 1, world_size)
128        with self.assertRaisesRegex(Exception, "Wait timeout"):
129            collectives.scatter_recv("foo", timeout)
130
131    def test_all_gather_timeout(self) -> None:
132        store = dist.HashStore()
133
134        world_size = 4
135        timeout = timedelta(milliseconds=1)
136        collectives = dist._StoreCollectives(store, 1, world_size)
137        with self.assertRaisesRegex(
138            Exception, "all_gather failed -- missing ranks: 0, 2, 3"
139        ):
140            collectives.all_gather("foo", "data", timeout)
141
142    def test_barrier_timeout(self) -> None:
143        store = dist.HashStore()
144
145        world_size = 4
146        timeout = timedelta(milliseconds=1)
147        collectives = dist._StoreCollectives(store, 1, world_size)
148        with self.assertRaisesRegex(
149            Exception, "barrier failed -- missing ranks: 0, 2, 3"
150        ):
151            collectives.barrier("foo", timeout, True)
152
153    def test_all_sum_timeout(self) -> None:
154        store = dist.HashStore()
155
156        world_size = 4
157        timeout = timedelta(milliseconds=1)
158        collectives = dist._StoreCollectives(store, 1, world_size)
159        with self.assertRaisesRegex(
160            Exception, "barrier failed -- missing ranks: 0, 2, 3"
161        ):
162            collectives.all_sum("foo", 1, timeout)
163
164    def test_unique(self) -> None:
165        store = dist.HashStore()
166
167        collectives = dist._StoreCollectives(store, 1, 1)
168        collectives.broadcast_send("foo", "bar")
169
170        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
171            collectives.broadcast_send("foo", "bar")
172
173        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
174            collectives.broadcast_recv("foo")
175
176        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
177            collectives.gather_send("foo", "bar")
178
179        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
180            collectives.gather_recv("foo", "asdf")
181
182        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
183            collectives.scatter_send("foo", ["asdf"])
184
185        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
186            collectives.scatter_recv("foo")
187
188        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
189            collectives.all_gather("foo", "bar")
190
191        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
192            collectives.all_sum("foo", 2)
193
194    def test_simple_user_func(self) -> None:
195        store = dist.HashStore()
196        world_size = 4
197
198        def f(rank: int) -> None:
199            # user need to create child collectives
200            # but simple_user_func do not need to be changed for different child collectives
201            store_collectives = dist._StoreCollectives(store, rank, world_size)
202            out = simple_user_func(store_collectives, rank)
203            self.assertEqual(out, sum(range(world_size)))
204
205        with ThreadPool(world_size) as pool:
206            pool.map(f, range(world_size))
207
208
209if __name__ == "__main__":
210    assert (
211        not torch.cuda._initialized
212    ), "test_distributed must not have initialized CUDA context on main process"
213
214    run_tests()
215