1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5import io 6from copy import deepcopy 7from typing import List, Type 8 9import torch 10import torch.distributed as dist 11import torch.distributed.checkpoint as dcp 12import torch.nn as nn 13import torch.nn.functional as F 14from torch.distributed._composable import replicate 15from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard 16from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard 17from torch.distributed.checkpoint.state_dict import ( 18 get_model_state_dict, 19 get_optimizer_state_dict, 20 set_model_state_dict, 21 set_optimizer_state_dict, 22 StateDictOptions, 23) 24from torch.distributed.device_mesh import DeviceMesh 25from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 26from torch.distributed.fsdp._common_utils import ( 27 _get_module_fsdp_state, 28 clean_tensor_name, 29) 30from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 31from torch.distributed.tensor.debug import CommDebugMode 32from torch.distributed.tensor.parallel import ( 33 ColwiseParallel, 34 parallelize_module, 35 RowwiseParallel, 36) 37from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform 38from torch.distributed.tensor.parallel.fsdp import DTensorExtensions 39from torch.distributed.tensor.parallel.input_reshard import input_reshard 40from torch.nn.parallel import DistributedDataParallel as DDP 41from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 42from torch.testing._internal.common_fsdp import FSDPTest, MLP, MLPStack 43from torch.testing._internal.common_utils import ( 44 instantiate_parametrized_tests, 45 parametrize, 46 run_tests, 47 skipIfRocm, 48) 49from torch.testing._internal.distributed._tensor.common_dtensor import ( 50 DTensorTestBase, 51 MLPModule, 52 ModelArgs, 53 Transformer, 54 with_comms, 55) 56from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 57 58 59class SimpleModel(nn.Module): 60 def __init__(self): 61 super().__init__() 62 self.net1 = nn.Linear(5, 8) 63 self.relu = nn.ReLU() 64 self.net2 = nn.Linear(8, 4) 65 self.net3 = nn.Linear(4, 12) 66 67 def forward(self, x): 68 x = F.relu(self.net1(x)) 69 x = F.relu(self.net2(x)) 70 x = F.relu(self.net3(x)) 71 return x 72 73 def get_input(self): 74 return torch.rand(4, 5, device="cuda") 75 76 77class SimpleModelUneven(nn.Module): 78 def __init__(self): 79 super().__init__() 80 torch.manual_seed(0) 81 self.net1 = nn.Linear(5, 10) 82 self.relu = nn.ReLU() 83 self.net2 = nn.Linear(10, 15) 84 self.net3 = nn.Linear(15, 30) 85 self.net4 = nn.Linear(30, 5) 86 87 def forward(self, x): 88 x = F.relu(self.net1(x)) 89 x = F.relu(self.net2(x)) 90 x = F.relu(self.net3(x)) 91 x = self.net4(x) 92 return x 93 94 def get_input(self): 95 return torch.rand(4, 5, device="cuda") 96 97 98class TestFullyShard2DTraining(FSDPTest): 99 global c10d_ops 100 global funcol 101 c10d_ops = torch.ops.c10d 102 funcol = torch.ops.c10d_functional 103 104 @property 105 def world_size(self) -> int: 106 return min(4, torch.cuda.device_count()) 107 108 def init_global_mesh(self) -> DeviceMesh: 109 # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP 110 dp_size = 2 if self.world_size > 2 else 1 111 return init_device_mesh( 112 "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") 113 ) 114 115 # TODO: remove this test when uneven sharding is supported for FSDP+TP 116 @skip_if_lt_x_gpu(2) 117 def test_2d_uneven_shard_raise_error(self): 118 global_mesh = self.init_global_mesh() 119 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 120 model = MLPStack(3) 121 with self.assertRaisesRegex(NotImplementedError, "uneven sharding"): 122 model.parallelize(tp_mesh, dp_mesh, False) 123 124 @skip_if_lt_x_gpu(2) 125 @skipIfRocm 126 def test_train_parity_2d_mlp(self): 127 global_mesh = self.init_global_mesh() 128 self.run_subtests( 129 { 130 "reshard_after_forward": [False, True], 131 "use_activation_checkpointing": [False, True], 132 # TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding 133 # is supported for FSDP+TP 134 "mlp_dim": [4, 16, 20], 135 }, 136 functools.partial(self._test_train_parity_2d_mlp, global_mesh), 137 ) 138 139 def _test_train_parity_2d_mlp( 140 self, 141 global_mesh: DeviceMesh, 142 reshard_after_forward: bool, 143 use_activation_checkpointing: bool, 144 mlp_dim: int, 145 ): 146 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 147 dp_pg = dp_mesh.get_group() # used for `replicate()` 148 149 torch.manual_seed(42) 150 model = MLPStack(mlp_dim) 151 ref_model = copy.deepcopy(model).cuda() 152 replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) 153 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) 154 model.parallelize( 155 tp_mesh, 156 dp_mesh, 157 use_activation_checkpointing, 158 reshard_after_forward=reshard_after_forward, 159 ) 160 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) 161 162 torch.manual_seed(42 + dp_pg.rank() + 1) 163 device = torch.device("cuda") 164 for iter_idx in range(10): 165 inp = torch.randn((8, mlp_dim), device=device) 166 losses: List[torch.Tensor] = [] 167 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 168 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 169 losses.append(_model(inp).sum()) 170 losses[-1].backward() 171 _optim.step() 172 self.assertEqual(losses[0], losses[1]) 173 174 @skip_if_lt_x_gpu(2) 175 @skipIfRocm 176 def test_tp_with_fsdp_offloading(self): 177 global_mesh = init_device_mesh( 178 "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") 179 ) 180 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 181 torch.manual_seed(42) 182 mlp_dim = 16 183 model = MLPStack(mlp_dim) 184 ref_model = copy.deepcopy(model).cuda() 185 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) 186 # Parallelize with N-way TP and 1-way FSDP 187 model.parallelize( 188 tp_mesh, 189 dp_mesh, 190 use_activation_checkpointing=False, 191 reshard_after_forward=True, 192 offload_policy=CPUOffloadPolicy(), 193 ) 194 for param in model.parameters(): 195 self.assertEqual(param.device.type, "cpu") 196 num_mlps = sum(isinstance(module, MLP) for module in model.modules()) 197 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) 198 199 # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops 200 # called, but they will just be no-ops without issuing any kernels. 201 # We prefer to keep the no-op check at the c10d level, not in FSDP. 202 inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks 203 for iter_idx in range(10): 204 ref_optim.zero_grad() 205 optim.zero_grad() 206 207 with CommDebugMode() as fwd_comm_mode: 208 loss = model(inp).sum() 209 210 fwd_comm_counts = fwd_comm_mode.get_comm_counts() 211 self.assertEqual(len(fwd_comm_counts), 2) 212 self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) 213 self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) 214 ref_loss = ref_model(inp).sum() 215 self.assertEqual(loss, ref_loss) 216 217 with CommDebugMode() as bwd_comm_mode: 218 loss.backward() 219 bwd_comm_counts = bwd_comm_mode.get_comm_counts() 220 self.assertEqual(len(bwd_comm_counts), 3) 221 # First MLP's input gradient does not need to be all-reduced 222 self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) 223 self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) 224 self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) 225 ref_loss.backward() 226 227 optim.step() 228 ref_optim.step() 229 230 @skip_if_lt_x_gpu(2) 231 @with_temp_dir 232 def test_train_parity_2d_transformer_checkpoint_resume(self): 233 """ 234 Tests train parity of a 2D transformer without checkpointing against a 235 2D transformer with a checkpoint save/load. 236 """ 237 self.run_subtests( 238 { 239 "use_seq_parallel": [False, True], 240 # If reusing, then load into the same model/optimizer instance 241 # else construct new ones (requiring eager optim state init) 242 "reuse_model_optim": [False, True], 243 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 244 # TODO: need to update `parallelize` before including foreach=True for testing 245 "foreach": [False], 246 }, 247 self._test_train_parity_2d_transformer_checkpoint_resume, 248 ) 249 250 def _test_train_parity_2d_transformer_checkpoint_resume( 251 self, 252 use_seq_parallel: bool, 253 reuse_model_optim: bool, 254 optimizer_class: Type[torch.optim.Optimizer], 255 foreach: bool, 256 ): 257 def train_step( 258 _model: nn.Module, _optim: torch.optim.Optimizer, _inp: torch.Tensor 259 ) -> torch.Tensor: 260 loss = _model(_inp).sum() 261 loss.backward() 262 _optim.step() 263 _optim.zero_grad() 264 return loss 265 266 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): 267 _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel) 268 for layer in _model.layers: 269 fully_shard(layer, mesh=mesh["dp"]) 270 fully_shard(_model, mesh=mesh["dp"]) 271 return _model 272 273 global_mesh = self.init_global_mesh() 274 # Baseline: run two iterations without checkpointing 275 seed = 42 276 torch.manual_seed(seed) 277 model_args = ModelArgs(dropout_p=0.0) 278 model_no_cp = parallelize( 279 Transformer(model_args), global_mesh, use_seq_parallel 280 ) 281 optim_no_cp = optimizer_class( 282 model_no_cp.parameters(), lr=1e-2, foreach=foreach 283 ) 284 285 torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1) 286 inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda") 287 loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp) 288 loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp) 289 290 # Test: run one iteration, save checkpoint, zero states or init new 291 # model/optimizer, load checkpoint, and run another iteration 292 torch.manual_seed(seed) 293 model_cp = parallelize(Transformer(model_args), global_mesh, use_seq_parallel) 294 optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach) 295 296 loss_cp1 = train_step(model_cp, optim_cp, inp) 297 self.assertEqual(loss_no_cp1, loss_cp1) 298 299 sharded_sd = { 300 "model": get_model_state_dict(model_cp), 301 # Use `get_optimizer_state_dict` to handle eager optim state init 302 # when constructing a new optimizer instance 303 "optim": get_optimizer_state_dict(model_cp, optim_cp), 304 } 305 dcp.save( 306 state_dict=sharded_sd, 307 storage_writer=dcp.FileSystemWriter(self.temp_dir), 308 ) 309 if reuse_model_optim: 310 with torch.no_grad(): 311 for param in model_cp.parameters(): 312 param.zero_() 313 optim_sd = optim_cp.state_dict() 314 for param_states in optim_sd["state"].values(): 315 for state_value in param_states.values(): 316 if torch.is_tensor(state_value): 317 state_value.zero_() 318 else: 319 torch.manual_seed(seed + 1) # different seed 320 model_cp = parallelize( 321 Transformer(model_args), global_mesh, use_seq_parallel 322 ) 323 optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach) 324 self.assertNotEqual(loss_no_cp2, train_step(model_cp, optim_cp, inp)) 325 326 sharded_sd = { 327 "model": get_model_state_dict(model_cp), 328 "optim": get_optimizer_state_dict(model_cp, optim_cp), 329 } 330 dcp.load( 331 state_dict=sharded_sd, 332 storage_reader=dcp.FileSystemReader(self.temp_dir), 333 ) 334 self.assertGreater(len(optim_cp.state_dict()["state"]), 0) 335 336 loss_cp2 = train_step(model_cp, optim_cp, inp) 337 self.assertEqual(loss_no_cp2, loss_cp2) 338 339 340class TestFullyShard2DStateDict(DTensorTestBase): 341 @property 342 def backend(self): 343 # need to specify gloo backend for testing cpu offload 344 return "cpu:gloo,cuda:nccl" 345 346 @with_comms 347 @skip_if_lt_x_gpu(4) 348 def test_fully_shard_tp_2d_set_full_state_dict(self): 349 dummy_model = SimpleModel().cuda() 350 mesh_2d = init_device_mesh( 351 "cuda", 352 (2, self.world_size // 2), 353 mesh_dim_names=("dp", "tp"), 354 ) 355 tp_mesh = mesh_2d["tp"] 356 dp_mesh = mesh_2d["dp"] 357 parallelize_plan = { 358 "net1": ColwiseParallel(), 359 "net2": RowwiseParallel(), 360 "net3": ColwiseParallel(), 361 } 362 model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) 363 fully_shard(model, mesh=dp_mesh) 364 optim = torch.optim.Adam(model.parameters(), lr=0.01) 365 model(model.get_input()).sum().backward() 366 optim.step() 367 # ref_msd, ref_osd are both the default sharded state dict 368 ref_msd = copy.deepcopy(get_model_state_dict(model)) 369 ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim)) 370 371 options = StateDictOptions( 372 full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True 373 ) 374 full_msd = get_model_state_dict(model, options=options) 375 full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options) 376 # load full_msd and full_osd into model and optim. 377 # this loads the slice of full tensor into each rank's local DTensor. 378 set_model_state_dict(model, full_msd, options=options) 379 set_optimizer_state_dict( 380 model, optimizers=optim, optim_state_dict=full_osd, options=options 381 ) 382 383 # check after setting full state dict, the model and optim default sharded state dict 384 # are the same as the initial default sharded state dict. 385 new_msd = get_model_state_dict(model) 386 new_osd = get_optimizer_state_dict(model, optimizers=optim) 387 self.assertEqual(ref_msd, new_msd) 388 self.assertEqual(ref_osd, new_osd) 389 390 391class Test2dFSDP1ParallelIntegration(DTensorTestBase): 392 def init_model(self, device_type, model_parallel_size=2): 393 torch.manual_seed(0) 394 model = MLPModule(device_type) 395 torch.manual_seed(0) 396 twod_model = MLPModule(device_type) 397 model = DDP(model) 398 399 # 2-D mesh is [dp, tp] 400 world_size = dist.get_world_size() 401 mesh_2d = init_device_mesh( 402 device_type, 403 (world_size // model_parallel_size, model_parallel_size), 404 mesh_dim_names=("dp", "tp"), 405 ) 406 407 dp_pg = mesh_2d.get_group(mesh_dim=0) 408 409 parallelize_plan = { 410 "net1": ColwiseParallel(), 411 "net2": RowwiseParallel(), 412 } 413 twod_model = parallelize_module(twod_model, mesh_2d["tp"], parallelize_plan) 414 _pre_dp_module_transform(twod_model) 415 # TODO: Add tests when using gradient_as_bucket_view and static_graph for DDP. 416 twod_model = DDP(twod_model, process_group=dp_pg) 417 return model, twod_model, dp_pg 418 419 def _check_module(self, m1, m2, check_grad=False): 420 named_parameters = dict(m1.named_parameters()) 421 for name, param_m2 in m2.named_parameters(): 422 if name not in named_parameters: 423 print(name, named_parameters.keys()) 424 self.assertTrue(name in named_parameters) 425 param_m1 = named_parameters[name] 426 if check_grad: 427 param_m2 = param_m2.grad 428 param_m1 = param_m1.grad 429 if isinstance(param_m2, DTensor): 430 replicate = [Replicate()] 431 param_m2 = param_m2.redistribute( 432 device_mesh=param_m2.device_mesh, placements=replicate 433 ).to_local() 434 self.assertEqual(param_m2, param_m1) 435 436 @with_comms 437 @skip_if_lt_x_gpu(4) 438 def test_2d_ddp_integration_functionality(self) -> None: 439 model, twod_model, dp_pg = self.init_model(self.device_type) 440 optim = torch.optim.Adam(model.parameters(), lr=3e-5) 441 twod_optim = torch.optim.Adam(twod_model.parameters(), lr=3e-5) 442 443 # Create Input 444 input_seed = dist.get_rank(dp_pg) 445 torch.manual_seed(input_seed + 1) 446 input = torch.rand(4, 10, device=self.device_type) 447 448 output = model(input) 449 twod_output = twod_model(input) 450 self.assertEqual(output, twod_output) 451 452 output.sum().backward() 453 twod_output.sum().backward() 454 self._check_module(model, twod_model, check_grad=True) 455 456 optim.step() 457 twod_optim.step() 458 self._check_module(model, twod_model) 459 460 torch.manual_seed(input_seed + 1004) 461 input = torch.rand(16, 10, device=self.device_type) 462 463 output = model(input) 464 twod_output = twod_model(input) 465 self.assertEqual(output, twod_output) 466 467 # TODO: Add save/load of 2D verification. 468 469 470# TODO: add additional tests for multi_param_group, optim_in_backward, 471# and fsdp_nested. 472class TestNew2dParallelTraining(DTensorTestBase): 473 def _compare_params(self, m1, m2): 474 with FSDP.summon_full_params(m1): 475 with FSDP.summon_full_params(m2): 476 for n_p1, n_p2 in zip(m1.named_parameters(), m2.named_parameters()): 477 p1 = n_p1[1] 478 p2 = n_p2[1] 479 if n_p1[0] != n_p2[0]: 480 self.assertTrue(n_p1[0] in n_p2[0]) 481 name = n_p1[0] 482 if name == "net2.bias" and self.rank != 0: 483 continue 484 if type(p2) is DTensor: 485 p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() 486 self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") 487 488 @with_comms 489 @skip_if_lt_x_gpu(4) 490 def test_raise_invalid_tp_composition(self): 491 with self.assertRaisesRegex( 492 RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh" 493 ): 494 mesh_2d = init_device_mesh( 495 self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp") 496 ) 497 parallelize_plan = { 498 "net1": ColwiseParallel(), 499 "net2": RowwiseParallel(), 500 } 501 model_2d = parallelize_module( 502 SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan 503 ) 504 505 @with_comms 506 @skip_if_lt_x_gpu(4) 507 def test_2d_fsdp_state_enable_extension(self): 508 mesh_2d = init_device_mesh( 509 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 510 ) 511 model = FSDP( 512 SimpleModel().cuda(), 513 device_mesh=mesh_2d["dp"], 514 ) 515 fsdp_state = _get_module_fsdp_state(model) 516 self.assertTrue(isinstance(fsdp_state._fsdp_extension, DTensorExtensions)) 517 518 def _test_2d_e2e_training( 519 self, 520 use_orig_params=False, 521 recompute_activation=False, 522 ) -> None: 523 torch.manual_seed(0) 524 model = SimpleModel().cuda(self.rank) 525 model = FSDP(model, use_orig_params=use_orig_params) 526 optim = torch.optim.Adam(model.parameters(), lr=0.01) 527 528 torch.manual_seed(0) 529 mesh_2d = init_device_mesh( 530 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 531 ) 532 tp_mesh = mesh_2d["tp"] 533 dp_mesh = mesh_2d["dp"] 534 parallelize_plan = { 535 "net1": ColwiseParallel(), 536 "net2": RowwiseParallel(), 537 } 538 model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan) 539 model_2d = FSDP( 540 model_2d, 541 device_mesh=dp_mesh, 542 use_orig_params=use_orig_params, 543 ) 544 optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) 545 546 if recompute_activation: 547 model_2d = input_reshard(model_2d, mesh_2d["tp"], 0) 548 549 # Check named parameters are returning the same name at least. 550 param_names_2d = [ 551 clean_tensor_name(name) for name, _ in model_2d.named_parameters() 552 ] 553 for name, _ in model.named_parameters(): 554 name = clean_tensor_name(name) 555 if name not in param_names_2d: 556 print(name, param_names_2d) 557 self.assertTrue(name in param_names_2d) 558 self._compare_params(model, model_2d) 559 560 # TODO: add additional tests for multi_param_group and optim_in_backward. 561 562 for i in range(5): 563 # Ensure all input across TP ranks are same. 564 # TODO: add a get_group_rank() to DeviceMesh. 565 torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0))) 566 input = torch.rand(4, 5).cuda(self.rank) 567 output = model(input) 568 output_2d = model_2d(input) 569 self.assertEqual(output, output_2d) 570 output.sum().backward() 571 output_2d.sum().backward() 572 optim.step() 573 optim_2d.step() 574 self.assertEqual(model(input), model_2d(input)) 575 576 # Ensure all params are still the same after optimizer update. 577 self._compare_params(model, model_2d) 578 579 @with_comms 580 @skip_if_lt_x_gpu(4) 581 def test_2d_e2e_training_default(self): 582 self._test_2d_e2e_training() 583 584 @with_comms 585 @skip_if_lt_x_gpu(4) 586 def test_2d_e2e_training_use_orig_params(self): 587 self._test_2d_e2e_training(use_orig_params=True) 588 589 @with_comms 590 @skip_if_lt_x_gpu(4) 591 def test_2d_e2e_training_not_use_orig_params(self): 592 # TODO: need to revisit input_reshard API about why it failed multi-gpu tests. 593 # self._test_2d_e2e_training(recompute_activation=True) 594 self._test_2d_e2e_training(recompute_activation=False) 595 596 597# TODO: update all state dict unit tests to use distributed.checkpoint.state_dict, 598# and consolidate all the state_dict test in test.distributed.checkpoint. 599class TestNew2dParallelStateDict(DTensorTestBase): 600 @property 601 def backend(self): 602 # need to specify gloo backend for testing cpu offload 603 return "cpu:gloo,cuda:nccl" 604 605 @with_comms 606 @skip_if_lt_x_gpu(4) 607 def test_fsdp_2d_extension(self): 608 """ 609 Test whether _fsdp_extension from FSDPstate has been set correctly. 610 """ 611 mesh_2d = init_device_mesh( 612 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 613 ) 614 parallelize_plan = { 615 "net1": ColwiseParallel(), 616 "net2": RowwiseParallel(), 617 "net3": ColwiseParallel(), 618 } 619 model_2d = parallelize_module( 620 SimpleModel().cuda(), 621 mesh_2d["tp"], 622 parallelize_plan=parallelize_plan, 623 ) 624 model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True) 625 model_2d_fsdp_state = _get_module_fsdp_state(model_2d) 626 self.assertTrue( 627 isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions) 628 ) 629 630 mesh_1d = init_device_mesh("cuda", (self.world_size,)) 631 model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True) 632 model_1d_fsdp_state = _get_module_fsdp_state(model_1d) 633 self.assertEqual(model_1d_fsdp_state._fsdp_extension, None) 634 635 @with_comms 636 @skip_if_lt_x_gpu(4) 637 @parametrize("is_even_sharded_model", [True, False]) 638 def test_2d_state_dict(self, is_even_sharded_model): 639 simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven 640 641 # Create a model without wrapper 642 torch.manual_seed(0) 643 no_wrap_model = simple_model().cuda(self.rank) 644 no_wrap_state_dict = no_wrap_model.state_dict() 645 646 # Create a model and sharded it with 2D FSDP + TP 647 torch.manual_seed(0) 648 mesh_2d = init_device_mesh( 649 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 650 ) 651 tp_mesh = mesh_2d["tp"] 652 dp_mesh = mesh_2d["dp"] 653 parallelize_plan = { 654 "net1": ColwiseParallel(), 655 "net2": RowwiseParallel(), 656 } 657 model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) 658 model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) 659 660 FSDP.set_state_dict_type( 661 model_2d, 662 StateDictType.SHARDED_STATE_DICT, 663 ) 664 state_dict_2d = model_2d.state_dict() 665 666 for no_wrap_items, two_d_items in zip( 667 no_wrap_state_dict.items(), state_dict_2d.items() 668 ): 669 no_wrap_k, no_wrap_v = no_wrap_items 670 two_d_k, two_d_v = two_d_items 671 672 self.assertEqual(no_wrap_k, two_d_k) 673 674 # check if all value in 2D state_dict are DTensor 675 self.assertTrue(isinstance(two_d_v, DTensor)) 676 self.assertEqual(len(two_d_v.placements), 2) 677 # the outer dimension is the FSDP dimension and the placement is always Shard(0) 678 self.assertEqual(two_d_v.placements[0], Shard(0)) 679 self.assertEqual(two_d_v.device_mesh, mesh_2d) 680 681 # check if the parameter value is the same between 2D model and the model without wrapper 682 all_gather_two_d_v = two_d_v.redistribute( 683 mesh_2d, (Replicate(), Replicate()) 684 ) 685 self.assertEqual( 686 torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True 687 ) 688 689 @with_comms 690 @skip_if_lt_x_gpu(4) 691 @parametrize("is_even_sharded_model", [True, False]) 692 def test_2d_load_state_dict(self, is_even_sharded_model): 693 simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven 694 695 torch.manual_seed(0) 696 mesh_2d = init_device_mesh( 697 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 698 ) 699 tp_mesh = mesh_2d["tp"] 700 dp_mesh = mesh_2d["dp"] 701 parallelize_plan = { 702 "net1": ColwiseParallel(), 703 "net2": RowwiseParallel(), 704 } 705 model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) 706 model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) 707 optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) 708 709 FSDP.set_state_dict_type( 710 model_2d, 711 StateDictType.SHARDED_STATE_DICT, 712 ) 713 checkpoint = io.BytesIO() 714 torch.save(model_2d.state_dict(), checkpoint) 715 # Deepcopy to save current state_dict to compare with the state_dict loaded back below. 716 ref_state_dict = deepcopy(model_2d.state_dict()) 717 718 # Update the parameters so model.state_dict() will be different from ref_dtensor_sd. 719 model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() 720 optim_2d.step() 721 722 # Load ref_state_dict back. 723 checkpoint.seek(0) 724 load_ref_state_dict = torch.load(checkpoint) 725 model_2d.load_state_dict(load_ref_state_dict) 726 new_state_dict = model_2d.state_dict() 727 728 # Check whether new_state_dict is the same as ref_state_dict. 729 for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()): 730 # check whether fqn are the same 731 self.assertEqual(k1, k2) 732 733 self.assertEqual(type(v1), DTensor) 734 self.assertEqual(type(v2), DTensor) 735 # check whether DTensor are the same 736 # TODO: 2D DTensor comparison is not supported at the time, so we are comparing the spec and the local tensor for now. 737 # TODO: Update it to compare the two DTensors once 2D DTensor comparison is supported. 738 self.assertEqual(v1.to_local(), v2.to_local()) 739 self.assertEqual(v1.device_mesh, v2.device_mesh) 740 self.assertEqual(v1.placements, v2.placements) 741 742 @with_comms 743 @skip_if_lt_x_gpu(4) 744 @parametrize("is_even_sharded_model", [True, False]) 745 def test_2d_optim_state_dict(self, is_even_sharded_model): 746 simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven 747 748 # Create a model without wrapper 749 torch.manual_seed(0) 750 no_wrap_model = simple_model().cuda(self.rank) 751 no_wrap_state_dict = no_wrap_model.state_dict() 752 no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01) 753 no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward() 754 no_wrap_optim.step() 755 no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim) 756 757 # Create a model and sharded it with 2D FSDP + TP 758 torch.manual_seed(0) 759 mesh_2d = init_device_mesh( 760 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 761 ) 762 parallelize_plan = { 763 "net1": ColwiseParallel(), 764 "net2": RowwiseParallel(), 765 } 766 model_2d = parallelize_module( 767 simple_model().cuda(), mesh_2d["tp"], parallelize_plan 768 ) 769 model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True) 770 FSDP.set_state_dict_type( 771 model_2d, 772 StateDictType.SHARDED_STATE_DICT, 773 ) 774 optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) 775 model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() 776 optim_2d.step() 777 optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) 778 ref_optim_2d_osd = deepcopy(optim_2d_osd) 779 780 no_wrap_osd_states = no_wrap_osd["state"] 781 optim_2d_osd_states = optim_2d_osd["state"] 782 783 self.assertEqual(len(no_wrap_osd_states), len(optim_2d_osd_states)) 784 self.assertEqual(no_wrap_osd_states.keys(), optim_2d_osd_states.keys()) 785 for fqn, states in no_wrap_osd_states.items(): 786 dist_states = optim_2d_osd_states.get(fqn) 787 788 for state_name, state in states.items(): 789 dist_state = dist_states.get(state_name) 790 # If a state is DTensor, we all gather it in both DP and TP dimension to 791 # compare with no_wrap state. 792 if isinstance(dist_state, DTensor): 793 dist_state = ( 794 dist_state.cuda() 795 .redistribute(placements=(Replicate(), Replicate())) 796 .to_local() 797 ) 798 self.assertTrue(isinstance(dist_state, torch.Tensor)) 799 self.assertTrue(torch.allclose(state, dist_state)) 800 801 # Update the parameters 2d optim states will be different from ref_optim_state_dict. 802 model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() 803 optim_2d.step() 804 805 set_optimizer_state_dict( 806 model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd 807 ) 808 new_optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) 809 810 ref_optim_2d_osd_states = ref_optim_2d_osd["state"] 811 new_optim_2d_osd_states = optim_2d_osd["state"] 812 813 # Compare the new optim state dict after load with the reference one 814 self.assertEqual(len(ref_optim_2d_osd_states), len(new_optim_2d_osd_states)) 815 self.assertEqual(ref_optim_2d_osd_states.keys(), new_optim_2d_osd_states.keys()) 816 for fqn, states in ref_optim_2d_osd_states.items(): 817 new_states = new_optim_2d_osd_states.get(fqn) 818 819 for state_name, state in states.items(): 820 new_state = new_states.get(state_name) 821 822 if isinstance(new_state, DTensor): 823 self.assertEqual(new_state.placements, state.placements) 824 self.assertEqual(new_state.device_mesh, state.device_mesh) 825 self.assertTrue( 826 torch.allclose(new_state.to_local(), state.to_local()) 827 ) 828 else: 829 self.assertEqual(new_state, state) 830 831 832instantiate_parametrized_tests(TestNew2dParallelStateDict) 833 834if __name__ == "__main__": 835 run_tests() 836