1# mypy: allow-untyped-defs 2 3import sys 4from functools import wraps, partial 5 6import torch 7import torch.distributed as dist 8from torch.distributed import rpc 9from torch.testing._internal.common_distributed import ( 10 MultiProcessTestCase, 11 TEST_SKIPS, 12 tp_transports, 13) 14 15TEST_GPU_NUM = 4 16 17class ShardedTensorTestBase(MultiProcessTestCase): 18 @property 19 def world_size(self): 20 return TEST_GPU_NUM 21 22 def init_pg(self, backend="nccl"): 23 if backend not in ["nccl", "gloo", "mpi"]: 24 raise RuntimeError(f"Backend {backend} not supported!") 25 26 dist.init_process_group( 27 backend=backend, 28 world_size=self.world_size, 29 rank=self.rank, 30 init_method=f"file://{self.file_name}", 31 ) 32 33 # set device for nccl pg for collectives 34 if backend == "nccl": 35 torch.cuda.set_device(self.rank) 36 37 38 def init_rpc(self): 39 rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports()) 40 rpc_backend_options.init_method = f"file://{self.file_name}" 41 for rank in range(self.world_size): 42 rpc_backend_options.set_device_map( 43 f"worker{rank}", {rank: self.rank, self.rank: rank} 44 ) 45 46 rpc.init_rpc( 47 name="worker%d" % self.rank, 48 rank=self.rank, 49 world_size=self.world_size, 50 rpc_backend_options=rpc_backend_options, 51 ) 52 53 def init_comms(self, init_rpc=True, backend="nccl"): 54 if init_rpc: 55 self.init_rpc() 56 self.init_pg(backend=backend) 57 58 def destroy_comms(self, destroy_rpc=True): 59 # Wait for all ranks to reach here before starting shutdown. 60 dist.barrier() 61 62 if destroy_rpc: 63 rpc.shutdown() 64 dist.destroy_process_group() 65 66 def setUp(self) -> None: 67 super().setUp() 68 self._spawn_processes() 69 70 def assert_sharded_tensor_equal(self, st1, st2): 71 st1_local_shards = st1.local_shards() 72 st2_local_shards = st2.local_shards() 73 self.assertEqual(len(st1_local_shards), len(st2_local_shards)) 74 for i, st1_local_shard in enumerate(st1_local_shards): 75 self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor) 76 self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata) 77 78 self.assertEqual(st1.metadata(), st2.metadata()) 79 self.assertEqual(st1.sharding_spec(), st2.sharding_spec()) 80 self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards())) 81 82# wrapper to initialize comms (processgroup + rpc) 83def with_comms(func=None, init_rpc=True, backend="nccl"): 84 if func is None: 85 return partial( 86 with_comms, 87 init_rpc=init_rpc, 88 backend=backend, 89 ) 90 91 @wraps(func) 92 def wrapper(self, *args, **kwargs): 93 if backend == "nccl" and torch.cuda.device_count() < self.world_size: 94 sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) 95 self.init_comms(init_rpc=init_rpc, backend=backend) 96 func(self, *args, **kwargs) 97 self.destroy_comms(destroy_rpc=init_rpc) 98 return wrapper 99