1# Owner(s): ["oncall: distributed"] 2import copy 3import sys 4 5import torch 6import torch.nn as nn 7from torch.distributed._shard import shard_module 8from torch.distributed._shard.sharded_tensor import ShardedTensor 9from torch.distributed._shard.sharder import Sharder 10from torch.distributed._shard.sharding_plan import ShardingPlan 11from torch.distributed._shard.sharding_spec import ChunkShardingSpec 12from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu 13from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 14from torch.testing._internal.distributed._shard.sharded_tensor import ( 15 ShardedTensorTestBase, 16 TEST_GPU_NUM, 17 with_comms, 18) 19 20 21if TEST_WITH_DEV_DBG_ASAN: 22 print( 23 "Skip dev-asan as torch + multiprocessing spawn have known issues", 24 file=sys.stderr, 25 ) 26 sys.exit(0) 27 28 29# a simple collection of embedding bag implementation 30class CustomEmbeddingBagCollection(nn.Module): 31 def __init__(self, num_bags, num_embeddings_per_bag, num_dims): 32 super().__init__() 33 self.num_bags = num_bags 34 self.embedding_bags: nn.ModuleDict = nn.ModuleDict() 35 36 for i in range(num_bags): 37 self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag( 38 num_embeddings_per_bag, num_dims, mode="sum" 39 ) 40 41 def forward(self, inputs): 42 outputs = [] 43 for bag in self.embedding_bags.values(): 44 outputs.append(bag(inputs)) 45 return torch.cat(outputs) 46 47 48# a simple sharded version of EBC 49class CustomShardedEBC(nn.Module): 50 def __init__(self, ebc, split_idx, specs): 51 super().__init__() 52 self.split_idx = split_idx 53 row_spec, col_spec = specs 54 55 # create embedding bags base on the spec 56 self.embedding_bags: nn.ModuleDict = nn.ModuleDict() 57 58 assert self.split_idx < ebc.num_bags 59 for i in range(ebc.num_bags): 60 bag_key = f"embedding_bag_{i}" 61 if i < self.split_idx: 62 shard_module( 63 ebc, 64 plan=ShardingPlan( 65 plan={f"embedding_bags.{bag_key}.weight": row_spec} 66 ), 67 ) 68 else: 69 shard_module( 70 ebc, 71 plan=ShardingPlan( 72 plan={f"embedding_bags.{bag_key}.weight": col_spec} 73 ), 74 ) 75 76 self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key] 77 78 79class CustomSharder(Sharder): 80 def __init__(self, devices, split_sharding_idx): 81 self.devices = devices 82 self.split_sharding_idx = split_sharding_idx 83 self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices) 84 self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices) 85 86 def shard(self, ebc: nn.Module) -> nn.Module: 87 if not isinstance(ebc, CustomEmbeddingBagCollection): 88 raise RuntimeError( 89 "The custom sharder only supports CustomEmbeddingBagCollection" 90 ) 91 92 return CustomShardedEBC( 93 ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec) 94 ) 95 96 97class TestCustomSharder(ShardedTensorTestBase): 98 @with_comms(init_rpc=False) 99 @skip_if_lt_x_gpu(TEST_GPU_NUM) 100 @requires_nccl() 101 def test_custom_sharder(self): 102 class MyModule(nn.Module): 103 def __init__(self) -> None: 104 super().__init__() 105 self.ebc = CustomEmbeddingBagCollection(10, 10, 8) 106 107 def forward(self, inputs): 108 return self.ebc(inputs) 109 110 custom_sharder = CustomSharder( 111 devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], 112 split_sharding_idx=TEST_GPU_NUM // 2, 113 ) 114 115 sharding_plan = ShardingPlan( 116 plan={ 117 "ebc": custom_sharder, 118 } 119 ) 120 121 local_model = MyModule().cuda(self.rank) 122 sharded_model = copy.deepcopy(local_model) 123 124 # shard the module with the provided sharding plan 125 shard_module(sharded_model, sharding_plan) 126 127 # check to make sure the module already been sharded 128 emb_bags = sharded_model.ebc.embedding_bags 129 self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor)) 130 self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor)) 131 self.assertEqual( 132 emb_bags["embedding_bag_0"].weight.sharding_spec(), 133 custom_sharder.rowwise_spec, 134 ) 135 self.assertEqual( 136 emb_bags["embedding_bag_9"].weight.sharding_spec(), 137 custom_sharder.colwise_spec, 138 ) 139 140 # make sure we can run sharded computation and compare outputs 141 # with the local model version 142 input = torch.arange(8).reshape((2, 4)).cuda(self.rank) 143 local_output = local_model(input) 144 sharded_output = sharded_model(input) 145 146 self.assertEqual(local_output, sharded_output) 147 148 @with_comms(init_rpc=False) 149 @skip_if_lt_x_gpu(TEST_GPU_NUM) 150 @requires_nccl() 151 def test_custom_sharder_errors(self): 152 custom_sharder = CustomSharder( 153 devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], 154 split_sharding_idx=TEST_GPU_NUM // 2, 155 ) 156 157 sharding_plan = ShardingPlan( 158 plan={ 159 "": custom_sharder, 160 } 161 ) 162 163 sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank) 164 165 with self.assertRaisesRegex( 166 KeyError, "path must not be empty for custom sharder!" 167 ): 168 # shard the module with the provided sharding plan 169 shard_module(sharded_model, sharding_plan) 170 171 # test conflicted sharding plan 172 spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"]) 173 sharding_plan = ShardingPlan( 174 plan={ 175 "embedding_bags.embedding_bag_0.weight": spec, 176 "embedding_bags": custom_sharder, 177 } 178 ) 179 180 with self.assertRaisesRegex( 181 RuntimeError, "should not conflict with the submodule tree" 182 ): 183 # shard the module with the provided sharding plan 184 shard_module(sharded_model, sharding_plan) 185 186 187if __name__ == "__main__": 188 run_tests() 189