1# Owner(s): ["oncall: distributed"] 2import copy 3import sys 4from collections import OrderedDict 5from typing import Dict, List, Optional, Tuple 6 7import torch 8from torch import distributed as dist 9from torch.distributed._tensor import ( 10 DeviceMesh, 11 distribute_module, 12 DTensor, 13 init_device_mesh, 14 Replicate, 15 Shard, 16) 17from torch.distributed.fsdp.fully_sharded_data_parallel import ( 18 CPUOffload, 19 FullyShardedDataParallel as FSDP, 20 ShardingStrategy, 21) 22from torch.distributed.tensor.debug import CommDebugMode 23from torch.distributed.tensor.parallel import ( 24 ColwiseParallel, 25 parallelize_module, 26 RowwiseParallel, 27) 28from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 29from torch.testing._internal.common_fsdp import FSDPTest 30from torch.testing._internal.common_utils import ( 31 instantiate_parametrized_tests, 32 run_tests, 33 TEST_WITH_DEV_DBG_ASAN, 34) 35from torch.testing._internal.distributed._tensor.common_dtensor import ( 36 MLPModule, 37 RMSNormPython, 38) 39 40 41if not dist.is_available(): 42 print("Distributed not available, skipping tests", file=sys.stderr) 43 sys.exit(0) 44 45if TEST_WITH_DEV_DBG_ASAN: 46 print( 47 "Skip dev-asan as torch + multiprocessing spawn have known issues", 48 file=sys.stderr, 49 ) 50 sys.exit(0) 51 52 53class SimpleModel(torch.nn.Module): 54 def __init__(self) -> None: 55 super().__init__() 56 self.net1 = torch.nn.Linear(5, 8) 57 self.relu = torch.nn.ReLU() 58 self.net2 = torch.nn.Linear(8, 4) 59 self.net3 = torch.nn.Linear(4, 12) 60 61 def forward(self, x): 62 return self.net3(self.net2(self.relu(self.net1(x)))) 63 64 @staticmethod 65 def get_sharded_param_names() -> List[str]: 66 return ["net1.weight", "net1.bias", "net2.weight"] 67 68 @staticmethod 69 def get_non_sharded_param_names() -> List[str]: 70 return ["net3.weight", "net3.bias"] 71 72 73def distribute_rmsnorm(module, device_mesh): 74 def prepare_input_fn(mod, inputs, device_mesh): 75 shard_tensor = DTensor.from_local(inputs[0], device_mesh, [Shard(0)]) 76 return shard_tensor 77 78 def prepare_output_fn(mod, outputs, device_mesh): 79 return outputs.to_local() 80 81 return distribute_module( 82 module, device_mesh, input_fn=prepare_input_fn, output_fn=prepare_output_fn 83 ) 84 85 86class TestTPFSDPIntegration(FSDPTest): 87 def _get_params_and_sharding_info( 88 self, 89 model: SimpleModel, 90 sharded_param_names: List[str], 91 tensor_parallel_size: int, 92 ) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]: 93 """ """ 94 assert ( 95 type(model) is SimpleModel 96 ), "Expects a `SimpleModel` since the sharding cases on the model definition" 97 param_name_to_numel = OrderedDict() 98 param_name_to_sharding_info = OrderedDict() 99 for param_name, param in model.named_parameters(): 100 if param_name not in sharded_param_names: 101 param_name_to_numel[param_name] = param.numel() 102 else: 103 param_name_to_numel[param_name] = param.numel() // tensor_parallel_size 104 param_name_to_sharding_info[param_name] = ( 105 param.size(), 106 0 if "net1" in param_name else 1, 107 ) 108 return param_name_to_numel, param_name_to_sharding_info 109 110 def _get_sub_pgs(self, tensor_parallel_size: int): 111 """ 112 Generates TP and FSDP subprocess groups. ``tensor_parallel_size`` gives 113 the TP process group size. 114 115 For example, if the global world size is 8 and the tensor parallel size 116 is 2, then this creates: 117 - 4 TP subprocess groups: [0, 1], [2, 3], [4, 5], [6, 7] 118 - 2 FSDP subprocess groups: [0, 2, 4, 6], [1, 3, 5, 7] 119 """ 120 # 2-D mesh is [dp, tp] 121 twod_mesh = DeviceMesh( 122 device_type="cuda", 123 mesh=torch.arange(0, self.world_size).view(-1, tensor_parallel_size), 124 ) 125 126 fsdp_pg = twod_mesh.get_group(mesh_dim=0) 127 tp_pg = twod_mesh.get_group(mesh_dim=1) 128 return twod_mesh, fsdp_pg, tp_pg 129 130 def _sync_tp_grads( 131 self, 132 tp_fsdp_model: FSDP, 133 tp_pg: dist.ProcessGroup, 134 param_name_to_numel: Dict[str, int], 135 non_sharded_param_names: List[str], 136 ) -> None: 137 """ 138 Syncs the tensor parallel parameters' gradients following the data 139 parallel paradigm where gradients are averaged over ranks (in this 140 case, the ones in the tensor parallel process group). 141 """ 142 tp_world_size = tp_pg.size() 143 fsdp_world_size = self.world_size // tp_world_size 144 assert ( 145 type(tp_fsdp_model) is FSDP 146 and len([m for m in tp_fsdp_model.modules() if type(m) is FSDP]) == 1 147 ), ( 148 "The following logic assumes a single top-level-only FSDP wrapping " 149 "the model with TP already applied" 150 ) 151 for flat_param in tp_fsdp_model.params: 152 splits = tuple(param_name_to_numel.values()) 153 # Create a mask over the gradient elements to manually reduce 154 unsharded_size = torch.Size([flat_param.numel() * fsdp_world_size]) 155 unsharded_zeros = torch.zeros(unsharded_size, device=flat_param.device) 156 per_param_masks = unsharded_zeros.split(splits) 157 for param_idx, param_name in enumerate( 158 param_name_to_numel.keys() 159 ): # assumes fixed order 160 if param_name not in non_sharded_param_names: 161 per_param_masks[param_idx][:] = 1 162 unsharded_mask = ( 163 torch.cat(per_param_masks).contiguous().type(torch.BoolTensor) 164 ) 165 sharded_mask = unsharded_mask.chunk(fsdp_world_size)[ 166 self.rank // tp_world_size 167 ] 168 grad_device = flat_param.grad.device 169 grad = flat_param.grad.detach().clone().cuda(self.rank) 170 dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=tp_pg) 171 grad = grad.to(grad_device) 172 flat_param.grad[~sharded_mask] = grad[~sharded_mask] 173 # Average *all* gradient elements to match the FSDP only semantics 174 flat_param.grad /= tp_world_size 175 176 def _get_grads_as_flattened( 177 self, 178 model: FSDP, 179 uses_tp: bool, 180 param_name_to_numel: Dict[str, int], 181 param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]], 182 tp_pg: Optional[dist.ProcessGroup], 183 fsdp_pg: Optional[dist.ProcessGroup], 184 sharded_param_names: Optional[List[str]], 185 ) -> torch.Tensor: 186 """ 187 Returns all unsharded gradients as a single flattened tensor. This 188 returns the same value on all ranks. 189 """ 190 local_grads_as_flattened = ( 191 torch.cat( 192 [ 193 torch.flatten(param.grad) 194 if param.grad is not None 195 else torch.zeros_like(torch.flatten(param)) 196 for param in model.parameters() 197 ] 198 ) 199 .contiguous() 200 .cuda(self.rank) 201 ) 202 all_grads_as_flattened = torch.cat( 203 [torch.empty_like(local_grads_as_flattened) for _ in range(fsdp_pg.size())] 204 ).contiguous() 205 dist.all_gather_into_tensor( 206 all_grads_as_flattened, local_grads_as_flattened, group=fsdp_pg 207 ) 208 if not uses_tp: 209 return all_grads_as_flattened 210 splits = tuple(param_name_to_numel.values()) 211 all_grads_per_param = list(all_grads_as_flattened.split(splits)) 212 for param_idx, param_name in enumerate( 213 param_name_to_numel.keys() 214 ): # assumes fixed order 215 if param_name in sharded_param_names: 216 local_tensor_size = list(param_name_to_sharding_info[param_name][0]) 217 sharding_dim = param_name_to_sharding_info[param_name][1] 218 local_tensor_size[sharding_dim] //= tp_pg.size() 219 local_tensor = all_grads_per_param[param_idx].view(*local_tensor_size) 220 local_tensors = [ 221 torch.empty_like(local_tensor) for _ in range(tp_pg.size()) 222 ] 223 dist.all_gather(local_tensors, local_tensor, group=tp_pg) 224 all_grads_per_param[param_idx] = torch.cat( 225 local_tensors, dim=sharding_dim 226 ).reshape(-1) 227 return torch.cat(all_grads_per_param).contiguous() 228 229 @skip_if_lt_x_gpu(4) 230 def test_fsdp_tp_integration(self): 231 self.run_subtests( 232 { 233 "cpu_offload": [ 234 CPUOffload(offload_params=False), 235 CPUOffload(offload_params=True), 236 ], 237 "sharding_strategy": [None, ShardingStrategy.SHARD_GRAD_OP], 238 "use_orig_params": [False, True], 239 }, 240 self._test_fsdp_tp_integration, 241 ) 242 243 def _test_fsdp_tp_integration( 244 self, cpu_offload, sharding_strategy, use_orig_params 245 ): 246 """ 247 Tests training for TP + FSDP integration by comparing an FSDP-only 248 model with a TP + FSDP model. 249 """ 250 tensor_parallel_size = 2 251 LR = 3e-5 252 torch.manual_seed(0) 253 model = SimpleModel().cuda(self.rank) 254 tp_fsdp_model = copy.deepcopy(model) 255 sharded_param_names = SimpleModel.get_sharded_param_names() 256 non_sharded_param_names = SimpleModel.get_non_sharded_param_names() 257 ( 258 param_name_to_numel, 259 param_name_to_sharding_info, 260 ) = self._get_params_and_sharding_info( 261 model, 262 sharded_param_names, 263 tensor_parallel_size, 264 ) 265 266 input_seed = self.rank 267 torch.manual_seed(input_seed + 1) 268 inp_size = [2, 3, 5] 269 inp = torch.rand(*inp_size).cuda(self.rank) 270 self.assertEqual(model(inp), tp_fsdp_model(inp)) # sanity check 271 272 mesh_1d = init_device_mesh("cuda", (self.world_size,)) 273 fsdp_model = FSDP( 274 model, 275 cpu_offload=cpu_offload, 276 device_mesh=mesh_1d, 277 sharding_strategy=sharding_strategy, 278 use_orig_params=use_orig_params, 279 ) 280 mesh_2d = init_device_mesh( 281 "cuda", 282 (self.world_size // tensor_parallel_size, tensor_parallel_size), 283 mesh_dim_names=["dp", "tp"], 284 ) 285 # Shard with TP and then wrap with FSDP 286 sequence_parallelize_plan = { 287 "net1": ColwiseParallel(input_layouts=Shard(0)), 288 "net2": RowwiseParallel(output_layouts=Shard(0)), 289 } 290 tp_fsdp_model = parallelize_module( 291 tp_fsdp_model, 292 mesh_2d["tp"], 293 sequence_parallelize_plan, 294 ) 295 tp_pg = mesh_2d["tp"].get_group(mesh_dim=0) 296 assert isinstance(tp_fsdp_model.net1.weight, DTensor) 297 assert isinstance(tp_fsdp_model.net2.weight, DTensor) 298 tp_fsdp_model = FSDP( 299 tp_fsdp_model, 300 cpu_offload=cpu_offload, 301 device_mesh=mesh_2d["dp"], 302 sharding_strategy=sharding_strategy, 303 use_orig_params=use_orig_params, 304 ) 305 fsdp_pg = mesh_2d["dp"].get_group(mesh_dim=0) 306 307 # Check the forward by checking output equality 308 fsdp_out = fsdp_model(inp) 309 tp_fsdp_out = tp_fsdp_model(inp) 310 self.assertEqual(fsdp_out, tp_fsdp_out) 311 312 # Check the backward by checking gradient equality 313 fsdp_out.sum().backward() 314 tp_fsdp_out.sum().backward() 315 self._sync_tp_grads( 316 tp_fsdp_model, 317 tp_pg, 318 param_name_to_numel, 319 non_sharded_param_names, 320 ) 321 model_grads = self._get_grads_as_flattened( 322 fsdp_model, 323 False, 324 param_name_to_numel, 325 param_name_to_sharding_info, 326 None, 327 self.process_group, 328 None, 329 ) 330 model_tp_grads = self._get_grads_as_flattened( 331 tp_fsdp_model, 332 True, 333 param_name_to_numel, 334 param_name_to_sharding_info, 335 tp_pg, 336 fsdp_pg, 337 sharded_param_names, 338 ) 339 self.assertEqual(model_grads, model_tp_grads) 340 341 # Check the optimizer step by performing a second forward pass 342 fsdp_optim = torch.optim.SGD(fsdp_model.parameters(), lr=LR) 343 tp_fsdp_optim = torch.optim.SGD(tp_fsdp_model.parameters(), lr=LR) 344 fsdp_optim.step() 345 tp_fsdp_optim.step() 346 torch.manual_seed(input_seed + 16) 347 inp = torch.rand(*inp_size).cuda(self.rank) 348 fsdp_out = fsdp_model(inp) 349 tp_fsdp_out = tp_fsdp_model(inp) 350 self.assertEqual(fsdp_out, tp_fsdp_out) 351 352 @skip_if_lt_x_gpu(4) 353 def test_fsdp_tp_extension_grad(self): 354 """ 355 Tests TP + FSDP extension with correct gradient (i.e. no ACT) 356 """ 357 mesh_2d = init_device_mesh( 358 "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"] 359 ) 360 361 class TestModel(torch.nn.Module): 362 def __init__(self) -> None: 363 super().__init__() 364 self.mlp = MLPModule("cuda") 365 self.mlp_norm = RMSNormPython(10) 366 367 def forward(self, x): 368 return self.mlp(self.mlp_norm(x)) 369 370 model = TestModel().cuda(self.rank) 371 372 # Shard with TP and test gradient 373 tp_mesh = mesh_2d["tp"] 374 tp_model = parallelize_module( 375 model, 376 tp_mesh, 377 { 378 "mlp.net1": ColwiseParallel(input_layouts=Shard(0)), 379 "mlp.net2": RowwiseParallel(output_layouts=Shard(0)), 380 }, 381 ) 382 distribute_rmsnorm(tp_model.mlp_norm, tp_mesh) 383 384 fsdp_2d_model = FSDP(tp_model, device_mesh=mesh_2d["dp"]) 385 comm_mode = CommDebugMode() 386 387 with comm_mode: 388 fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward() 389 390 funcol = torch.ops.c10d_functional 391 c10d_ops = torch.ops.c10d 392 comm_counts = comm_mode.get_comm_counts() 393 self.assertEqual(comm_mode.get_total_counts(), 7) 394 # TP comms 395 self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2) 396 self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2) 397 self.assertEqual(comm_counts[funcol.all_reduce], 1) 398 # FSDP comms 399 self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1) 400 self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1) 401 402 grads = [p.grad for p in fsdp_2d_model.parameters() if p.grad is not None] 403 404 for grad in grads: 405 self.assertFalse(grad.isnan().any().item()) 406 407 @skip_if_lt_x_gpu(4) 408 def test_fsdp_tp_sync_module_state(self): 409 mesh_2d = init_device_mesh( 410 "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"] 411 ) 412 tp_mesh = mesh_2d["tp"] 413 dp_mesh = mesh_2d["dp"] 414 415 # set random seed for each rank 416 torch.manual_seed(mesh_2d.get_rank()) 417 418 class TestModel(torch.nn.Module): 419 def __init__(self) -> None: 420 super().__init__() 421 replicated_dt = DTensor.from_local( 422 torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False 423 ) 424 replicated_buffer_dt = DTensor.from_local( 425 torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False 426 ) 427 self.param = torch.nn.Parameter(replicated_dt) 428 self.buf = torch.nn.Buffer(replicated_buffer_dt) 429 430 def forward(self, x): 431 return self.param + self.buffer + 1 432 433 model = TestModel() 434 435 def assert_local_shard_across_ranks(local_tensor, group, check_equal=True): 436 gathered_tensors = [ 437 torch.empty_like(local_tensor) for _ in range(group.size()) 438 ] 439 dist.all_gather(gathered_tensors, local_tensor, group=group) 440 # on dp mesh dim local tensor does not equal 441 tensor_to_compare = gathered_tensors[0] 442 for tensor in gathered_tensors[1:]: 443 if check_equal: 444 self.assertTrue(torch.equal(tensor, tensor_to_compare)) 445 else: 446 self.assertFalse(torch.equal(tensor, tensor_to_compare)) 447 448 dp_group = dp_mesh.get_group() 449 450 # check on dp mesh dim param local tensor does not equal 451 local_param = model.param.to_local() 452 assert_local_shard_across_ranks(local_param, dp_group, check_equal=False) 453 # check on dp mesh dim buffer local tensor does not equal 454 local_buf = model.buf.to_local() 455 assert_local_shard_across_ranks(local_buf, dp_group, check_equal=False) 456 457 # wrap with fsdp sync param should sync dp mesh dim 458 fsdp_mod = FSDP(model, device_mesh=dp_mesh, sync_module_states=True) 459 with fsdp_mod.summon_full_params(fsdp_mod): 460 # on dp mesh dim local param does equal after sync_module_states 461 local_param = fsdp_mod.param.to_local() 462 assert_local_shard_across_ranks(local_param, dp_group, check_equal=True) 463 464 # on dp mesh dim local buf does equal after sync_module_states 465 local_buf = fsdp_mod.buf.to_local() 466 assert_local_shard_across_ranks(local_buf, dp_group, check_equal=True) 467 468 469instantiate_parametrized_tests(TestTPFSDPIntegration) 470 471if __name__ == "__main__": 472 run_tests() 473