1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import copy 5 6import torch 7import torch.distributed.checkpoint as dcp 8import torch.nn as nn 9from torch.distributed._composable.fsdp import fully_shard 10from torch.distributed._tensor import DTensor, init_device_mesh 11from torch.distributed._tensor.experimental import implicit_replication 12from torch.distributed.checkpoint.state_dict import ( 13 get_model_state_dict, 14 get_optimizer_state_dict, 15 StateDictOptions, 16) 17from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 18from torch.distributed.fsdp.wrap import always_wrap_policy 19from torch.distributed.tensor.parallel import ( 20 ColwiseParallel, 21 parallelize_module, 22 RowwiseParallel, 23) 24from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 25from torch.testing._internal.common_fsdp import FSDPTest, MLP 26from torch.testing._internal.common_utils import run_tests 27from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 28from torch.utils._pytree import tree_all_only 29 30 31class TestFullyShardWithDistributedStateDict(FSDPTest): 32 @property 33 def world_size(self) -> int: 34 return min(4, torch.cuda.device_count()) 35 36 def _get_base_model(self, mlp_dim: int = 2): 37 base_model = nn.Sequential( 38 MLP(mlp_dim), 39 nn.Sequential(MLP(mlp_dim), nn.Linear(mlp_dim, mlp_dim)), 40 MLP(mlp_dim), 41 ) 42 return base_model 43 44 @skip_if_lt_x_gpu(2) 45 def test_1d_fsdp_get_model_state_dict(self): 46 self.run_subtests( 47 {"mlp_dim": [2, 3, 4, 5]}, 48 self._test_1d_fsdp_get_model_state_dict, 49 ) 50 51 def _test_1d_fsdp_get_model_state_dict(self, mlp_dim: int): 52 """ 53 Test model.state_dict() and distributed_state_dict parity. 54 """ 55 base_model = self._get_base_model(mlp_dim) 56 # Default is `reshard_after_forward=True` 57 model1 = copy.deepcopy(base_model) 58 for module in model1: 59 fully_shard(module) 60 fully_shard(model1) 61 62 # osd: original state dict, dsd: distributed state dict 63 osd = model1.state_dict() 64 dsd = get_model_state_dict(model1) 65 self.assertEqual(osd, dsd) 66 67 # Check `reshard_after_forward=False` after a forward 68 model2 = copy.deepcopy(base_model) 69 for module in model2: 70 fully_shard(module, reshard_after_forward=False) 71 fully_shard(model2, reshard_after_forward=False) 72 inp = torch.randn((2, mlp_dim), device="cuda") 73 model2(inp) # parameters are not resharded after this forward 74 # Check that state dict hooks reshard 75 osd_2 = model2.state_dict() 76 dsd_2 = get_model_state_dict(model2) 77 self.assertEqual(osd_2, dsd_2) 78 79 @skip_if_lt_x_gpu(2) 80 def test_1d_fsdp_cpu_offload_full_model_state_dict(self): 81 """ 82 Test full_state_dict and cpu_offload works for FSDP2 state_dict. 83 """ 84 orig_model = self._get_base_model() 85 fsdp_model = copy.deepcopy(orig_model) 86 for module in fsdp_model: 87 fully_shard(module) 88 fully_shard(fsdp_model) 89 90 osd = orig_model.state_dict() 91 dsd = get_model_state_dict( 92 fsdp_model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) 93 ) 94 95 cpu_device = torch.device("cpu") 96 97 def is_cpu(v): 98 if isinstance(v, DTensor): 99 return v.device == torch.device("cpu") 100 else: 101 return v.device == cpu_device 102 103 if self.rank == 0: 104 self.assertEqual(osd, dsd) 105 self.assertTrue(tree_all_only((torch.Tensor, DTensor), is_cpu, osd)) 106 else: 107 self.assertEqual(dsd, {}) 108 109 @skip_if_lt_x_gpu(2) 110 def test_save_with_fsdp1_and_load_with_fsdp2(self): 111 self.run_subtests( 112 { 113 "state_dict_type": [ 114 StateDictType.FULL_STATE_DICT, 115 StateDictType.SHARDED_STATE_DICT, 116 ] 117 }, 118 self._test_save_with_fsdp1_and_load_with_fsdp2, 119 ) 120 121 @skip_if_lt_x_gpu(2) 122 @with_temp_dir 123 def _test_save_with_fsdp1_and_load_with_fsdp2(self, state_dict_type: StateDictType): 124 """ 125 Test that we can save a model with FSDP1 and load it with FSDP2. 126 """ 127 128 # Save state dict with model wrapped with FSDP1 129 fsdp1_model = FSDP( 130 self._get_base_model().cuda(), 131 use_orig_params=True, 132 auto_wrap_policy=always_wrap_policy, 133 ) 134 135 fsdp1_optim = torch.optim.AdamW(fsdp1_model.parameters(), lr=0.1) 136 137 fsdp1_model(torch.randn((2,), device=self.rank)).sum().backward() 138 fsdp1_optim.step() 139 140 with FSDP.state_dict_type(fsdp1_model, state_dict_type): 141 fsdp1_state_dict = { 142 "model": fsdp1_model.state_dict(), 143 "optim": FSDP.sharded_optim_state_dict(fsdp1_model, fsdp1_optim), 144 } 145 dcp.save( 146 fsdp1_state_dict, 147 checkpoint_id=self.temp_dir, 148 ) 149 150 fsdp1_full_msd = get_model_state_dict( 151 fsdp1_model, 152 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 153 ) 154 fsdp1_full_osd = get_optimizer_state_dict( 155 fsdp1_model, 156 fsdp1_optim, 157 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 158 ) 159 160 # Load state dict into model with FSDP2 applied 161 fsdp2_model = self._get_base_model() 162 for module in fsdp2_model: 163 fully_shard(module) 164 fully_shard(fsdp2_model) 165 fsdp2_optim = torch.optim.AdamW(fsdp2_model.parameters(), lr=0.1) 166 167 fsdp2_state_dict = { 168 "model": get_model_state_dict(fsdp2_model), 169 "optim": get_optimizer_state_dict(fsdp2_model, fsdp2_optim), 170 } 171 dcp.load( 172 fsdp2_state_dict, 173 checkpoint_id=self.temp_dir, 174 ) 175 fsdp2_model.load_state_dict(fsdp2_state_dict["model"]) 176 fsdp2_optim.load_state_dict(fsdp2_state_dict["optim"]) 177 178 fsdp2_full_msd = get_model_state_dict( 179 fsdp2_model, 180 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 181 ) 182 fsdp2_full_osd = get_optimizer_state_dict( 183 fsdp2_model, 184 fsdp2_optim, 185 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 186 ) 187 188 # Compare full state dict to make sure they are the same. 189 self.assertEqual(fsdp2_full_msd, fsdp1_full_msd) 190 self.assertEqual(fsdp1_full_osd, fsdp2_full_osd) 191 192 @skip_if_lt_x_gpu(4) 193 @with_temp_dir 194 def test_save_with_fsdp1_and_load_with_fsdp2_tp(self): 195 """ 196 Test that we can save a model with FSDP1 and load it with FSDP2 + TP on 2d mesh. 197 """ 198 199 def _get_base_model(mlp_dim: int = 2): 200 base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim)) 201 return base_model 202 203 # init device mesh 204 dp_size = 2 205 global_mesh = init_device_mesh( 206 "cuda", 207 (dp_size, self.world_size // dp_size), 208 mesh_dim_names=("dp", "tp"), 209 ) 210 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 211 212 # Save state dict with original model 213 base_model = _get_base_model().cuda() 214 base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) 215 216 # Save state dict with model wrapped with FSDP1 217 fsdp1_model = FSDP( 218 copy.deepcopy(base_model), 219 device_mesh=global_mesh, 220 use_orig_params=True, 221 auto_wrap_policy=always_wrap_policy, 222 ) 223 224 fsdp1_optim = torch.optim.AdamW(fsdp1_model.parameters(), lr=0.1) 225 226 # one-step training to modify state dict 227 inp = torch.randn((2,), device=self.rank) 228 base_model(inp).sum().backward() 229 base_optim.step() 230 fsdp1_model(inp).sum().backward() 231 fsdp1_optim.step() 232 233 # obtain the full state dict 234 base_msd = get_model_state_dict( 235 base_model, 236 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 237 ) 238 base_osd = get_optimizer_state_dict( 239 base_model, 240 base_optim, 241 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 242 ) 243 244 # obtain the sharded state dict 245 fsdp1_msd = get_model_state_dict( 246 fsdp1_model, 247 options=StateDictOptions(full_state_dict=False), 248 ) 249 fsdp1_osd = get_optimizer_state_dict( 250 fsdp1_model, 251 fsdp1_optim, 252 options=StateDictOptions(full_state_dict=False), 253 ) 254 255 # save state dict to temp dir 256 source_state_dict = { 257 "model_full": base_msd, 258 "optim_full": base_osd, 259 "model_sharded": fsdp1_msd, 260 "optim_sharded": fsdp1_osd, 261 } 262 dcp.save( 263 source_state_dict, 264 checkpoint_id=self.temp_dir, 265 ) 266 267 # FSDP + TP 268 fsdp2_tp_model = _get_base_model() 269 fsdp2_tp_model = parallelize_module( 270 fsdp2_tp_model, 271 device_mesh=tp_mesh, 272 parallelize_plan={ 273 "0.in_proj": ColwiseParallel(), 274 "0.out_proj": RowwiseParallel(), 275 "1.in_proj": ColwiseParallel(), 276 "1.out_proj": RowwiseParallel(), 277 "2.in_proj": ColwiseParallel(), 278 "2.out_proj": RowwiseParallel(), 279 }, 280 ) 281 for module in fsdp2_tp_model: 282 fully_shard(module, mesh=dp_mesh) 283 fully_shard(fsdp2_tp_model, mesh=dp_mesh) 284 285 fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) 286 287 # Load state dict into model with FSDP2 + TP applied 288 for src_state_dict_type in ["full", "sharded"]: 289 msd_name = f"model_{src_state_dict_type}" 290 osd_name = f"optim_{src_state_dict_type}" 291 fsdp2_tp_state_dict = { 292 msd_name: get_model_state_dict(fsdp2_tp_model), 293 osd_name: get_optimizer_state_dict(fsdp2_tp_model, fsdp2_tp_optim), 294 } 295 # load state dict from temp dir 296 dcp.load( 297 fsdp2_tp_state_dict, 298 checkpoint_id=self.temp_dir, 299 ) 300 fsdp2_tp_model.load_state_dict(fsdp2_tp_state_dict[msd_name]) 301 fsdp2_tp_optim.load_state_dict(fsdp2_tp_state_dict[osd_name]) 302 303 fsdp2_tp_full_msd = get_model_state_dict( 304 fsdp2_tp_model, 305 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 306 ) 307 fsdp2_tp_full_osd = get_optimizer_state_dict( 308 fsdp2_tp_model, 309 fsdp2_tp_optim, 310 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 311 ) 312 313 # Compare full state dict to make sure they are the same. 314 self.assertEqual(base_msd, fsdp2_tp_full_msd) 315 self.assertEqual(base_osd, fsdp2_tp_full_osd) 316 317 @skip_if_lt_x_gpu(4) 318 @with_temp_dir 319 def test_save_with_tp_and_load_with_fsdp2_tp(self): 320 """ 321 Test that we can save a model with TP and load it with FSDP2 + TP on 2d mesh. 322 """ 323 324 def _get_base_model(mlp_dim: int = 2): 325 base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim)) 326 return base_model 327 328 tp_parallelize_plan = { 329 "0.in_proj": ColwiseParallel(), 330 "0.out_proj": RowwiseParallel(), 331 "1.in_proj": ColwiseParallel(), 332 "1.out_proj": RowwiseParallel(), 333 "2.in_proj": ColwiseParallel(), 334 "2.out_proj": RowwiseParallel(), 335 } 336 337 # init device mesh 338 dp_size = 2 339 global_mesh_1d = init_device_mesh( 340 "cuda", (self.world_size,), mesh_dim_names=("tp",) 341 ) 342 global_mesh_2d = init_device_mesh( 343 "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") 344 ) 345 dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] 346 347 # Save state dict with original model 348 base_model = _get_base_model().cuda() 349 base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) 350 351 # Save state dict with TP model 352 tp_model = copy.deepcopy(base_model) 353 tp_model = parallelize_module( 354 tp_model, 355 device_mesh=global_mesh_1d, 356 parallelize_plan=tp_parallelize_plan, 357 ) 358 tp_model_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) 359 360 # one-step training to modify state dict 361 inp = torch.randn((2,), device=self.rank) 362 base_model(inp).sum().backward() 363 base_optim.step() 364 tp_model(inp).sum().backward() 365 tp_model_optim.step() 366 367 # obtain the full state dict 368 base_msd = get_model_state_dict( 369 base_model, 370 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 371 ) 372 base_osd = get_optimizer_state_dict( 373 base_model, 374 base_optim, 375 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 376 ) 377 378 # obtain sharded state dict 379 tp_msd = get_model_state_dict( 380 tp_model, 381 options=StateDictOptions(full_state_dict=False), 382 ) 383 tp_osd = get_optimizer_state_dict( 384 tp_model, 385 tp_model_optim, 386 options=StateDictOptions(full_state_dict=False), 387 ) 388 389 # save state dict to temp dir 390 source_state_dict = { 391 "model_full": base_msd, 392 "optim_full": base_osd, 393 "model_sharded": tp_msd, 394 "optim_sharded": tp_osd, 395 } 396 dcp.save( 397 source_state_dict, 398 checkpoint_id=self.temp_dir, 399 ) 400 401 # FSDP + TP 402 fsdp2_tp_model = _get_base_model() 403 fsdp2_tp_model = parallelize_module( 404 fsdp2_tp_model, 405 device_mesh=tp_mesh, 406 parallelize_plan=tp_parallelize_plan, 407 ) 408 for module in fsdp2_tp_model: 409 fully_shard(module, mesh=dp_mesh) 410 fully_shard(fsdp2_tp_model, mesh=dp_mesh) 411 fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) 412 413 # Load state dict into model with FSDP2 + TP applied 414 for src_state_dict_type in ["full", "sharded"]: 415 msd_name = f"model_{src_state_dict_type}" 416 osd_name = f"optim_{src_state_dict_type}" 417 fsdp2_tp_state_dict = { 418 msd_name: get_model_state_dict(fsdp2_tp_model), 419 osd_name: get_optimizer_state_dict(fsdp2_tp_model, fsdp2_tp_optim), 420 } 421 # load state dict from temp dir 422 dcp.load( 423 fsdp2_tp_state_dict, 424 checkpoint_id=self.temp_dir, 425 ) 426 fsdp2_tp_model.load_state_dict(fsdp2_tp_state_dict[msd_name]) 427 fsdp2_tp_optim.load_state_dict(fsdp2_tp_state_dict[osd_name]) 428 429 fsdp2_tp_full_msd = get_model_state_dict( 430 fsdp2_tp_model, 431 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 432 ) 433 fsdp2_tp_full_osd = get_optimizer_state_dict( 434 fsdp2_tp_model, 435 fsdp2_tp_optim, 436 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 437 ) 438 439 # Compare full state dict to make sure they are the same. 440 self.assertEqual(base_msd, fsdp2_tp_full_msd) 441 self.assertEqual(base_osd, fsdp2_tp_full_osd) 442 443 @skip_if_lt_x_gpu(4) 444 def test_save_with_fsdp2_tp_and_load_with_tp(self): 445 self.run_subtests( 446 {"allow_implicit_replication": [True, False]}, 447 self._test_save_with_fsdp2_tp_and_load_with_tp, 448 ) 449 450 @skip_if_lt_x_gpu(4) 451 @with_temp_dir 452 def _test_save_with_fsdp2_tp_and_load_with_tp( 453 self, allow_implicit_replication: bool 454 ): 455 """ 456 Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP. 457 """ 458 459 def _get_base_model(mlp_dim: int = 2): 460 base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim)) 461 return base_model 462 463 cm = ( 464 implicit_replication() 465 if allow_implicit_replication 466 else contextlib.nullcontext() 467 ) 468 tp_parallelize_plan = { 469 "0.in_proj": ColwiseParallel(), 470 "0.out_proj": RowwiseParallel(), 471 "1.in_proj": ColwiseParallel(), 472 "1.out_proj": RowwiseParallel(), 473 "2.in_proj": ColwiseParallel(), 474 "2.out_proj": RowwiseParallel(), 475 } 476 if allow_implicit_replication: 477 # intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized 478 tp_parallelize_plan.pop("0.in_proj") 479 tp_parallelize_plan.pop("0.out_proj") 480 481 with cm: 482 tp_parallelize_plan = { 483 "0.in_proj": ColwiseParallel(), 484 "0.out_proj": RowwiseParallel(), 485 "1.in_proj": ColwiseParallel(), 486 "1.out_proj": RowwiseParallel(), 487 "2.in_proj": ColwiseParallel(), 488 "2.out_proj": RowwiseParallel(), 489 } 490 491 # init device mesh 492 dp_size = 2 493 global_mesh_1d = init_device_mesh( 494 "cuda", (self.world_size,), mesh_dim_names=("tp",) 495 ) 496 global_mesh_2d = init_device_mesh( 497 "cuda", 498 (dp_size, self.world_size // dp_size), 499 mesh_dim_names=("dp", "tp"), 500 ) 501 dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] 502 503 for save_full_state_dict in [True, False]: 504 # Save state dict with original model 505 base_model = _get_base_model().cuda() 506 base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) 507 508 # Save state dict with FSDP2 + TP model 509 fsdp2_tp_model = copy.deepcopy(base_model) 510 fsdp2_tp_model = parallelize_module( 511 fsdp2_tp_model, 512 device_mesh=tp_mesh, 513 parallelize_plan=tp_parallelize_plan, 514 ) 515 for module in fsdp2_tp_model: 516 fully_shard(module, mesh=dp_mesh) 517 fully_shard(fsdp2_tp_model, mesh=dp_mesh) 518 fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) 519 520 # one-step training to modify state dict 521 inp = torch.randn((2,), device=self.rank) 522 base_model(inp).sum().backward() 523 base_optim.step() 524 fsdp2_tp_model(inp).sum().backward() 525 fsdp2_tp_optim.step() 526 527 # obtain the unsharded state dict 528 base_msd = get_model_state_dict( 529 base_model, 530 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 531 ) 532 base_osd = get_optimizer_state_dict( 533 base_model, 534 base_optim, 535 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 536 ) 537 538 # obtain FSDP2 + TP state dict 539 fsdp2_tp_msd = get_model_state_dict( 540 fsdp2_tp_model, 541 options=StateDictOptions(full_state_dict=save_full_state_dict), 542 ) 543 fsdp2_tp_osd = get_optimizer_state_dict( 544 fsdp2_tp_model, 545 fsdp2_tp_optim, 546 options=StateDictOptions(full_state_dict=save_full_state_dict), 547 ) 548 549 fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd} 550 dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir) 551 552 fsdp2_tp_full_msd = get_model_state_dict( 553 fsdp2_tp_model, 554 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 555 ) 556 fsdp2_tp_full_osd = get_optimizer_state_dict( 557 fsdp2_tp_model, 558 fsdp2_tp_optim, 559 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 560 ) 561 562 # Load state dict into model with TP applied 563 tp_model = _get_base_model() 564 tp_model = parallelize_module( 565 tp_model, 566 device_mesh=global_mesh_1d, 567 parallelize_plan=tp_parallelize_plan, 568 ) 569 tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) 570 571 tp_state_dict = { 572 "model": get_model_state_dict(tp_model), 573 "optim": get_optimizer_state_dict(tp_model, tp_optim), 574 } 575 dcp.load(tp_state_dict, checkpoint_id=self.temp_dir) 576 tp_model.load_state_dict(tp_state_dict["model"]) 577 tp_optim.load_state_dict(tp_state_dict["optim"]) 578 579 tp_full_msd = get_model_state_dict( 580 tp_model, 581 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 582 ) 583 tp_full_osd = get_optimizer_state_dict( 584 tp_model, 585 tp_optim, 586 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 587 ) 588 589 # Compare full state dict to make sure they are the same. 590 self.assertEqual(base_msd, tp_full_msd) 591 self.assertEqual(base_osd, tp_full_osd) 592 self.assertEqual(fsdp2_tp_full_msd, tp_full_msd) 593 self.assertEqual(fsdp2_tp_full_osd, tp_full_osd) 594 595 596if __name__ == "__main__": 597 run_tests() 598