1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5import sys 6from itertools import chain 7from typing import Callable, Tuple, Type, Union 8 9import torch 10import torch.distributed as dist 11import torch.nn as nn 12from torch.distributed._composable import fully_shard, replicate 13 14# importing fully_shard as FSDP2 since the original fully_shard is used in this test. 15# TODO: remove old composable fully_shard so that we don't have to import new fully_shard as FSDP2 16from torch.distributed._composable.fsdp import fully_shard as FSDP2 17from torch.distributed._shard.sharded_tensor import ShardedTensor 18from torch.distributed._tensor import DTensor, init_device_mesh 19from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 20 apply_activation_checkpointing, 21) 22from torch.distributed.checkpoint import state_dict as ptd_state_dict 23from torch.distributed.checkpoint.state_dict import ( 24 _patch_model_state_dict, 25 _patch_optimizer_state_dict, 26 get_model_state_dict, 27 get_optimizer_state_dict, 28 get_state_dict, 29 set_model_state_dict, 30 set_optimizer_state_dict, 31 StateDictOptions, 32) 33from torch.distributed.fsdp import ( 34 FullyShardedDataParallel as FSDP, 35 ShardingStrategy, 36 StateDictType, 37) 38from torch.distributed.fsdp.wrap import ModuleWrapPolicy 39from torch.distributed.optim import _apply_optimizer_in_backward 40from torch.nn.parallel import DistributedDataParallel as DDP 41from torch.optim import Optimizer 42from torch.testing._internal.common_dist_composable import ( 43 CompositeParamModel, 44 UnitModule, 45) 46from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 47from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 48from torch.testing._internal.distributed._tensor.common_dtensor import ( 49 DTensorTestBase, 50 MultiProcessTestCase, 51 with_comms, 52) 53from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin 54from torch.utils._pytree import tree_all, tree_all_only 55 56 57if not dist.is_available(): 58 print("Distributed not available, skipping tests", file=sys.stderr) 59 sys.exit(0) 60 61if TEST_WITH_DEV_DBG_ASAN: 62 print( 63 "Skip dev-asan as torch + multiprocessing spawn have known issues", 64 file=sys.stderr, 65 ) 66 sys.exit(0) 67 68 69class TestStateDict(DTensorTestBase, VerifyStateDictMixin): 70 """Tests state_dict and load_state_dict""" 71 72 @property 73 def world_size(self) -> int: 74 return min(4, torch.cuda.device_count()) 75 76 def _test_save_load( 77 self, 78 init_model_optim: Callable, 79 test_frozen: bool = False, 80 ) -> None: 81 options = StateDictOptions(ignore_frozen_params=test_frozen) 82 # Initialize original model and distributed model. 83 model, optim, copy_optim, dist_model, dist_optim = init_model_optim() 84 85 # Train 10 steps. 86 for i in range(10): 87 batch = torch.rand(8, 100, device="cuda") 88 model(batch).sum().backward() 89 optim.step() 90 dist_model(batch).sum().backward() 91 if not isinstance(dist_optim, list): 92 dist_optim.step() 93 dist_optim.zero_grad() 94 else: 95 for _dist_optim in dist_optim: 96 _dist_optim.zero_grad() 97 optim.zero_grad() 98 99 # Get the state_dict, and compare the result 100 msd = model.state_dict() 101 osd = optim.state_dict() 102 dist_msd, dist_osd = get_state_dict( 103 dist_model, optimizers=dist_optim, options=options 104 ) 105 self._verify_msd(msd, dist_msd, options) 106 self._verify_osd_by_load(model, optim, copy_optim, dist_osd) 107 self._verify_osd(model, optim, osd, dist_osd) 108 109 # Initialize a completely new model to simulate checkpoint load. 110 _, _, _, dist_model, dist_optim = init_model_optim() 111 112 # Simulate DCP distributed load. We need to first get the state_dict and 113 # pass them to DCP to load the saved state_dict from the storage. 114 # Then finally we can call set_state_dict(). 115 if not isinstance(dist_optim, list): 116 dist_optim = [dist_optim] 117 if test_frozen: 118 # We won't be able to load the partial state_dict back. 119 return 120 # Since we already have the state_dict saved before, no need to call DCP. 121 # We can directly load them back. This asser is to ensure that optimizer 122 # state storage are initialized. 123 # self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE])) 124 set_model_state_dict( 125 dist_model, 126 model_state_dict=dist_msd, 127 options=options, 128 ) 129 set_optimizer_state_dict( 130 dist_model, 131 optimizers=dist_optim, 132 optim_state_dict=dist_osd, 133 options=options, 134 ) 135 136 # Check if the new state_dict are the same 137 dist_msd, dist_osd = get_state_dict( 138 dist_model, optimizers=dist_optim, options=options 139 ) 140 self._verify_msd(msd, dist_msd, options) 141 # TODO: Ditto 142 # self._verify_osd_by_load(model, optim, copy_optim, dist_osd) 143 self._verify_osd(model, optim, osd, dist_osd) 144 145 # Test _patch_model_state_dict, and _patch_optimizer_state_dict 146 _patch_model_state_dict(dist_model, options=options) 147 _patch_optimizer_state_dict(dist_model, optimizers=dist_optim, options=options) 148 dist_msd = dist_model.state_dict() 149 dist_osd = dist_optim[0].state_dict() 150 self._verify_msd(msd, dist_msd, options) 151 self._verify_osd_by_load(model, optim, copy_optim, dist_osd) 152 self._verify_osd(model, optim, osd, dist_osd) 153 154 def _test_fsdp( 155 self, 156 *, 157 use_orig_params: bool, 158 use_composable: bool, 159 use_dtensor: bool, 160 wrapping: Tuple[nn.Module] = (), 161 compile_model: bool = False, 162 optimizer_class: Type[Optimizer], 163 ) -> None: 164 if not use_orig_params and use_composable: 165 return 166 167 # TODO: remove this return after we complete the composable API side change for device_mesh 168 if use_composable and use_dtensor: 169 return 170 171 def init_model_optim(): 172 if use_dtensor: 173 device_mesh = init_device_mesh("cuda", (self.world_size,)) 174 175 orig_model = CompositeParamModel(device=torch.device("cuda")) 176 orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 177 copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 178 if wrapping: 179 strategy = set(wrapping) 180 else: 181 strategy = {UnitModule} 182 if use_composable: 183 dist_model = fully_shard( 184 copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy) 185 ) 186 else: 187 if use_dtensor: 188 device_mesh = init_device_mesh("cuda", (self.world_size,)) 189 dist_model = FSDP( 190 copy.deepcopy(orig_model), 191 auto_wrap_policy=ModuleWrapPolicy(strategy), 192 use_orig_params=use_orig_params, 193 device_mesh=device_mesh, 194 ) 195 else: 196 dist_model = FSDP( 197 copy.deepcopy(orig_model), 198 auto_wrap_policy=ModuleWrapPolicy(strategy), 199 use_orig_params=use_orig_params, 200 ) 201 202 if compile_model: 203 dist_model = torch.compile(dist_model) 204 dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) 205 return orig_model, orig_optim, copy_optim, dist_model, dist_optim 206 207 self._test_save_load(init_model_optim) 208 209 @with_comms 210 @skip_if_lt_x_gpu(2) 211 def test_fsdp(self) -> None: 212 self.run_subtests( 213 { 214 "use_orig_params": [True, False], 215 "use_composable": [True, False], 216 "use_dtensor": [True, False], 217 "wrapping": [tuple(), (nn.Linear, UnitModule)], 218 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 219 }, 220 self._test_fsdp, 221 ) 222 223 @with_comms 224 @skip_if_lt_x_gpu(2) 225 def test_compiled_fsdp(self) -> None: 226 self.run_subtests( 227 { 228 "use_orig_params": [True], 229 "use_composable": [False], 230 "use_dtensor": [False], 231 "wrapping": [tuple()], 232 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 233 }, 234 self._test_fsdp, 235 ) 236 237 def _test_fsdp2( 238 self, 239 *, 240 reshard_after_forward: Union[bool, int], 241 optimizer_class: Type[Optimizer], 242 compile_model: bool, 243 foreach: bool = True, 244 ): 245 def init_model_optim(): 246 orig_model = CompositeParamModel(device=torch.device("cuda")) 247 orig_optim = optimizer_class( 248 orig_model.parameters(), lr=1e-3, foreach=foreach 249 ) 250 copy_optim = optimizer_class( 251 orig_model.parameters(), lr=1e-3, foreach=foreach 252 ) 253 254 dist_model = FSDP2( 255 copy.deepcopy(orig_model), 256 reshard_after_forward=reshard_after_forward, 257 ) 258 259 if compile_model: 260 dist_model = torch.compile(dist_model) 261 dist_optim = optimizer_class( 262 dist_model.parameters(), lr=1e-3, foreach=foreach 263 ) 264 265 return orig_model, orig_optim, copy_optim, dist_model, dist_optim 266 267 self._test_save_load(init_model_optim) 268 269 @with_comms 270 @skip_if_lt_x_gpu(2) 271 def test_fsdp2(self) -> None: 272 self.run_subtests( 273 { 274 "reshard_after_forward": [True, False], 275 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 276 "compile_model": [True, False], 277 }, 278 self._test_fsdp2, 279 ) 280 281 def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None: 282 def init_model_optim(): 283 orig_model = CompositeParamModel(device=torch.device("cuda")) 284 orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 285 copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 286 if use_composable: 287 dist_model = replicate(copy.deepcopy(orig_model)) 288 else: 289 dist_model = DDP(copy.deepcopy(orig_model)) 290 dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) 291 return orig_model, orig_optim, copy_optim, dist_model, dist_optim 292 293 self._test_save_load(init_model_optim) 294 295 @with_comms 296 @skip_if_lt_x_gpu(2) 297 def test_ddp(self) -> None: 298 self.run_subtests( 299 { 300 "use_composable": [True, False], 301 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 302 }, 303 self._test_ddp, 304 ) 305 306 def _test_fsdp_ddp( 307 self, 308 use_composable: bool, 309 optimizer_class: Type[Optimizer], 310 optim_in_backward: bool = False, 311 test_frozen: bool = False, 312 ) -> None: 313 def init_model_optim(): 314 orig_model = CompositeParamModel(device=torch.device("cuda")) 315 if test_frozen: 316 for param in chain( 317 orig_model.u1.parameters(), orig_model.u2.parameters() 318 ): 319 param.requires_grad = False 320 orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 321 copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 322 dist_model = copy.deepcopy(orig_model) 323 if use_composable: 324 replicate(dist_model.l) 325 fully_shard(dist_model, policy=ModuleWrapPolicy({UnitModule})) 326 else: 327 dist_model.l = DDP(dist_model.l) 328 dist_model = FSDP( 329 copy.deepcopy(orig_model), 330 auto_wrap_policy=ModuleWrapPolicy({UnitModule}), 331 use_orig_params=optim_in_backward, 332 ignored_modules=[dist_model.l], 333 ) 334 if optim_in_backward: 335 _apply_optimizer_in_backward( 336 optimizer_class, dist_model.parameters(), {"lr": 1e-3} 337 ) 338 dist_optim = [ 339 p._in_backward_optimizers[0] for p in dist_model.parameters() 340 ] 341 else: 342 dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) 343 return orig_model, orig_optim, copy_optim, dist_model, dist_optim 344 345 self._test_save_load(init_model_optim, test_frozen) 346 347 @with_comms 348 @skip_if_lt_x_gpu(2) 349 def test_fsdp_ddp(self) -> None: 350 self.run_subtests( 351 { 352 "use_composable": [True, False], 353 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 354 }, 355 self._test_fsdp_ddp, 356 ) 357 358 @with_comms 359 @skip_if_lt_x_gpu(2) 360 def test_frozen_parameters(self) -> None: 361 self.run_subtests( 362 { 363 "use_composable": [True], 364 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 365 "test_frozen": [True], 366 }, 367 self._test_fsdp_ddp, 368 ) 369 370 # TODO: enable use_dtensor once 2D device_mesh support is fully landed. 371 """ 372 @with_comms 373 @skip_if_lt_x_gpu(2) 374 def test_use_dtensor(self) -> None: 375 self._test_fsdp_ddp(use_composable=False, use_dtensor=True) 376 """ 377 378 # TODO: enable the test after FSDP + apply_optimizer_in_backward works. 379 # Disable this test as it is broken after 380 # https://github.com/pytorch/pytorch/pull/108298. 381 """ 382 @with_comms 383 @skip_if_lt_x_gpu(2) 384 def test_apply_optimizer_in_backward(self) -> None: 385 self.run_subtests( 386 {"use_composable": [True, False]}, 387 self._test_fsdp_ddp, 388 optim_in_backward=True, 389 ) 390 """ 391 392 def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None: 393 def init_model_optim(): 394 orig_model = CompositeParamModel(device=torch.device("cuda")) 395 orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 396 copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) 397 model_copy = copy.deepcopy(orig_model) 398 optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3) 399 return orig_model, orig_optim, copy_optim, model_copy, optim_copy 400 401 self._test_save_load(init_model_optim) 402 403 @with_comms 404 @skip_if_lt_x_gpu(1) 405 def test_single_gpu(self) -> None: 406 self.run_subtests( 407 {"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]}, 408 self._test_single_gpu, 409 ) 410 411 @with_comms 412 @skip_if_lt_x_gpu(1) 413 def test_strict(self) -> None: 414 model = CompositeParamModel(device=torch.device("cuda")) 415 416 model_state_dict = get_model_state_dict(model) 417 key = next(iter(model_state_dict.keys())) 418 model_state_dict["abc"] = torch.zeros(10) 419 with self.assertRaisesRegex(RuntimeError, "Unexpected key"): 420 set_model_state_dict(model, model_state_dict=model_state_dict) 421 model_state_dict.pop(key) 422 incompatible_keys = set_model_state_dict( 423 model, 424 model_state_dict=model_state_dict, 425 options=StateDictOptions(strict=False), 426 ) 427 self.assertEqual(incompatible_keys.missing_keys, [key]) 428 self.assertEqual(incompatible_keys.unexpected_keys, ["abc"]) 429 model_state_dict.pop("abc") 430 with self.assertRaisesRegex(RuntimeError, "Missing key"): 431 set_model_state_dict(model, model_state_dict=model_state_dict) 432 433 def _test_cpu_offload_full_state_dict( 434 self, optimizer_class: Type[Optimizer] 435 ) -> None: 436 orig_model = CompositeParamModel(device=torch.device("cuda")) 437 device_mesh = init_device_mesh("cuda", (self.world_size,)) 438 dist_model = FSDP( 439 copy.deepcopy(orig_model), 440 auto_wrap_policy=ModuleWrapPolicy({UnitModule}), 441 use_orig_params=True, 442 device_mesh=device_mesh, 443 ) 444 445 dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) 446 447 mst, ost = get_state_dict( 448 dist_model, 449 dist_optim, 450 options=StateDictOptions(cpu_offload=True), 451 ) 452 453 cpu_device = torch.device("cpu") 454 455 def is_cpu(v): 456 if isinstance(v, DTensor): 457 return v.device == cpu_device 458 elif isinstance(v, ShardedTensor): 459 shards = v.local_shards() 460 if not shards: 461 return True 462 return shards[0].tensor.device == cpu_device 463 else: 464 return v.device == cpu_device 465 466 self.assertTrue( 467 tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst) 468 ) 469 self.assertTrue( 470 tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost) 471 ) 472 473 mst, ost = get_state_dict( 474 dist_model, dist_optim, options=StateDictOptions(full_state_dict=True) 475 ) 476 477 self.assertTrue( 478 tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), mst) 479 ) 480 self.assertTrue( 481 tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), ost) 482 ) 483 484 mst, ost = get_state_dict( 485 dist_model, 486 dist_optim, 487 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 488 ) 489 490 if self.rank == 0: 491 self.assertTrue( 492 tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst) 493 ) 494 self.assertTrue( 495 tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost) 496 ) 497 else: 498 self.assertEqual(mst, {}) 499 self.assertEqual(ost, {}) 500 501 @with_comms 502 @skip_if_lt_x_gpu(2) 503 def test_cpu_offload_full_state_dict(self) -> None: 504 self.run_subtests( 505 {"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]}, 506 self._test_cpu_offload_full_state_dict, 507 ) 508 509 @with_comms 510 @skip_if_lt_x_gpu(1) 511 def test_activation_ckpt_fqns_ddp(self) -> None: 512 """Tests that activation checkpointing prefixes are removed from module names""" 513 model = CompositeParamModel(device=torch.device("cuda")) 514 original_keys = get_model_state_dict(model).keys() 515 516 apply_activation_checkpointing(model) 517 model = DDP(model) 518 new_keys = get_model_state_dict(model).keys() 519 520 self.assertEqual(original_keys, new_keys) 521 522 @with_comms 523 @skip_if_lt_x_gpu(1) 524 def test_activation_ckpt_fqns_fsdp1(self) -> None: 525 self.run_subtests( 526 {"use_orig_params": [True, False]}, 527 self._test_activation_ckpt_fqns_fsdp1, 528 ) 529 530 def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: 531 """Tests that activation checkpointing prefixes are removed from module names""" 532 model = CompositeParamModel(device=torch.device("cuda")) 533 original_keys = get_model_state_dict(model).keys() 534 535 apply_activation_checkpointing(model) 536 model = FSDP(model, use_orig_params=use_orig_params) 537 new_keys = get_model_state_dict(model).keys() 538 539 self.assertEqual(original_keys, new_keys) 540 541 @with_comms 542 @skip_if_lt_x_gpu(1) 543 def test_extra_state(self) -> None: 544 model = CompositeParamModel(device=torch.device("cuda")) 545 546 def get_extra_state(self): 547 return "MyState" 548 549 def set_extra_state(self, state): 550 return 551 552 UnitModule.get_extra_state = get_extra_state 553 UnitModule.set_extra_state = set_extra_state 554 555 ddp_model = DDP(copy.deepcopy(model)) 556 set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) 557 self.assertEqual(model.state_dict()["u1._extra_state"], "MyState") 558 self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model)) 559 560 @with_comms 561 @skip_if_lt_x_gpu(1) 562 def test_non_persistent_buffers(self) -> None: 563 model = CompositeParamModel(device=torch.device("cuda")) 564 model.register_buffer( 565 "dont_save_me", torch.rand(100, device="cuda"), persistent=False 566 ) 567 ddp_model = DDP(copy.deepcopy(model)) 568 set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) 569 self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model)) 570 571 def _test_broadcast_from_rank0(self, wrapper) -> None: 572 model = CompositeParamModel(device=torch.device("cuda")) 573 optim = torch.optim.Adam(model.parameters()) 574 fsdp_model = wrapper(copy.deepcopy(model)) 575 fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) 576 577 batch = torch.rand(8, 100, device="cuda") 578 model(batch).sum().backward() 579 optim.step() 580 states, optim_states = get_state_dict(model, optim) 581 582 fsdp_model(batch).sum().backward() 583 fsdp_optim.step() 584 585 def check(equal): 586 fsdp_states = get_model_state_dict( 587 fsdp_model, 588 options=StateDictOptions(full_state_dict=True), 589 ) 590 fsdp_optim_states = get_optimizer_state_dict( 591 fsdp_model, 592 fsdp_optim, 593 options=StateDictOptions(full_state_dict=True), 594 ) 595 if equal: 596 self.assertEqual(states, fsdp_states) 597 self.assertEqual(optim_states, fsdp_optim_states) 598 else: 599 self.assertNotEqual(states, fsdp_states) 600 self.assertNotEqual(optim_states, fsdp_optim_states) 601 602 check(equal=True) 603 fsdp_model(batch).sum().backward() 604 fsdp_optim.step() 605 check(equal=False) 606 607 # Drop the states to simulate loading from rank0 608 if dist.get_rank() > 0: 609 load_states = {} 610 load_states2 = {} 611 load_optim_states = {} 612 else: 613 load_states = copy.deepcopy(states) 614 load_states2 = copy.deepcopy(states) 615 load_optim_states = copy.deepcopy(optim_states) 616 617 set_model_state_dict( 618 fsdp_model, 619 model_state_dict=load_states, 620 options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), 621 ) 622 set_optimizer_state_dict( 623 fsdp_model, 624 fsdp_optim, 625 optim_state_dict=load_optim_states, 626 options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), 627 ) 628 629 check(equal=True) 630 # Verify the `strict` flag. 631 load_states = load_states2 632 if load_states: 633 key = next(iter(load_states.keys())) 634 load_states.pop(key) 635 with self.assertRaisesRegex(RuntimeError, "Missing key"): 636 set_model_state_dict( 637 fsdp_model, 638 model_state_dict=load_states, 639 options=StateDictOptions( 640 broadcast_from_rank0=True, full_state_dict=True 641 ), 642 ) 643 644 @with_comms 645 @skip_if_lt_x_gpu(2) 646 def test_broadcast_from_rank0(self) -> None: 647 device_mesh = init_device_mesh("cuda", (self.world_size,)) 648 self.run_subtests( 649 { 650 "wrapper": [ 651 functools.partial(FSDP2, mesh=device_mesh), 652 functools.partial(FSDP, device_mesh=device_mesh), 653 ] 654 }, 655 self._test_broadcast_from_rank0, 656 ) 657 658 @with_comms 659 @skip_if_lt_x_gpu(4) 660 def test_broadcast_from_rank0_hsdp(self) -> None: 661 device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) 662 self.run_subtests( 663 { 664 "wrapper": [ 665 functools.partial( 666 FSDP, 667 device_mesh=device_mesh, 668 sharding_strategy=ShardingStrategy.HYBRID_SHARD, 669 ), 670 ] 671 }, 672 self._test_broadcast_from_rank0, 673 ) 674 675 @with_comms 676 @skip_if_lt_x_gpu(2) 677 def test_fsdp_root_not_initialized(self) -> None: 678 # This test verifies that FSDP root is not initialized but we should 679 # still be able to get the state_dict without errors because 680 # fsdp_model.state_dict() will trigger the FSDP initialization. 681 device_mesh = init_device_mesh("cuda", (self.world_size,)) 682 model = CompositeParamModel(device=torch.device("cuda")) 683 fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) 684 fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) 685 get_model_state_dict(fsdp_model) 686 get_optimizer_state_dict(fsdp_model, fsdp_optim) 687 688 @with_comms 689 @skip_if_lt_x_gpu(2) 690 def test_optim_state_dict_param_matching(self) -> None: 691 # This test verifies parameters between optim and optim_state_dict 692 # "initial_lr" is added to optim_state_dict, but not to the new optim 693 # We test whether "initial_lr" appear in optim after 694 # set_optimizer_state_dict. 695 device = "cuda" 696 torch.manual_seed(0) 697 model = nn.Sequential( 698 *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] 699 ) 700 for layer in model: 701 fully_shard(layer) 702 fully_shard(model) 703 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 704 torch.optim.lr_scheduler.LambdaLR( 705 optim, lr_lambda=[lambda epoch: 0.95**epoch] 706 ) 707 opt_state_dict = ptd_state_dict.get_optimizer_state_dict( 708 model, 709 optim, 710 options=ptd_state_dict.StateDictOptions( 711 full_state_dict=True, cpu_offload=True 712 ), 713 ) 714 if dist.get_rank() == 0: 715 self.assertTrue("initial_lr" in opt_state_dict["param_groups"][0]) 716 717 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 718 self.assertTrue("initial_lr" not in optim.param_groups[0]) 719 720 ptd_state_dict.set_optimizer_state_dict( 721 model, 722 optim, 723 optim_state_dict=opt_state_dict, 724 options=ptd_state_dict.StateDictOptions( 725 broadcast_from_rank0=True, full_state_dict=True 726 ), 727 ) 728 if dist.get_rank() == 0: 729 self.assertTrue("initial_lr" in optim.param_groups[0]) 730 731 @with_comms 732 @skip_if_lt_x_gpu(2) 733 def test_flattened_osd(self) -> None: 734 device_mesh = init_device_mesh("cuda", (self.world_size,)) 735 model = CompositeParamModel(device=torch.device("cuda")) 736 fsdp_model = FSDP2(copy.deepcopy(model), mesh=device_mesh) 737 fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) 738 batch = torch.rand(8, 100, device="cuda") 739 fsdp_model(batch).sum().backward() 740 fsdp_optim.step() 741 fsdp_optim.zero_grad() 742 osd1 = get_optimizer_state_dict(fsdp_model, fsdp_optim) 743 osd2 = get_optimizer_state_dict( 744 fsdp_model, 745 fsdp_optim, 746 options=StateDictOptions(flatten_optimizer_state_dict=True), 747 ) 748 fsdp_optim2 = torch.optim.AdamW(fsdp_model.parameters()) 749 set_optimizer_state_dict( 750 fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd2 751 ) 752 self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) 753 set_optimizer_state_dict( 754 fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd1 755 ) 756 self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) 757 758 @with_comms 759 @skip_if_lt_x_gpu(1) 760 def test_deprecate_partial(self) -> None: 761 model = CompositeParamModel(device=torch.device("cuda")) 762 763 model_state_dict1 = get_model_state_dict(model) 764 model_state_dict1 = copy.deepcopy(model_state_dict1) 765 with self.assertWarnsRegex( 766 FutureWarning, 767 "Getting submodules only model/optim state_dict is deprecated", 768 ): 769 model_state_dict2 = get_model_state_dict(model, submodules={model.l}) 770 model_state_dict2 = copy.deepcopy(model_state_dict2) 771 with self.assertWarnsRegex( 772 FutureWarning, 773 "Getting submodules only model/optim state_dict is deprecated", 774 ): 775 model_state_dict3 = get_model_state_dict( 776 model, 777 submodules={model.l}, 778 options=StateDictOptions(keep_submodule_prefixes=False), 779 ) 780 model_state_dict3 = copy.deepcopy(model_state_dict3) 781 self.assertEqual(len(model_state_dict2), 2) 782 self.assertEqual(len(model_state_dict3), 2) 783 for key in model_state_dict3.keys(): 784 full_fqn = f"l.{key}" 785 value1 = model_state_dict1[full_fqn] 786 value2 = model_state_dict2[full_fqn] 787 value3 = model_state_dict3[key] 788 self.assertEqual(value1, value2) 789 self.assertEqual(value2, value3) 790 791 zeros_state_dict = { 792 k: torch.zeros_like(v) for k, v in model_state_dict1.items() 793 } 794 model.load_state_dict(zeros_state_dict) 795 set_model_state_dict( 796 model, 797 model_state_dict=model_state_dict2, 798 options=StateDictOptions(strict=False), 799 ) 800 self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) 801 self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) 802 803 model.load_state_dict(zeros_state_dict) 804 with self.assertWarnsRegex(FutureWarning, "Passing model_state_dict as a "): 805 set_model_state_dict( 806 model, 807 model_state_dict={model.l: model_state_dict3}, 808 options=StateDictOptions(strict=False), 809 ) 810 self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) 811 self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) 812 813 @with_comms 814 @skip_if_lt_x_gpu(1) 815 def test_deprecate_fsdp_api(self) -> None: 816 device_mesh = init_device_mesh("cuda", (self.world_size,)) 817 model = CompositeParamModel(device=torch.device("cuda")) 818 fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) 819 with self.assertWarnsRegex( 820 FutureWarning, 821 r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", 822 ): 823 with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): 824 fsdp_model.state_dict() 825 826 with self.assertRaisesRegex(AssertionError, "FutureWarning not triggered"): 827 with self.assertWarnsRegex( 828 FutureWarning, 829 r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", 830 ): 831 get_model_state_dict(model) 832 833 @with_comms 834 @skip_if_lt_x_gpu(2) 835 def test_shared_weight(self): 836 class TiedEmbeddingModel(nn.Module): 837 def __init__(self, vocab_size, embedding_dim): 838 super().__init__() 839 self.embedding = nn.Embedding(vocab_size, embedding_dim) 840 self.decoder = nn.Linear(embedding_dim, vocab_size) 841 self.decoder.weight = self.embedding.weight # Tying weights 842 843 def forward(self, input): 844 input = (input * 10).to(torch.int) 845 embedded = self.embedding(input) 846 output = self.decoder(embedded) 847 return output 848 849 def init_model_optim(): 850 device_mesh = init_device_mesh("cuda", (self.world_size,)) 851 orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) 852 orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) 853 copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) 854 dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) 855 dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) 856 return orig_model, orig_optim, copy_optim, dist_model, dist_optim 857 858 self._test_save_load(init_model_optim) 859 860 861class TestNoComm(MultiProcessTestCase): 862 def setUp(self) -> None: 863 super().setUp() 864 self._spawn_processes() 865 866 @skip_if_lt_x_gpu(1) 867 def test_no_dist(self) -> None: 868 model = CompositeParamModel(device=torch.device("cuda")) 869 optim = torch.optim.AdamW(model.parameters(), lr=1e-3) 870 871 self.assertFalse(dist.is_initialized()) 872 msd = get_model_state_dict( 873 model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) 874 ) 875 for v in msd.values(): 876 self.assertFalse(v.is_cuda) 877 self.assertEqual(model.state_dict(), msd) 878 set_model_state_dict(model, model.state_dict()) 879 osd = get_optimizer_state_dict( 880 model, 881 optim, 882 options=StateDictOptions(full_state_dict=True, cpu_offload=True), 883 ) 884 set_optimizer_state_dict(model, optim, osd) 885 set_optimizer_state_dict(model, optim, optim.state_dict()) 886 887 888if __name__ == "__main__": 889 run_tests() 890