1# Owner(s): ["oncall: distributed"] 2 3import io 4import itertools 5import sys 6from contextlib import nullcontext 7from copy import deepcopy 8from functools import partial 9from typing import Any, Dict 10 11import torch 12import torch.nn as nn 13from torch import distributed as dist 14from torch.distributed._shard.sharded_tensor import ( 15 init_from_local_shards, 16 Shard, 17 ShardedTensor, 18) 19from torch.distributed._state_dict_utils import ( 20 _all_gather_sharded_tensor, 21 _gather_state_dict, 22) 23from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 24 apply_activation_checkpointing, 25 checkpoint_wrapper, 26 CheckpointImpl, 27) 28from torch.distributed.fsdp import ( 29 CPUOffload, 30 FullStateDictConfig, 31 FullyShardedDataParallel as FSDP, 32 LocalStateDictConfig, 33 MixedPrecision, 34 ShardedStateDictConfig, 35 StateDictType, 36) 37from torch.distributed.fsdp._common_utils import FSDP_PREFIX 38from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM 39from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap 40from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer 41from torch.nn.parallel import DistributedDataParallel 42from torch.optim import SGD 43from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 44from torch.testing._internal.common_fsdp import ( 45 _assert_module_states, 46 _broadcast_state_dict, 47 _get_state_dict, 48 _zero_model, 49 CUDAInitMode, 50 FSDPInitMode, 51 FSDPTest, 52 get_full_params, 53 SkipModel, 54 TransformerWithSharedParams, 55) 56from torch.testing._internal.common_utils import ( 57 instantiate_parametrized_tests, 58 parametrize, 59 run_tests, 60 TEST_WITH_DEV_DBG_ASAN, 61) 62 63 64if not dist.is_available(): 65 print("Distributed not available, skipping tests", file=sys.stderr) 66 sys.exit(0) 67 68if TEST_WITH_DEV_DBG_ASAN: 69 print( 70 "Skip dev-asan as torch + multiprocessing spawn have known issues", 71 file=sys.stderr, 72 ) 73 sys.exit(0) 74 75INNER_SHAPE = [4, 4] 76OUTER_SHAPE = [4, 5] 77BUFFER_SHAPE = [5, 5] 78 79NON_ROOT_FSDP_PREFIX = "non_fsdp_lin" 80 81_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"] 82_FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"] 83_SUPPORTED_STATE_DICT_IMPLS = ( 84 _UNFLATTENED_STATE_DICT_IMPLS + _FLATTENED_STATE_DICT_IMPLS 85) 86 87STATE_DICT_MAPPING = { 88 "state_dict": StateDictType.FULL_STATE_DICT, 89 "local_state_dict": StateDictType.LOCAL_STATE_DICT, 90 "sharded_state_dict": StateDictType.SHARDED_STATE_DICT, 91} 92 93 94class Model(Module): 95 def __init__( 96 self, 97 wrap_fsdp, 98 register_buffers=False, 99 ignore_inner=False, 100 mixed_precision=False, 101 process_group=None, 102 ): 103 super().__init__() 104 self.inner = Linear(*INNER_SHAPE) 105 if register_buffers: 106 self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) 107 self.inner.register_buffer( 108 "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False 109 ) 110 if wrap_fsdp: 111 self.inner = FSDP( 112 self.inner, 113 ignored_modules=([self.inner] if ignore_inner else []), 114 mixed_precision=MixedPrecision( 115 param_dtype=torch.float16, 116 reduce_dtype=torch.float16, 117 buffer_dtype=torch.float16, 118 ) 119 if mixed_precision 120 else None, 121 process_group=process_group, 122 ) 123 self.outer = Linear(*OUTER_SHAPE) 124 if register_buffers: 125 self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) 126 self.outer.register_buffer( 127 "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False 128 ) 129 130 def forward(self, x): 131 # Forward twice. 132 i = self.inner(x) 133 j = self.inner(x) 134 return self.outer(i + j) 135 136 137class TestDummyModel(torch.nn.Module): 138 def __init__(self) -> None: 139 super().__init__() 140 torch.manual_seed(0) 141 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 142 self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) 143 self.net3 = self.net2 144 self.random_parameter = nn.Parameter(torch.Tensor(10)) 145 self.shared_parameter = self.random_parameter 146 147 def forward(self, x): 148 return self.net3(self.net2(self.net1(x))) 149 150 def get_input(self): 151 return torch.rand(8, 8, device="cuda") 152 153 154class TestFSDPStateDict(FSDPTest): 155 @property 156 def world_size(self): 157 return min(torch.cuda.device_count(), 2) 158 159 def _broadcast_state_dict(self, model, state_dict): 160 # TODO (rohan-varma): remove model 161 return _broadcast_state_dict(self.rank, state_dict) 162 163 def _state_compare(self, model, model_new, assert_fn, state_generator="parameters"): 164 state_base = list(getattr(model, state_generator)()) 165 state_new = list(getattr(model_new, state_generator)()) 166 # Regardless of `assert_fn`, the number of parameters should be the same 167 self.assertEqual(len(state_base), len(state_new)) 168 assert_fn(state_base, state_new) 169 170 def _compare_models( 171 self, model, model_new, assert_fn, check_fp16=False, check_buffers=True 172 ): 173 assert assert_fn in (self.assertEqual, self.assertNotEqual) 174 with FSDP.summon_full_params(model): 175 with FSDP.summon_full_params(model_new): 176 self._state_compare(model, model_new, assert_fn) 177 if check_buffers: 178 has_buffers = any( 179 len(list(m.buffers())) for m in (model, model_new) 180 ) 181 if has_buffers: 182 self._state_compare( 183 model, model_new, assert_fn, state_generator="buffers" 184 ) 185 if check_fp16: 186 for tensor in model_new.parameters(): 187 self.assertEqual(tensor.dtype, torch.float16) 188 189 def _get_simple_nested_model( 190 self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs 191 ): 192 if wrap: 193 lin1 = nn.Linear(10, 10, bias=False).cuda() 194 lin2 = nn.Linear(10, 10, bias=False).cuda() 195 if checkpoint_wrap: 196 lin1 = checkpoint_wrapper(lin1) 197 lin2 = checkpoint_wrapper(lin2) 198 seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2) 199 if checkpoint_wrap: 200 seq = checkpoint_wrapper(seq) 201 model = FSDP(seq, *fsdp_args, **fsdp_kwargs) 202 else: 203 model = nn.Sequential( 204 nn.Linear(10, 10, bias=False).cuda(), 205 nn.Linear(10, 10, bias=False).cuda(), 206 ) 207 return model 208 209 def _get_simple_model(self, *fsdp_args, checkpoint_wrap=False, **fsdp_kwargs): 210 lin = nn.Linear(10, 10, bias=False).cuda() 211 if checkpoint_wrap: 212 lin = checkpoint_wrapper(lin) 213 model = FSDP(lin, *fsdp_args, **fsdp_kwargs) 214 return model 215 216 def _get_multibuffer_nested_model( 217 self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs 218 ): 219 full_p = torch.float32 220 lin_mp = fsdp_kwargs.pop("mixed_precision", None) 221 bn_mp = ( 222 MixedPrecision(param_dtype=full_p, reduce_dtype=full_p, buffer_dtype=full_p) 223 if lin_mp 224 else None 225 ) 226 if wrap: 227 lin1 = nn.Linear(10, 10, bias=False).cuda() 228 bn1 = nn.BatchNorm1d(10).cuda() 229 lin2 = nn.Linear(10, 10, bias=False).cuda() 230 if checkpoint_wrap: 231 lin1 = checkpoint_wrapper(lin1) 232 bn1 = checkpoint_wrapper(bn1) 233 lin2 = checkpoint_wrapper(lin2) 234 seq = nn.Sequential( 235 FSDP(lin1, *fsdp_args, mixed_precision=lin_mp, **fsdp_kwargs), 236 FSDP(bn1, *fsdp_args, mixed_precision=bn_mp, **fsdp_kwargs), 237 lin2, 238 ) 239 if checkpoint_wrap: 240 seq = checkpoint_wrapper(seq) 241 model = FSDP(seq, *fsdp_args, **fsdp_kwargs) 242 else: 243 model = nn.Sequential( 244 nn.Linear(10, 10, bias=False).cuda(), 245 nn.BatchNorm1d(10).cuda(), 246 nn.Linear(10, 10, bias=False).cuda(), 247 ) 248 return model 249 250 def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs): 251 class FSDPContainer(nn.Module): 252 def __init__(self, fsdp_1, fsdp_2): 253 super().__init__() 254 self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda() 255 self.fsdp_1 = fsdp_1 256 self.fsdp_2 = fsdp_2 257 258 def forward(self, x): 259 x = self.non_fsdp_lin(x) 260 x = self.fsdp_1(x) 261 x = self.fsdp_2(x) 262 return x 263 264 return FSDPContainer( 265 self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs), 266 self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs), 267 ) 268 269 def _get_state_dict_mgr( 270 self, 271 model: nn.Module, 272 state_dict_type: str, 273 state_dict_rank0_and_offload: bool, 274 ): 275 _state_dict_type = STATE_DICT_MAPPING[state_dict_type] 276 if state_dict_type == "state_dict": 277 config = FullStateDictConfig( 278 rank0_only=state_dict_rank0_and_offload, 279 offload_to_cpu=state_dict_rank0_and_offload, 280 ) 281 elif state_dict_type == "local_state_dict": 282 config = LocalStateDictConfig( 283 offload_to_cpu=state_dict_rank0_and_offload, 284 ) 285 elif state_dict_type == "sharded_state_dict": 286 config = ShardedStateDictConfig( 287 offload_to_cpu=state_dict_rank0_and_offload, 288 ) 289 else: 290 raise ValueError("Unsupported state_dict_type") 291 return FSDP.state_dict_type(model, _state_dict_type, config) 292 293 def _validate_state_dict_contents( 294 self, model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None 295 ): 296 if state_dict_rank0_and_offload: 297 if self.rank == 0: 298 self.assertNotEqual(fsdp_state_dict, {}) 299 for key, tensor in fsdp_state_dict.items(): 300 if ignore_keys and key in ignore_keys: 301 continue 302 self.assertEqual( 303 tensor.device, 304 torch.device("cpu"), 305 f"{key} is unexpectedly on device {tensor.device}", 306 ) 307 else: 308 # For non-FSDP roots, the non FSDP portion can still have parameters on rank 0, 309 # so bypass the check for now. 310 if isinstance(model, FSDP): 311 self.assertEqual( 312 fsdp_state_dict, 313 {}, 314 f"Expected empty state_dict but got {fsdp_state_dict} on rank {dist.get_rank()}", 315 ) 316 317 @skip_if_lt_x_gpu(2) 318 @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) 319 @parametrize( 320 "checkpoint_wrap", 321 ["source", "dest", "both", "source_after_wrap", "both_after_wrap"], 322 ) 323 @parametrize("rank0_only_and_offload", [False, True]) 324 def test_fsdp_state_dict_with_activation_checkpoint( 325 self, state_dict_type, checkpoint_wrap, rank0_only_and_offload 326 ): 327 """Tests saving the state dict, zeroing a target model's parameters, and 328 loading the state dict, where the source and target models may have a 329 checkpoint wrapper.""" 330 331 def apply_ac_to_linears(model) -> None: 332 non_reentrant_wrapper = partial( 333 checkpoint_wrapper, 334 offload_to_cpu=False, 335 checkpoint_impl=CheckpointImpl.NO_REENTRANT, 336 ) 337 apply_activation_checkpointing( 338 model, 339 checkpoint_wrapper_fn=non_reentrant_wrapper, 340 check_fn=lambda submodule: isinstance(submodule, nn.Linear), 341 ) 342 343 for model_call in [ 344 partial(self._get_simple_model), 345 partial(self._get_simple_nested_model), 346 ]: 347 model = model_call(checkpoint_wrap=(checkpoint_wrap in ("source", "both"))) 348 if checkpoint_wrap in ("source_after_wrap", "both_after_wrap"): 349 apply_ac_to_linears(model) 350 with self._get_state_dict_mgr( 351 model, state_dict_type, rank0_only_and_offload 352 ): 353 state_dict = _gather_state_dict(_get_state_dict(model, False, False)) 354 # Possibly wrap new model in activation checkpoint wrapper to test save/ 355 # load with this wrapper 356 model_new = model_call( 357 checkpoint_wrap=(checkpoint_wrap in ("dest", "both")) 358 ) 359 if checkpoint_wrap == "both_after_wrap": 360 apply_ac_to_linears(model_new) 361 _zero_model(model_new) 362 self._compare_models(model, model_new, self.assertNotEqual) 363 if rank0_only_and_offload: 364 state_dict = self._broadcast_state_dict(model, state_dict) 365 # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks 366 model_new.load_state_dict(state_dict, strict=True) 367 self._compare_models(model, model_new, self.assertEqual) 368 369 @skip_if_lt_x_gpu(2) 370 @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) 371 @parametrize("rank0_only_and_offload", [False, True]) 372 def test_state_dict_with_manual_ac_wrapper( 373 self, 374 state_dict_type: str, 375 rank0_only_and_offload: bool, 376 ): 377 """ 378 Tests saving and loading a state dict for a model manually wrapped with 379 ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is 380 wrapped before FSDP. 381 382 TODO: Investigate why the test above does not cover everything in this 383 test and de-duplicate afterwards. 384 """ 385 if state_dict_type == "sharded_state_dict" and rank0_only_and_offload: 386 return # not supported 387 model_ac = TransformerWithSharedParams.init( 388 self.process_group, 389 FSDPInitMode.NO_FSDP, 390 CUDAInitMode.CUDA_BEFORE, 391 ) 392 # Manually wrap FSDP without AC 393 model_no_ac = deepcopy(model_ac) 394 for i, layer in enumerate(model_no_ac.transformer.encoder.layers): 395 model_no_ac.transformer.encoder.layers[i] = FSDP(layer) 396 for i, layer in enumerate(model_no_ac.transformer.decoder.layers): 397 model_no_ac.transformer.decoder.layers[i] = FSDP(layer) 398 model_no_ac.transformer = FSDP(model_no_ac.transformer) 399 400 # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))` 401 for i, layer in enumerate(model_ac.transformer.encoder.layers): 402 layer = checkpoint_wrapper(layer) 403 model_ac.transformer.encoder.layers[i] = FSDP(layer) 404 for i, layer in enumerate(model_ac.transformer.decoder.layers): 405 layer = checkpoint_wrapper(layer) 406 model_ac.transformer.decoder.layers[i] = FSDP(layer) 407 model_ac.transformer = FSDP(model_ac.transformer) 408 409 # Save, load, and compare the two models 410 with self._get_state_dict_mgr( 411 model_no_ac, state_dict_type, rank0_only_and_offload 412 ): 413 state_dict_no_ac = model_no_ac.state_dict() 414 with self._get_state_dict_mgr( 415 model_ac, state_dict_type, rank0_only_and_offload 416 ): 417 state_dict_ac = model_ac.state_dict() 418 self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys()) 419 if rank0_only_and_offload: 420 state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac) 421 state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac) 422 with self._get_state_dict_mgr( 423 model_no_ac, state_dict_type, rank0_only_and_offload 424 ): 425 model_no_ac.load_state_dict(state_dict_no_ac) 426 with self._get_state_dict_mgr( 427 model_ac, state_dict_type, rank0_only_and_offload 428 ): 429 model_ac.load_state_dict(state_dict_ac) 430 self._compare_models(model_ac, model_no_ac, self.assertEqual) 431 432 @skip_if_lt_x_gpu(2) 433 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 434 def test_state_dict_with_shared_parameters(self, state_dict_type): 435 auto_wrap_policy = ModuleWrapPolicy( 436 {TransformerEncoderLayer, TransformerDecoderLayer} 437 ) 438 model_creator = partial( 439 TransformerWithSharedParams.init, 440 self.process_group, 441 FSDPInitMode.RECURSIVE, 442 CUDAInitMode.CUDA_BEFORE, 443 {"auto_wrap_policy": auto_wrap_policy}, 444 ) 445 446 fsdp_model = model_creator() 447 with self._get_state_dict_mgr(fsdp_model, state_dict_type, False): 448 state_dict = fsdp_model.state_dict() 449 450 new_model = model_creator() 451 _zero_model(new_model, zero_buffers=True) 452 with self._get_state_dict_mgr(new_model, state_dict_type, False): 453 new_model.load_state_dict(state_dict) 454 455 @skip_if_lt_x_gpu(2) 456 @parametrize("use_orig_params", [False, True]) 457 def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): 458 """Tests saving a model checkpoint only on rank 0 and loading it only 459 on rank 0 with ``sync_module_states=True`` to emulate the workflow to 460 avoid redundant CPU memory usage.""" 461 auto_wrap_policy = ModuleWrapPolicy( 462 {TransformerEncoderLayer, TransformerDecoderLayer} 463 ) 464 fsdp_kwargs = { 465 "auto_wrap_policy": auto_wrap_policy, 466 "use_orig_params": use_orig_params, 467 } 468 fsdp_model = TransformerWithSharedParams.init( 469 self.process_group, 470 FSDPInitMode.RECURSIVE, 471 CUDAInitMode.CUDA_BEFORE, 472 fsdp_kwargs, 473 ) 474 # Force model parameters and buffers to be nonzero 475 with FSDP.summon_full_params(fsdp_model): 476 for tensor in itertools.chain( 477 fsdp_model.parameters(), fsdp_model.buffers() 478 ): 479 if torch.count_nonzero(tensor) == 0: 480 with torch.no_grad(): 481 tensor.add_(torch.ones_like(tensor)) 482 with self._get_state_dict_mgr(fsdp_model, "state_dict", True): 483 state_dict = deepcopy(_get_state_dict(fsdp_model)) 484 # Initialize a non-wrapped model on all ranks 485 new_model = TransformerWithSharedParams.init( 486 self.process_group, 487 FSDPInitMode.NO_FSDP, 488 CUDAInitMode.CUDA_BEFORE, 489 ) 490 _zero_model(new_model, zero_buffers=True) 491 # Only load the checkpoint on rank 0 492 if self.rank == 0: 493 new_model.load_state_dict(state_dict, strict=True) 494 _assert_module_states( 495 new_model, 496 process_group=self.process_group, 497 assert_fn=self.assertNotEqual, 498 ) 499 # Broadcast the module states from rank 0 with `sync_module_states=True` 500 new_fsdp_model = FSDP( 501 new_model, 502 device_id=torch.cuda.current_device(), 503 auto_wrap_policy=auto_wrap_policy, 504 sync_module_states=True, 505 ) 506 # Check FSDP models are equal across ranks 507 with FSDP.summon_full_params(new_fsdp_model): 508 _assert_module_states( 509 new_fsdp_model, 510 process_group=self.process_group, 511 assert_fn=self.assertEqual, 512 ) 513 # Check FSDP models correctly loaded the checkpoint 514 with FSDP.summon_full_params(fsdp_model): 515 with FSDP.summon_full_params(new_fsdp_model): 516 params = list(fsdp_model.parameters()) 517 params_new = list(new_fsdp_model.parameters()) 518 self.assertEqual(params, params_new) 519 520 @skip_if_lt_x_gpu(2) 521 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 522 @parametrize( 523 "cpu_offload", 524 [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], 525 ) 526 @parametrize("fp16", [True, False]) 527 @parametrize("state_dict_rank0_and_offload", [True, False]) 528 @parametrize("use_orig_params", [True, False]) 529 def test_basic_save_and_load_state_dict( 530 self, 531 state_dict_type: str, 532 cpu_offload: bool, 533 fp16: bool, 534 state_dict_rank0_and_offload: bool, 535 use_orig_params: bool, 536 ): 537 """ 538 Tests that we can save a state_dict and load it into a blank model 539 with various configs such as fp16 and cpu offload and parameters 540 match as expected. 541 """ 542 if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or ( 543 use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS 544 ): 545 return # not supported 546 device = torch.device(self.rank) 547 for model_call in [ 548 partial( 549 self._get_non_fsdp_root_module, 550 cpu_offload=cpu_offload, 551 use_orig_params=use_orig_params, 552 ), 553 partial( 554 self._get_simple_nested_model, 555 cpu_offload=cpu_offload, 556 use_orig_params=use_orig_params, 557 ), 558 partial( 559 self._get_simple_model, 560 cpu_offload=cpu_offload, 561 use_orig_params=use_orig_params, 562 ), 563 ]: 564 model = model_call() 565 if fp16: 566 model.half() 567 # Run a forward/backward to compute gradients to test the case 568 # where there are gradients populated 569 inp = torch.randn((3, 10), device=device) 570 if fp16: 571 inp = inp.half() 572 model(inp).sum().backward() 573 574 ctx = self._get_state_dict_mgr( 575 model, state_dict_type, state_dict_rank0_and_offload 576 ) 577 with ctx: 578 fsdp_state_dict = _get_state_dict( 579 model, cpu_offload.offload_params, fp16 580 ) 581 582 ignore_keys = [ 583 k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k 584 ] 585 586 self._validate_state_dict_contents( 587 model, 588 fsdp_state_dict, 589 state_dict_rank0_and_offload, 590 ignore_keys=ignore_keys, 591 ) 592 if fp16: 593 # Verify fp16 is the type 594 for tensor in fsdp_state_dict.values(): 595 self.assertEqual(tensor.dtype, torch.float16) 596 597 model_new = model_call() 598 if not cpu_offload.offload_params: 599 model_new = model_new.cuda() 600 if fp16: 601 model_new.half() 602 # Run a forward/backward to compute gradients to test the case 603 # where there are gradients populated 604 inp = torch.randn((3, 10), device=device) 605 if fp16: 606 inp = inp.half() 607 model_new(inp).sum().backward() 608 609 # zero the model to ensure parameters are different. 610 _zero_model(model_new, zero_buffers=True) 611 self._compare_models(model, model_new, self.assertNotEqual) 612 613 # Verify parameters are the same in the new model. 614 if state_dict_rank0_and_offload: 615 fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) 616 with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): 617 model_new.load_state_dict(fsdp_state_dict, strict=True) 618 619 self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16) 620 621 @skip_if_lt_x_gpu(2) 622 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 623 @parametrize( 624 "cpu_offload", 625 [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], 626 ) 627 @parametrize("mixed_precision", [True, False]) 628 @parametrize("state_dict_rank0_and_offload", [True, False]) 629 @parametrize("use_orig_params", [True, False]) 630 def test_buffers_save_and_load_state_dict( 631 self, 632 state_dict_type: str, 633 cpu_offload: bool, 634 mixed_precision: bool, 635 state_dict_rank0_and_offload: bool, 636 use_orig_params: bool, 637 ): 638 """ 639 Tests that we can save a state_dict and load it for modules with persistent buffers, including 640 in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading. 641 """ 642 if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or ( 643 use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS 644 ): 645 return # not supported 646 mixed_precision = ( 647 MixedPrecision( 648 param_dtype=torch.float16, 649 reduce_dtype=torch.float16, 650 buffer_dtype=torch.float16, 651 ) 652 if mixed_precision 653 else None 654 ) 655 model_call = partial( 656 self._get_multibuffer_nested_model, 657 cpu_offload=cpu_offload, 658 use_orig_params=use_orig_params, 659 mixed_precision=mixed_precision, 660 ) 661 model = model_call() 662 ctx = self._get_state_dict_mgr( 663 model, state_dict_type, state_dict_rank0_and_offload 664 ) 665 with ctx: 666 fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, False) 667 668 self._validate_state_dict_contents( 669 model, fsdp_state_dict, state_dict_rank0_and_offload 670 ) 671 672 model_new = model_call() 673 if not cpu_offload.offload_params: 674 model_new = model_new.cuda() 675 676 # zero the model to ensure parameters are different. 677 _zero_model(model_new, zero_buffers=True) 678 self._compare_models(model, model_new, self.assertNotEqual) 679 680 # Verify parameters are the same in the new model. 681 if state_dict_rank0_and_offload: 682 fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) 683 with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): 684 model_new.load_state_dict(fsdp_state_dict, strict=True) 685 686 self._compare_models(model, model_new, self.assertEqual) 687 688 @skip_if_lt_x_gpu(2) 689 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 690 @parametrize("mixed_precision", [True, False]) 691 @parametrize("state_dict_rank0_and_offload", [True, False]) 692 def test_save_and_load_after_forward_state_dict( 693 self, state_dict_type, mixed_precision, state_dict_rank0_and_offload 694 ): 695 """ 696 Test that saving after some training results in params being updated as 697 expected. 698 """ 699 if state_dict_rank0_and_offload and state_dict_type != "state_dict": 700 return 701 torch.cuda.set_device(self.rank) 702 mixed_precision = ( 703 MixedPrecision( 704 param_dtype=torch.float16, 705 reduce_dtype=torch.float16, 706 buffer_dtype=torch.float16, 707 ) 708 if mixed_precision 709 else None 710 ) 711 model = self._get_simple_nested_model(mixed_precision=mixed_precision) 712 optim = torch.optim.SGD(model.parameters(), lr=0.1) 713 initial_params = get_full_params(model) 714 for _ in range(6): 715 inp = torch.randn(1, 10, device=torch.cuda.current_device()) 716 output = model(*inp) 717 loss = output.sum() 718 expected_dtype = torch.float32 if mixed_precision is None else torch.float16 719 self.assertEqual(expected_dtype, loss.dtype) 720 loss.backward() 721 optim.step() 722 723 trained_params = get_full_params(model) 724 # Ensure some training occurred 725 self.assertNotEqual(initial_params, trained_params) 726 # Save a copy of the state_dict 727 fsd_mgr = self._get_state_dict_mgr( 728 model, state_dict_type, state_dict_rank0_and_offload 729 ) 730 with fsd_mgr: 731 state_dict = model.state_dict() 732 if state_dict_type == "state_dict": 733 state_dict = {k: v.clone() for k, v in state_dict.items()} 734 else: 735 for sharded_tensor in state_dict.values(): 736 shard = sharded_tensor._local_shards[0] 737 shard.tensor = shard.tensor.clone().detach_() 738 self._validate_state_dict_contents( 739 model, state_dict, state_dict_rank0_and_offload 740 ) 741 _zero_model(model) 742 743 # Ensure checkpointed params have the full param dtype 744 for tensor in state_dict.values(): 745 self.assertEqual(tensor.dtype, torch.float32) 746 747 # Load state_dict into zeroed model 748 if state_dict_rank0_and_offload: 749 state_dict = self._broadcast_state_dict(model, state_dict) 750 751 with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): 752 model.load_state_dict(state_dict, strict=True) 753 loaded_params = get_full_params(model) 754 self.assertEqual(loaded_params, trained_params) 755 756 def _initialize_model( 757 self, 758 wrap_fsdp: bool, 759 wrap_ddp: bool = True, 760 register_buffers: bool = False, 761 ): 762 # keep everything deterministic for input data 763 torch.manual_seed(0) 764 765 model = Model(wrap_fsdp, register_buffers=register_buffers).cuda() 766 if wrap_fsdp: 767 model = FSDP(model) 768 elif wrap_ddp: 769 model = DistributedDataParallel(model, device_ids=[self.rank]) 770 return model 771 772 @staticmethod 773 def _state_dict(model: Module, state_dict_type: str): 774 try: 775 enum_val = STATE_DICT_MAPPING[state_dict_type] 776 except KeyError as e: 777 raise ValueError(f"No state_dict type for {state_dict_type}") from e 778 779 with FSDP.state_dict_type(model, enum_val): 780 return model.state_dict() 781 782 @staticmethod 783 def _load_state_dict( 784 model: Module, state_dict_type: str, state_dict: Dict[str, Any] 785 ): 786 try: 787 enum_val = STATE_DICT_MAPPING[state_dict_type] 788 except KeyError as e: 789 raise ValueError(f"No state_dict for {state_dict_type}") from e 790 791 with FSDP.state_dict_type(model, enum_val): 792 return model.load_state_dict(state_dict, strict=True) 793 794 def _dist_train( 795 self, wrap_fsdp: bool, state_dict_type: str = "", move_to_cpu: bool = False 796 ): 797 # TODO: Move this test to common_fsdp. 798 model = self._initialize_model(wrap_fsdp) 799 optim = SGD(model.parameters(), lr=0.1) 800 801 in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) 802 for _ in range(3): 803 out = model(in_data) 804 out.sum().backward() 805 optim.step() 806 optim.zero_grad() 807 808 if wrap_fsdp: 809 blank_model = FSDP(Model(True).cuda()) 810 _zero_model(blank_model) 811 state_dict = self._state_dict(model, state_dict_type) 812 if move_to_cpu: 813 for key in list(state_dict.keys()): 814 tensor = state_dict[key] 815 if isinstance(tensor, torch.Tensor): 816 state_dict[key] = tensor.cpu() 817 else: 818 shards = tensor.local_shards() 819 if shards: 820 shards[0].tensor = shards[0].tensor.cpu() 821 822 self._load_state_dict(blank_model, state_dict_type, state_dict) 823 return get_full_params(blank_model) 824 else: 825 return list(model.parameters()) 826 827 @skip_if_lt_x_gpu(2) 828 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 829 def test_state_dict_save_load_flow(self, state_dict_type): 830 self.run_subtests( 831 {"move_to_cpu": [True, False]}, 832 self._test_state_dict_save_load_flow, 833 state_dict_type=state_dict_type, 834 ) 835 836 def _test_state_dict_save_load_flow(self, state_dict_type, move_to_cpu): 837 fsdp_params = self._dist_train( 838 wrap_fsdp=True, 839 state_dict_type=state_dict_type, 840 move_to_cpu=move_to_cpu, 841 ) 842 ddp_params = self._dist_train(wrap_fsdp=False) 843 self.assertEqual(ddp_params, fsdp_params) 844 845 @skip_if_lt_x_gpu(2) 846 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 847 def test_fsdp_state_dict_keys(self, state_dict_type): 848 state_dict = self._state_dict(self._initialize_model(True), state_dict_type) 849 if state_dict_type == "local_state_dict": 850 self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys()) 851 elif state_dict_type in ("state_dict", "sharded_state_dict"): 852 # Keys should match local model. 853 local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) 854 local_keys = local_model.state_dict().keys() 855 self.assertEqual(state_dict.keys(), local_keys) 856 else: 857 raise NotImplementedError(f"No test for {state_dict_type}!") 858 859 @skip_if_lt_x_gpu(2) 860 @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) 861 @parametrize("state_dict_rank0_and_offload", [True, False]) 862 @parametrize("fsdp_root", [True, False]) 863 def test_state_dict_load_into_local_module( 864 self, 865 state_dict_type, 866 state_dict_rank0_and_offload, 867 fsdp_root, 868 ): 869 """ 870 Tests that FSDP's state_dict can be loaded into a local model. 871 """ 872 if state_dict_rank0_and_offload and state_dict_type != "state_dict": 873 return 874 if not fsdp_root: 875 model = self._get_non_fsdp_root_module() 876 else: 877 model = self._initialize_model(wrap_fsdp=True, register_buffers=True) 878 optim = SGD(model.parameters(), lr=0.1) 879 if not fsdp_root: 880 in_data = torch.randn( 881 1, 10, requires_grad=True, device=torch.device("cuda") 882 ) 883 else: 884 in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) 885 for _ in range(3): 886 out = model(in_data) 887 out.sum().backward() 888 optim.step() 889 optim.zero_grad() 890 891 with FSDP.summon_full_params(model): 892 fsdp_params = deepcopy(list(model.parameters())) 893 894 # get FSDP state_dict. Note that by default we return full_state_dict. 895 sd_mgr = self._get_state_dict_mgr( 896 model, state_dict_type, state_dict_rank0_and_offload 897 ) 898 with sd_mgr: 899 fsdp_state_dict = model.state_dict() 900 901 ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k] 902 self._validate_state_dict_contents( 903 model, 904 fsdp_state_dict, 905 state_dict_rank0_and_offload, 906 ignore_keys=ignore_keys, 907 ) 908 # Create zeroed local model 909 if not fsdp_root: 910 blank_local_model = self._get_non_fsdp_root_module(wrap=False) 911 else: 912 blank_local_model = self._initialize_model( 913 wrap_fsdp=False, wrap_ddp=False, register_buffers=True 914 ) 915 916 # Nothing should be FSDP 917 for mod in blank_local_model.modules(): 918 self.assertFalse(isinstance(mod, FSDP)) 919 920 for param in blank_local_model.parameters(): 921 with torch.no_grad(): 922 param.zero_() 923 924 fsdp_state_dict = _gather_state_dict(fsdp_state_dict) 925 926 # Load fsdp's full state dict into the local and verify params are as 927 # expected. 928 if state_dict_rank0_and_offload: 929 fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) 930 931 blank_local_model.load_state_dict(fsdp_state_dict, strict=True) 932 local_params = list(blank_local_model.parameters()) 933 for fsdp_param, local_param in zip(fsdp_params, local_params): 934 self.assertEqual(fsdp_param, local_param) 935 936 @skip_if_lt_x_gpu(2) 937 @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) 938 @parametrize("double_nest", [True]) 939 def test_state_dict_skip_module(self, state_dict_type, double_nest): 940 torch.cuda.set_device(self.rank) 941 942 def _create_module(wrap_fsdp=True): 943 LINEAR_SKIP = "linear_skip" 944 ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext() 945 with ctx: 946 module = SkipModel(double_nest=double_nest) 947 # Full name of linear_skip param tensors in SkipModel, as would be 948 # stored in checkpoint. 949 linear_skip_tensor_names = [ 950 k 951 for k in dict(module.named_parameters()).keys() 952 if LINEAR_SKIP in k 953 ] 954 # skip SkipModule 955 linear_skip = getattr(module, LINEAR_SKIP) 956 delattr(module, LINEAR_SKIP) 957 # Wrap FSDP 958 fsdp = wrap(module) 959 # reattach 960 setattr(module, LINEAR_SKIP, linear_skip) 961 return fsdp, linear_skip_tensor_names 962 963 fsdp, linear_skip_tensor_names = _create_module() 964 # Run a forward pass 965 inp = torch.randn((1, 10), device=torch.cuda.current_device()) 966 loss = fsdp(inp) 967 loss.sum().backward() 968 969 with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]): 970 state_dict = fsdp.state_dict() 971 if self.rank == 0 and state_dict_type != "local_state_dict": 972 sd_keys = list(state_dict.keys()) 973 expected = list(SkipModel(double_nest=False).state_dict().keys()) 974 self.assertEqual(sorted(sd_keys), sorted(expected)) 975 # TODO: parameters in linear_skip_tensor_names should not be handled 976 # by FSDP.state_dict(). Have a check once this is implemented in 977 # FSDP.state_dict(). 978 979 # Check that it can be loaded into FSDP. 980 new_fsdp, _ = _create_module() 981 _zero_model(new_fsdp) 982 for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()): 983 self.assertNotEqual(p1, p2) 984 with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]): 985 if state_dict_type != "local_state_dict": 986 # FlatParameter has not supported deepcopy yet. 987 state_dict = deepcopy(state_dict) 988 new_fsdp.load_state_dict(state_dict, strict=True) 989 for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()): 990 self.assertEqual(p1, p2) 991 992 # Test that the checkpoint can be loaded into a local model. 993 local, _ = _create_module(wrap_fsdp=False) 994 for param in local.parameters(): 995 with torch.no_grad(): 996 param.zero_() 997 998 with fsdp.summon_full_params(fsdp): 999 for p1, p2 in zip(fsdp.parameters(), local.parameters()): 1000 self.assertNotEqual(p1, p2) 1001 1002 if state_dict_type == "local_state_dict": 1003 return 1004 state_dict = _gather_state_dict(state_dict) 1005 with fsdp.summon_full_params(fsdp): 1006 if self.rank == 0: 1007 local.load_state_dict(state_dict, strict=True) 1008 for p1, p2 in zip(fsdp.parameters(), local.parameters()): 1009 self.assertEqual(p1, p2) 1010 1011 @skip_if_lt_x_gpu(2) 1012 def test_wrong_state_dict_config(self): 1013 model = FSDP(Model(wrap_fsdp=True).cuda()) 1014 with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): 1015 with model.state_dict_type( 1016 model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig() 1017 ): 1018 pass 1019 1020 @skip_if_lt_x_gpu(2) 1021 @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) 1022 @parametrize("prefix", [True, False]) 1023 @parametrize("ignore_inner", [True, False]) 1024 @parametrize("mixed_precision", [True, False]) 1025 def test_state_dict_with_ignored_modules( 1026 self, state_dict_type, prefix, ignore_inner, mixed_precision 1027 ): 1028 # Initialize an FSDP-wrapped model with an ignored module that includes 1029 # both parameters and a buffer 1030 model = Model( 1031 wrap_fsdp=True, 1032 register_buffers=True, 1033 ignore_inner=ignore_inner, 1034 mixed_precision=mixed_precision, 1035 ).cuda() 1036 ignored_modules = [model.outer] 1037 ignored_tensor_to_tensor_name = { 1038 model.outer.bias: "outer.bias", 1039 model.outer.weight: "outer.weight", 1040 } 1041 if ignore_inner: 1042 ignored_tensor_to_tensor_name = { 1043 **ignored_tensor_to_tensor_name, 1044 model.inner.bias: "inner.bias", 1045 model.inner.weight: "inner.weight", 1046 } 1047 # Note that when model.inner is not ignored this test also ensures 1048 # non-ignored buffers are not cloned. 1049 buffer_to_buffer_name = { 1050 model.inner.buffer: "inner.buffer", 1051 model.outer.buffer: "outer.buffer", 1052 } 1053 # expect fp16 model.inner.buffer with mixed_precisions 1054 # expect fp32 sd.inner.buffer after restoring to original precision 1055 # so skip AssertEqual 1056 if mixed_precision and not ignore_inner: 1057 buffer_to_buffer_name.pop(model.inner.buffer) 1058 1059 fsdp_model = FSDP( 1060 model, 1061 ignored_modules=ignored_modules, 1062 mixed_precision=MixedPrecision( 1063 param_dtype=torch.float16, 1064 reduce_dtype=torch.float16, 1065 buffer_dtype=torch.float16, 1066 ) 1067 if mixed_precision 1068 else None, 1069 ) 1070 prefix_str = "foo." if prefix else "" 1071 with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]): 1072 sd1 = _gather_state_dict(fsdp_model.state_dict(prefix=prefix_str)) 1073 with FSDP.summon_full_params(fsdp_model): 1074 fsdp_params = deepcopy(list(fsdp_model.parameters())) 1075 # Check that the ignored parameters and all buffers are not cloned 1076 for tensor, tensor_name in { 1077 **ignored_tensor_to_tensor_name, 1078 **buffer_to_buffer_name, 1079 }.items(): 1080 prefixed_tensor_name = f"{prefix_str}{tensor_name}" 1081 self.assertTrue(prefixed_tensor_name in sd1) 1082 self.assertEqual( 1083 tensor.data_ptr(), 1084 sd1[prefixed_tensor_name].data_ptr(), 1085 f"{prefixed_tensor_name}", 1086 ) 1087 # should not apply mixed_precision to ignored buffers 1088 for buffer_name in buffer_to_buffer_name.values(): 1089 prefixed_buffer_name = f"{prefix_str}{buffer_name}" 1090 self.assertTrue(prefixed_buffer_name in sd1) 1091 self.assertEqual(sd1[prefixed_buffer_name].dtype, torch.float32) 1092 # Check that the state dict can be loaded into a non-wrapped version of 1093 # the model 1094 nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() 1095 for param in nonwrapped_model.parameters(): 1096 with torch.no_grad(): 1097 param.zero_() 1098 1099 to_load = {k[len(prefix_str) :]: v for k, v in sd1.items()} 1100 nonwrapped_model.load_state_dict(to_load, strict=True) 1101 local_params = list(nonwrapped_model.parameters()) 1102 for fsdp_param, local_param in zip(fsdp_params, local_params): 1103 self.assertEqual(fsdp_param, local_param) 1104 # Check that if we save a state dict again, the ignored parameters and 1105 # buffer still have the same data pointer 1106 with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]): 1107 sd2 = fsdp_model.state_dict(prefix=prefix_str) 1108 for tensor, tensor_name in { 1109 **ignored_tensor_to_tensor_name, 1110 **buffer_to_buffer_name, 1111 }.items(): 1112 prefixed_tensor_name = f"{prefix_str}{tensor_name}" 1113 self.assertTrue(prefixed_tensor_name in sd2) 1114 self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) 1115 self.assertEqual( 1116 sd1[prefixed_tensor_name].data_ptr(), 1117 sd2[prefixed_tensor_name].data_ptr(), 1118 ) 1119 1120 @skip_if_lt_x_gpu(2) 1121 def test_state_dict_type(self): 1122 module = SkipModel(double_nest=True) 1123 with enable_wrap(wrapper_cls=FSDP): 1124 fsdp = wrap(module) 1125 with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): 1126 pass 1127 for module in FSDP.fsdp_modules(fsdp): 1128 self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT) 1129 1130 @skip_if_lt_x_gpu(2) 1131 def test_local_state_dict_with_empty_ranks(self): 1132 class Model(Module): 1133 def __init__(self) -> None: 1134 super().__init__() 1135 self.my_tensor = torch.full((1,), 3.1415926) 1136 self.my_parameter = nn.Parameter(self.my_tensor) 1137 1138 def forward(self, x): 1139 return self.my_parameter 1140 1141 model = FSDP(Model().cuda()) 1142 with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 1143 out = model(None) 1144 out.backward() 1145 1146 state_dict = deepcopy(model.state_dict()) 1147 with torch.no_grad(): 1148 with FSDP.summon_full_params(model): 1149 self.assertEqual(model.my_parameter.item(), 3.1415926) 1150 model.my_parameter.copy_(torch.full((1,), 1.75).cuda()) 1151 self.assertEqual(model.my_parameter.item(), 1.75) 1152 model.load_state_dict(state_dict) 1153 with FSDP.summon_full_params(model): 1154 self.assertEqual(model.my_parameter.item(), 3.1415926) 1155 1156 @skip_if_lt_x_gpu(2) 1157 def test_torch_save_load(self): 1158 model = Model(wrap_fsdp=True).cuda() 1159 with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 1160 state_dict = model.state_dict() 1161 checkpoint = io.BytesIO() 1162 torch.save(state_dict, checkpoint) 1163 checkpoint.seek(0) 1164 state_dict_saved = torch.load(checkpoint) 1165 for k, v in state_dict_saved.items(): 1166 if isinstance(v, ShardedTensor): 1167 self.assertEqual( 1168 v._local_shards[0].tensor, state_dict[k]._local_shards[0].tensor 1169 ) 1170 else: 1171 self.assertEqual(v, state_dict[k]) 1172 1173 @skip_if_lt_x_gpu(2) 1174 def test_shared_module_and_shared_parameter(self): 1175 model = FSDP(TestDummyModel().cuda()) 1176 with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): 1177 state_dict = model.state_dict() 1178 self.assertEqual( 1179 state_dict["random_parameter"], state_dict["shared_parameter"] 1180 ) 1181 self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"]) 1182 self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"]) 1183 1184 @skip_if_lt_x_gpu(2) 1185 def test_full_state_dict_missing_unexpected_keys_cleaned(self): 1186 model = self._get_simple_nested_model() 1187 sd = model.state_dict() 1188 # Create a missing key 1189 sd.pop(next(iter(sd.keys()))) 1190 # Create an unexpected key 1191 sd["unexpected"] = torch.ones(1) 1192 missing, unexpected = model.load_state_dict(sd, strict=False) 1193 assert len(missing) == 1 1194 assert len(unexpected) == 1 1195 self.assertTrue(FSDP_PREFIX not in missing[0]) 1196 self.assertTrue(FSDP_PREFIX not in unexpected[0]) 1197 1198 @skip_if_lt_x_gpu(2) 1199 def test_sharded_load_multi_backend_pg(self): 1200 auto_wrap_policy = ModuleWrapPolicy( 1201 {TransformerEncoderLayer, TransformerDecoderLayer} 1202 ) 1203 fsdp_kwargs = { 1204 "auto_wrap_policy": auto_wrap_policy, 1205 "use_orig_params": True, 1206 } 1207 for load_cpu in [True, False]: 1208 with self.subTest(load_cpu=load_cpu): 1209 pg = dist.new_group(backend="cpu:gloo,cuda:nccl") 1210 fsdp_model = TransformerWithSharedParams.init( 1211 pg, 1212 FSDPInitMode.RECURSIVE, 1213 CUDAInitMode.CUDA_BEFORE, 1214 fsdp_kwargs, 1215 ) 1216 FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT) 1217 sharded = fsdp_model.state_dict() 1218 param_copy = [t.clone().detach_() for t in fsdp_model.parameters()] 1219 with torch.no_grad(): 1220 for p in fsdp_model.parameters(): 1221 p.zero_() 1222 1223 if load_cpu: 1224 # Offload to CPU to simulate CPU state_dict load 1225 for k, v in sharded.items(): 1226 sharded[k] = v.cpu() 1227 1228 fsdp_model.load_state_dict(sharded) 1229 for p1, p2 in zip(param_copy, fsdp_model.parameters()): 1230 self.assertEqual(p1, p2, f"not equal: {p1.sum()} vs {p2.sum()}") 1231 1232 @skip_if_lt_x_gpu(2) 1233 def test_world_size_one(self): 1234 my_pg = None 1235 for i in range(self.world_size): 1236 pg = dist.new_group(ranks=[i]) 1237 if i == self.rank: 1238 my_pg = pg 1239 1240 model = TransformerWithSharedParams.init( 1241 my_pg, 1242 FSDPInitMode.RECURSIVE, 1243 CUDAInitMode.CUDA_BEFORE, 1244 ) 1245 with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 1246 state_dict = model.state_dict() 1247 model.load_state_dict(state_dict) 1248 1249 dist.barrier() 1250 1251 1252class TestFSDPStateDict4GPUs(FSDPTest): 1253 @property 1254 def world_size(self): 1255 return torch.cuda.device_count() 1256 1257 @skip_if_lt_x_gpu(4) 1258 def test_local_state_dict_reshard(self): 1259 """ 1260 This test demonstrates the ability to do resharding when using 1261 local_state_dict. Although we do not recommend users to use 1262 local_state_dict, there are still some corner cases that 1263 using local_state_dict is a better solution. 1264 """ 1265 model = FSDP(Model(wrap_fsdp=True)).cuda() 1266 optim = torch.optim.SGD(model.parameters(), lr=0.1) 1267 1268 batch = torch.randn(4, 4, device=torch.cuda.current_device()) 1269 output = model(batch) 1270 loss = output.sum() 1271 loss.backward() 1272 optim.step() 1273 with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 1274 state_dict = model.state_dict() 1275 1276 rank = dist.get_rank() 1277 new_pg = dist.new_group(ranks=[0, 1]) 1278 resharded_state_dict = {} 1279 # Mimic resharding from 4 GPUs to 2 GPUs 1280 for key, value in state_dict.items(): 1281 if isinstance(value, ShardedTensor): 1282 full_flat_param = _all_gather_sharded_tensor(value) 1283 if rank < 2: 1284 full_numel = full_flat_param.size() 1285 chunks = full_flat_param.chunk(2) 1286 flat_param = chunks[rank] 1287 shard_offset = 0 if rank == 0 else chunks[0].numel() 1288 local_shards = [ 1289 Shard.from_tensor_and_offsets(flat_param, [shard_offset], rank) 1290 ] 1291 sharded_tensor = init_from_local_shards( 1292 local_shards, full_numel, process_group=new_pg 1293 ) 1294 resharded_state_dict[key] = sharded_tensor 1295 else: 1296 if rank < 2: 1297 resharded_state_dict[key] = value 1298 1299 if rank < 2: 1300 model2 = FSDP( 1301 Model(wrap_fsdp=True, process_group=new_pg), process_group=new_pg 1302 ).cuda() 1303 with FSDP.state_dict_type(model2, StateDictType.LOCAL_STATE_DICT): 1304 model2.load_state_dict(resharded_state_dict) 1305 1306 with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): 1307 full_state_dict1 = model.state_dict() 1308 1309 if rank < 2: 1310 with FSDP.state_dict_type(model2, StateDictType.FULL_STATE_DICT): 1311 full_state_dict2 = model2.state_dict() 1312 self.assertEqual(full_state_dict1, full_state_dict2) 1313 1314 1315instantiate_parametrized_tests(TestFSDPStateDict) 1316 1317if __name__ == "__main__": 1318 run_tests() 1319