1# Owner(s): ["oncall: distributed"] 2 3from datetime import timedelta 4from multiprocessing.pool import ThreadPool 5 6import torch 7import torch.distributed as dist 8from torch.testing._internal.common_utils import run_tests, TestCase 9 10 11# simple example of user code that takes the base class ControlCollectives 12# and executes multiple different collectives 13def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int: 14 timeout = timedelta(seconds=10) 15 # first a barrier 16 collectives.barrier("1", timeout, True) 17 # then an all_sum 18 out = collectives.all_sum("2", rank, timeout) 19 return out 20 21 22class TestCollectives(TestCase): 23 def test_barrier(self) -> None: 24 store = dist.HashStore() 25 26 world_size = 2 27 28 def f(rank: int) -> None: 29 collectives = dist._StoreCollectives(store, rank, world_size) 30 collectives.barrier("foo", timedelta(seconds=10), True) 31 32 with ThreadPool(world_size) as pool: 33 pool.map(f, range(world_size)) 34 35 def test_broadcast(self) -> None: 36 store = dist.HashStore() 37 38 world_size = 4 39 timeout = timedelta(seconds=10) 40 41 def f(rank: int) -> None: 42 collectives = dist._StoreCollectives(store, rank, world_size) 43 if rank == 2: 44 collectives.broadcast_send("foo", b"data", timeout) 45 else: 46 out = collectives.broadcast_recv("foo", timeout) 47 self.assertEqual(out, b"data") 48 49 with ThreadPool(world_size) as pool: 50 pool.map(f, range(world_size)) 51 52 def test_gather(self) -> None: 53 store = dist.HashStore() 54 55 world_size = 4 56 timeout = timedelta(seconds=10) 57 58 def f(rank: int) -> None: 59 collectives = dist._StoreCollectives(store, rank, world_size) 60 if rank == 2: 61 out = collectives.gather_recv("foo", str(rank), timeout) 62 self.assertEqual(out, [b"0", b"1", b"2", b"3"]) 63 else: 64 collectives.gather_send("foo", str(rank), timeout) 65 66 with ThreadPool(world_size) as pool: 67 pool.map(f, range(world_size)) 68 69 def test_scatter(self) -> None: 70 store = dist.HashStore() 71 72 world_size = 4 73 timeout = timedelta(seconds=10) 74 75 def f(rank: int) -> None: 76 collectives = dist._StoreCollectives(store, rank, world_size) 77 if rank == 2: 78 out = collectives.scatter_send( 79 "foo", [str(i) for i in range(world_size)], timeout 80 ) 81 else: 82 out = collectives.scatter_recv("foo", timeout) 83 self.assertEqual(out, str(rank).encode()) 84 85 with ThreadPool(world_size) as pool: 86 pool.map(f, range(world_size)) 87 88 def test_all_sum(self) -> None: 89 store = dist.HashStore() 90 91 world_size = 4 92 timeout = timedelta(seconds=10) 93 94 def f(rank: int) -> None: 95 collectives = dist._StoreCollectives(store, rank, world_size) 96 out = collectives.all_sum("foo", rank, timeout) 97 self.assertEqual(out, sum(range(world_size))) 98 99 with ThreadPool(world_size) as pool: 100 pool.map(f, range(world_size)) 101 102 def test_broadcast_timeout(self) -> None: 103 store = dist.HashStore() 104 105 world_size = 4 106 timeout = timedelta(milliseconds=1) 107 collectives = dist._StoreCollectives(store, 1, world_size) 108 with self.assertRaisesRegex(Exception, "Wait timeout"): 109 collectives.broadcast_recv("foo", timeout) 110 111 def test_gather_timeout(self) -> None: 112 store = dist.HashStore() 113 114 world_size = 4 115 timeout = timedelta(milliseconds=1) 116 collectives = dist._StoreCollectives(store, 1, world_size) 117 with self.assertRaisesRegex( 118 Exception, "gather failed -- missing ranks: 0, 2, 3" 119 ): 120 collectives.gather_recv("foo", "data", timeout) 121 122 def test_scatter_timeout(self) -> None: 123 store = dist.HashStore() 124 125 world_size = 4 126 timeout = timedelta(milliseconds=1) 127 collectives = dist._StoreCollectives(store, 1, world_size) 128 with self.assertRaisesRegex(Exception, "Wait timeout"): 129 collectives.scatter_recv("foo", timeout) 130 131 def test_all_gather_timeout(self) -> None: 132 store = dist.HashStore() 133 134 world_size = 4 135 timeout = timedelta(milliseconds=1) 136 collectives = dist._StoreCollectives(store, 1, world_size) 137 with self.assertRaisesRegex( 138 Exception, "all_gather failed -- missing ranks: 0, 2, 3" 139 ): 140 collectives.all_gather("foo", "data", timeout) 141 142 def test_barrier_timeout(self) -> None: 143 store = dist.HashStore() 144 145 world_size = 4 146 timeout = timedelta(milliseconds=1) 147 collectives = dist._StoreCollectives(store, 1, world_size) 148 with self.assertRaisesRegex( 149 Exception, "barrier failed -- missing ranks: 0, 2, 3" 150 ): 151 collectives.barrier("foo", timeout, True) 152 153 def test_all_sum_timeout(self) -> None: 154 store = dist.HashStore() 155 156 world_size = 4 157 timeout = timedelta(milliseconds=1) 158 collectives = dist._StoreCollectives(store, 1, world_size) 159 with self.assertRaisesRegex( 160 Exception, "barrier failed -- missing ranks: 0, 2, 3" 161 ): 162 collectives.all_sum("foo", 1, timeout) 163 164 def test_unique(self) -> None: 165 store = dist.HashStore() 166 167 collectives = dist._StoreCollectives(store, 1, 1) 168 collectives.broadcast_send("foo", "bar") 169 170 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 171 collectives.broadcast_send("foo", "bar") 172 173 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 174 collectives.broadcast_recv("foo") 175 176 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 177 collectives.gather_send("foo", "bar") 178 179 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 180 collectives.gather_recv("foo", "asdf") 181 182 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 183 collectives.scatter_send("foo", ["asdf"]) 184 185 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 186 collectives.scatter_recv("foo") 187 188 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 189 collectives.all_gather("foo", "bar") 190 191 with self.assertRaisesRegex(Exception, "Key foo has already been used"): 192 collectives.all_sum("foo", 2) 193 194 def test_simple_user_func(self) -> None: 195 store = dist.HashStore() 196 world_size = 4 197 198 def f(rank: int) -> None: 199 # user need to create child collectives 200 # but simple_user_func do not need to be changed for different child collectives 201 store_collectives = dist._StoreCollectives(store, rank, world_size) 202 out = simple_user_func(store_collectives, rank) 203 self.assertEqual(out, sum(range(world_size))) 204 205 with ThreadPool(world_size) as pool: 206 pool.map(f, range(world_size)) 207 208 209if __name__ == "__main__": 210 assert ( 211 not torch.cuda._initialized 212 ), "test_distributed must not have initialized CUDA context on main process" 213 214 run_tests() 215