1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5from functools import partial, wraps 6 7import torch 8import torch.distributed as dist 9 10 11if not dist.is_available(): 12 print("Distributed not available, skipping tests", file=sys.stderr) 13 sys.exit(0) 14 15from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS 16from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 17 18 19if TEST_WITH_DEV_DBG_ASAN: 20 print( 21 "Skip dev-asan as torch + multiprocessing spawn have known issues", 22 file=sys.stderr, 23 ) 24 sys.exit(0) 25 26BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO 27WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) 28 29 30def with_comms(func=None): 31 if func is None: 32 return partial( 33 with_comms, 34 ) 35 36 @wraps(func) 37 def wrapper(self, *args, **kwargs): 38 if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: 39 sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) 40 self.dist_init() 41 func(self) 42 self.destroy_comms() 43 44 return wrapper 45 46 47class TestObjectCollectives(MultiProcessTestCase): 48 def setUp(self): 49 super().setUp() 50 os.environ["WORLD_SIZE"] = str(self.world_size) 51 os.environ["BACKEND"] = BACKEND 52 self._spawn_processes() 53 54 @property 55 def device(self): 56 return ( 57 torch.device(self.rank) 58 if BACKEND == dist.Backend.NCCL 59 else torch.device("cpu") 60 ) 61 62 @property 63 def world_size(self): 64 return WORLD_SIZE 65 66 @property 67 def process_group(self): 68 return dist.group.WORLD 69 70 def destroy_comms(self): 71 # Wait for all ranks to reach here before starting shutdown. 72 dist.barrier() 73 dist.destroy_process_group() 74 75 def dist_init(self): 76 dist.init_process_group( 77 backend=BACKEND, 78 world_size=self.world_size, 79 rank=self.rank, 80 init_method=f"file://{self.file_name}", 81 ) 82 83 # set device for nccl pg for collectives 84 if BACKEND == "nccl": 85 torch.cuda.set_device(self.rank) 86 87 @with_comms() 88 def test_all_gather_object(self): 89 output = [None] * dist.get_world_size() 90 dist.all_gather_object(object_list=output, obj=self.rank) 91 92 for i, v in enumerate(output): 93 self.assertEqual(i, v, f"rank: {self.rank}") 94 95 @with_comms() 96 def test_gather_object(self): 97 output = [None] * dist.get_world_size() if self.rank == 0 else None 98 dist.gather_object(obj=self.rank, object_gather_list=output) 99 100 if self.rank == 0: 101 for i, v in enumerate(output): 102 self.assertEqual(i, v, f"rank: {self.rank}") 103 104 @with_comms() 105 def test_send_recv_object_list(self): 106 val = 99 if self.rank == 0 else None 107 object_list = [val] * dist.get_world_size() 108 if self.rank == 0: 109 dist.send_object_list(object_list, 1) 110 if self.rank == 1: 111 dist.recv_object_list(object_list, 0) 112 113 if self.rank < 2: 114 self.assertEqual(99, object_list[0]) 115 else: 116 self.assertEqual(None, object_list[0]) 117 118 @with_comms() 119 def test_broadcast_object_list(self): 120 val = 99 if self.rank == 0 else None 121 object_list = [val] * dist.get_world_size() 122 # TODO test with broadcast_object_list's device argument 123 dist.broadcast_object_list(object_list=object_list) 124 125 self.assertEqual(99, object_list[0]) 126 127 @with_comms() 128 def test_scatter_object_list(self): 129 input_list = list(range(dist.get_world_size())) if self.rank == 0 else None 130 output_list = [None] 131 dist.scatter_object_list( 132 scatter_object_output_list=output_list, scatter_object_input_list=input_list 133 ) 134 135 self.assertEqual(self.rank, output_list[0]) 136 137 # Test Object Collectives With Sub Pg 138 139 def setup_sub_pg(self): 140 rank = dist.get_rank() 141 base_rank = rank - (rank % 2) 142 ranks = [base_rank, base_rank + 1] 143 my_pg = dist.new_group(ranks, use_local_synchronization=True) 144 return rank, ranks, my_pg 145 146 @with_comms() 147 def test_subpg_scatter_object(self): 148 rank, ranks, my_pg = self.setup_sub_pg() 149 out_list = [None] 150 dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg) 151 self.assertEqual(rank, out_list[0]) 152 153 @with_comms() 154 def test_subpg_all_gather_object(self): 155 rank, ranks, my_pg = self.setup_sub_pg() 156 out_list = [None] * len(ranks) 157 dist.all_gather_object(out_list, rank, group=my_pg) 158 self.assertEqual(ranks, out_list) 159 160 @with_comms() 161 def test_subpg_gather_object(self): 162 rank, ranks, my_pg = self.setup_sub_pg() 163 out_list = [None] * len(ranks) if rank == ranks[0] else None 164 dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg) 165 if rank == ranks[0]: 166 self.assertEqual(ranks, out_list) 167 168 @with_comms() 169 def test_subpg_broadcast_object(self): 170 rank, ranks, my_pg = self.setup_sub_pg() 171 out_list = [None] 172 if rank == ranks[0]: 173 out_list[0] = rank 174 dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg) 175 self.assertEqual(ranks[0], out_list[0]) 176 177 178if __name__ == "__main__": 179 run_tests() 180