1# Owner(s): ["oncall: distributed"] 2from collections import OrderedDict 3from copy import deepcopy 4 5import torch 6from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard 7from torch.distributed.tensor.parallel.api import parallelize_module 8from torch.distributed.tensor.parallel.style import ( 9 ColwiseParallel, 10 PrepareModuleInput, 11 PrepareModuleOutput, 12 RowwiseParallel, 13) 14from torch.testing._internal.common_utils import run_tests 15from torch.testing._internal.distributed._tensor.common_dtensor import ( 16 DTensorTestBase, 17 MLPModule, 18 MLPStacked, 19 with_comms, 20) 21 22 23class DummyModule(torch.nn.Module): 24 def __init__(self) -> None: 25 super().__init__() 26 27 def forward(self, x): 28 return x 29 30 31class TensorParallelAPITests(DTensorTestBase): 32 @property 33 def world_size(self): 34 gpu_num = torch.cuda.device_count() 35 return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 36 37 def _compare_params( 38 self, 39 local_module, 40 dist_module, 41 rank0_only, 42 skip_rowwise_bias=False, 43 compare_grad=False, 44 ): 45 replicate = [Replicate()] 46 for name, param in local_module.named_parameters(): 47 dist_param = dist_module.get_parameter(name) 48 param = param.grad if compare_grad else param 49 dist_param = dist_param.grad if compare_grad else dist_param 50 if ( 51 (not rank0_only) 52 or (self.rank == 0) 53 or ( 54 name not in ["net2.bias"] 55 and not skip_rowwise_bias 56 or name not in ["bias", "net2.bias"] 57 ) 58 ): 59 self.assertEqual( 60 param, 61 dist_param.redistribute( 62 device_mesh=dist_param.device_mesh, placements=replicate 63 ).to_local(), 64 f"{name} not equal between dist and non-dist", 65 ) 66 67 def _compare_module( 68 self, local_module, dist_module, inp_size, rank0_only=True, rowwise=False 69 ): 70 LR = 0.25 # the learning rate we use for testing 71 local_optim = torch.optim.SGD(local_module.parameters(), lr=LR) 72 dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR) 73 torch.manual_seed(0) 74 inp = torch.rand(*inp_size, device=self.device_type) 75 self._compare_params(local_module, dist_module, rank0_only) 76 77 # check forward correctness 78 local_output = local_module(inp) 79 inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp 80 dist_output = dist_module(inp) 81 dist_output = ( 82 dist_output.redistribute(dist_output.device_mesh, [Replicate()]).to_local() 83 if isinstance(dist_output, DTensor) 84 else dist_output 85 ) 86 self.assertEqual(local_output, dist_output) 87 88 local_output.sum().backward() 89 dist_output.sum().backward() 90 91 # check backward and ensure gradients are same 92 self._compare_params(local_module, dist_module, rank0_only, rowwise, True) 93 94 local_optim.step() 95 dist_optim.step() 96 self._compare_params(local_module, dist_module, rank0_only, rowwise) 97 98 @with_comms 99 def test_parallelize_mlp_with_module_api(self): 100 inp_size = [12, 10] 101 model = MLPModule(self.device_type) 102 model_tp = deepcopy(model) 103 104 # Parallelize module. 105 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 106 model_tp = parallelize_module( 107 model_tp, 108 device_mesh, 109 { 110 "net1": ColwiseParallel(output_layouts=Replicate()), 111 "net2": ColwiseParallel(output_layouts=Replicate()), 112 }, 113 ) 114 self._compare_module(model, model_tp, inp_size, rank0_only=False) 115 116 @with_comms 117 def test_parallelize_mlp_with_module_api_nested(self): 118 inp_size = [12, 10] 119 model = torch.nn.Sequential( 120 OrderedDict([("dummy_encoder", MLPModule(self.device_type))]) 121 ) 122 model_tp = deepcopy(model) 123 124 # Parallelize module. 125 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 126 model_tp = parallelize_module( 127 model_tp, 128 device_mesh, 129 { 130 "dummy_encoder.net1": ColwiseParallel(output_layouts=Replicate()), 131 "dummy_encoder.net2": ColwiseParallel(output_layouts=Replicate()), 132 }, 133 ) 134 self._compare_module(model, model_tp, inp_size, rank0_only=False) 135 136 @with_comms 137 def test_linear_row_wise_parallel(self): 138 # test RowwiseParallel 139 inp_size = [9, 16] 140 rowwise = RowwiseParallel() 141 142 torch.manual_seed(5) 143 model = torch.nn.Linear(16, 10, device=self.device_type) 144 model_tp = deepcopy(model) 145 146 # parallelize model_tp 147 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 148 model_tp = parallelize_module(model_tp, device_mesh, rowwise) 149 150 # let each rank generate unique local input 151 torch.manual_seed(self.rank) 152 self._compare_module(model, model_tp, inp_size, rowwise=True) 153 154 @with_comms 155 def test_linear_col_wise_parallel(self): 156 # test ColwiseParallel 157 inp_size = [8, 10] 158 colwise = ColwiseParallel(output_layouts=Replicate()) 159 160 torch.manual_seed(5) 161 model = torch.nn.Linear(10, 16, device=self.device_type) 162 model_tp = deepcopy(model) 163 164 # parallelize model_tp 165 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 166 model_tp = parallelize_module(model_tp, device_mesh, colwise) 167 168 self._compare_module(model, model_tp, inp_size) 169 170 @with_comms 171 def test_prepare_module_input(self): 172 module = DummyModule() 173 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 174 parallelize_module( 175 module, 176 device_mesh, 177 PrepareModuleInput( 178 input_layouts=Shard(0), desired_input_layouts=Replicate() 179 ), 180 ) 181 inp = torch.rand(5, 7, device=self.device_type) 182 output = module(inp).redistribute(device_mesh, [Shard(0)]).to_local() 183 self.assertEqual(inp, output) 184 185 @with_comms 186 def test_prepare_module_output(self): 187 module = DummyModule() 188 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 189 parallelize_module( 190 module, 191 device_mesh, 192 PrepareModuleOutput( 193 output_layouts=Replicate(), desired_output_layouts=Shard(0) 194 ), 195 ) 196 torch.manual_seed(15) 197 inp = torch.rand(16, 7, device=self.device_type) 198 dtensor = DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False) 199 output = module(dtensor) 200 inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local() 201 self.assertEqual(inp, output) 202 203 @with_comms 204 def test_parallelize_module_with_star(self): 205 inp_size = [12, 10] 206 model = MLPModule(self.device_type) 207 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 208 209 model_tp = deepcopy(model) 210 model_tp = parallelize_module( 211 model_tp, 212 device_mesh, 213 { 214 "net*": ColwiseParallel(output_layouts=Replicate()), 215 }, 216 ) 217 self._compare_module(model, model_tp, inp_size, rank0_only=False) 218 219 @with_comms 220 def test_parallelize_module_with_question(self): 221 inp_size = [12, 10] 222 model = MLPModule(self.device_type) 223 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 224 225 model_tp = deepcopy(model) 226 model_tp = parallelize_module( 227 model_tp, 228 device_mesh, 229 { 230 "net?": ColwiseParallel(output_layouts=Replicate()), 231 }, 232 ) 233 self._compare_module(model, model_tp, inp_size, rank0_only=False) 234 235 @with_comms 236 def test_parallelize_module_with_digit(self): 237 inp_size = [12, 10] 238 model = MLPModule(self.device_type) 239 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 240 241 model_tp = deepcopy(model) 242 model_tp = parallelize_module( 243 model_tp, 244 device_mesh, 245 { 246 "net[1-2]": ColwiseParallel(output_layouts=Replicate()), 247 }, 248 ) 249 self._compare_module(model, model_tp, inp_size, rank0_only=False) 250 251 @with_comms 252 def test_parallelize_module_multi_wildcard(self): 253 inp_size = [12, 10] 254 model = MLPStacked(self.device_type, n_layers=2) 255 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 256 257 model_tp = deepcopy(model) 258 model_tp = parallelize_module( 259 model_tp, 260 device_mesh, 261 { 262 "layers.*.net[1]": ColwiseParallel(), 263 "layers.*.net[2]": RowwiseParallel(), 264 }, 265 ) 266 self._compare_module(model, model_tp, inp_size, rank0_only=False) 267 268 269if __name__ == "__main__": 270 run_tests() 271