1# Owner(s): ["oncall: distributed"] 2 3import bisect 4import sys 5from copy import deepcopy 6from enum import auto, Enum 7from typing import Any, Callable, Dict, List, Optional, Tuple, Type 8 9import torch 10import torch.nn as nn 11from torch import distributed as dist 12from torch.distributed._shard.sharded_tensor import ShardedTensor 13from torch.distributed._state_dict_utils import _gather_state_dict 14from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 15 _CHECKPOINT_WRAPPED_MODULE, 16 apply_activation_checkpointing, 17) 18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 19from torch.distributed.fsdp.api import ShardingStrategy 20from torch.distributed.fsdp.fully_sharded_data_parallel import ( 21 FullOptimStateDictConfig, 22 FullStateDictConfig, 23 OptimStateKeyType, 24 ShardedOptimStateDictConfig, 25 ShardedStateDictConfig, 26 StateDictSettings, 27 StateDictType, 28) 29from torch.distributed.optim import _NamedOptimizer 30from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 31from torch.testing._internal.common_fsdp import ( 32 CUDAInitMode, 33 FSDPInitMode, 34 FSDPTest, 35 TransformerWithSharedParams, 36) 37from torch.testing._internal.common_utils import ( 38 instantiate_parametrized_tests, 39 parametrize, 40 run_tests, 41 TEST_WITH_DEV_DBG_ASAN, 42) 43 44 45STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT] 46 47if not dist.is_available(): 48 print("Distributed not available, skipping tests", file=sys.stderr) 49 sys.exit(0) 50 51if TEST_WITH_DEV_DBG_ASAN: 52 print( 53 "Skip dev-asan as torch + multiprocessing spawn have known issues", 54 file=sys.stderr, 55 ) 56 sys.exit(0) 57 58 59class _OSDCommMethod(Enum): 60 """Method for communicating the optimizer state dict for internal tests.""" 61 62 BROADCAST_OBJECT_LIST = auto() 63 SCATTER_FULL_OSD = auto() 64 FLATTEN_SHARDED_OSD = auto() 65 OPTIM_STATE_DICT = auto() 66 67 68class _ModelClass(Enum): 69 """Different model type to test.""" 70 71 NESTED = auto() 72 TRANSFORMER = auto() 73 74 75class Bias(torch.nn.Module): 76 """This module applies a 1D additive bias with dimension ``dim``.""" 77 78 def __init__(self, dim: int) -> None: 79 super().__init__() 80 assert dim > 0 81 torch.manual_seed(0) 82 self.bias = torch.nn.Parameter(torch.randn((dim,))) 83 84 def forward(self, x): 85 return x + self.bias 86 87 88class BlockA(torch.nn.Module): 89 """ 90 Used to define interesting nested structure for FSDP wrapping. 91 BlockA 92 Bias0 93 bias 94 weight 95 Bias1 96 bias 97 """ 98 99 def __init__(self, in_dim: int, out_dim: int) -> None: 100 super().__init__() 101 assert all(v > 0 for v in (in_dim, out_dim)) 102 torch.manual_seed(0) 103 self.bias_module0 = Bias(out_dim) 104 self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim))) 105 self.bias_module1 = Bias(out_dim) 106 self.relu = torch.nn.ReLU() 107 108 def forward(self, x): 109 x = x @ self.weight 110 x = self.bias_module0(x) 111 x = self.relu(x) # ensure biases have different gradients 112 x = self.bias_module1(x) 113 return x 114 115 116class BlockB(torch.nn.Module): 117 """ 118 Used to define interesting nested structure for FSDP wrapping. 119 BlockB 120 weight 121 Bias 122 bias 123 Bias 124 bias 125 """ 126 127 def __init__(self, in_dim: int, out_dim: int) -> None: 128 super().__init__() 129 assert all(v > 0 for v in (in_dim, out_dim)) 130 torch.manual_seed(0) 131 self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim))) 132 self.bias_module0 = Bias(out_dim) 133 self.bias_module1 = Bias(out_dim) 134 self.relu = torch.nn.ReLU() 135 136 def forward(self, x): 137 x = x @ self.weight 138 x = self.bias_module0(x) 139 x = self.relu(x) # ensure biases have different gradients 140 x = self.bias_module1(x) 141 return x 142 143 144class NestedModel(torch.nn.Module): 145 def __init__(self) -> None: 146 super().__init__() 147 self.block0 = BlockB(5, 3) 148 self.block1 = BlockB(3, 7) 149 self.bias = torch.nn.Parameter(torch.randn((5,))) 150 self.block2 = torch.nn.Sequential( 151 BlockA(7, 9), 152 BlockA(9, 9), 153 BlockB(9, 5), 154 ) 155 self.relu = torch.nn.ReLU() 156 157 def forward(self, x) -> torch.Tensor: 158 x = self.relu(self.block0(x)) 159 x = self.relu(self.block1(x)) 160 x = self.relu(self.block2(x)) 161 x = x + self.bias 162 return x 163 164 def get_input(self, device): 165 BATCH_SIZE = 8 166 return (torch.randn((BATCH_SIZE, 5)).to(device),) 167 168 def get_loss(self, inp, output): 169 return output.sum() 170 171 def run_backward(self, loss): 172 loss.backward() 173 174 @staticmethod 175 def wrap( 176 model: torch.nn.Module, 177 group: Optional[dist.ProcessGroup] = None, 178 ignore_modules: bool = False, 179 fsdp_kwargs: Optional[Dict[str, Any]] = None, 180 ) -> torch.nn.Module: 181 if fsdp_kwargs is None: 182 fsdp_kwargs = {} 183 # Flatten Bias0; then flatten weight and Bias1 together into `block1` 184 model.block1.bias_module0 = FSDP( 185 model.block1.bias_module0, 186 process_group=group, 187 **fsdp_kwargs, 188 ) 189 model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs) 190 # Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]` 191 model.block2[1].bias_module0 = FSDP( 192 model.block2[1].bias_module0, 193 process_group=group, 194 **fsdp_kwargs, 195 ) 196 model.block2[1].bias_module1 = FSDP( 197 model.block2[1].bias_module1, 198 process_group=group, 199 **fsdp_kwargs, 200 ) 201 model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs) 202 # Flatten weight, Bias, bias into `block2[2]` 203 ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None 204 model.block2[2] = FSDP( 205 model.block2[2], 206 process_group=group, 207 ignored_modules=ignored_modules, 208 **fsdp_kwargs, 209 ) 210 return model 211 212 @staticmethod 213 def wrap_alt( 214 model: torch.nn.Module, 215 group: Optional[dist.ProcessGroup] = None, 216 fsdp_kwargs: Optional[Dict[str, Any]] = None, 217 ) -> torch.nn.Module: 218 if fsdp_kwargs is None: 219 fsdp_kwargs = {} 220 model.block0.bias_module0 = FSDP( 221 model.block0.bias_module0, 222 process_group=group, 223 **fsdp_kwargs, 224 ) 225 model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs) 226 return model 227 228 @staticmethod 229 def wrap_with_unmanaged_params( 230 model, 231 add_to_fsdp_module: bool, 232 group=None, 233 ) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]: 234 """Registers unmanaged parameters before wrapping with :meth:`wrap`.""" 235 device = next(model.parameters()).device 236 unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device)) 237 # Either register the parameter to a module to be wrapped with FSDP 238 # (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`) 239 register_module = model.block2[2] if add_to_fsdp_module else model 240 register_module.register_parameter( 241 "unmanaged_param", 242 unmanaged_param, 243 ) 244 # For simplicity, we only add a single unmanaged parameter, but should 245 # be easy to generalize if needed 246 return NestedModel.wrap(model, group), [unmanaged_param] 247 248 @staticmethod 249 def add_unmanaged_param_entry(osd, unmanaged_param, step) -> None: 250 """Adds an entry for the unmanaged parameter ``unmanaged_param`` 251 assuming Adam optimizer and a single parameter group.""" 252 # The unmanaged parameters should be passed to this method in 253 # `model.parameters()` order since their parameter IDs will be assigned 254 # in order of the skipped IDs 255 # Assign a parameter ID to the unmanaged parameter 256 unmanaged_param_id = -1 257 param_ids = osd["param_groups"][0]["params"] 258 for i in range(1, len(param_ids)): 259 diff = param_ids[i] - param_ids[i - 1] 260 if diff != 1: 261 assert diff > 1, f"Invalid IDs: {param_ids[i - 1]} {param_ids[i]}" 262 unmanaged_param_id = param_ids[i - 1] + 1 263 break 264 if unmanaged_param_id == -1: 265 unmanaged_param_id = len(param_ids) # last ID skipped 266 assert unmanaged_param_id >= 0, "One parameter ID should be skipped" 267 # Add a state entry for the unmanaged parameter 268 state_device = next(iter(next(iter(osd["state"].values())).values())).device 269 osd["state"][unmanaged_param_id] = { 270 "step": torch.tensor(float(step), device=state_device), 271 "exp_avg": torch.randn(unmanaged_param.shape, device=state_device), 272 "exp_avg_sq": torch.randn(unmanaged_param.shape, device=state_device), 273 } 274 # Insert the ID into the parameter group in order 275 bisect.insort(osd["param_groups"][0]["params"], unmanaged_param_id) 276 277 # NOTE: We exclude `self.bias` from either parameter group to test the 278 # case where the optimizer input does not include all model parameters 279 def param_group0(self) -> List[torch.nn.Parameter]: 280 # Use `block1`'s parameters for the first parameter group to deviate 281 # from the `model.parameters()` order 282 return list(self.block1.parameters()) 283 284 def param_group1(self) -> List[torch.nn.Parameter]: 285 # Deviate from the `model.parameters()` order further by rearranging 286 # `block2`'s parameters to be before `block0`'s parameters 287 return list(self.block2.parameters()) + list(self.block0.parameters()) 288 289 290# Simple and boring model to test interface and some corner cases that do not 291# require complicated wrapping strategy. 292class TestDummyModel(torch.nn.Module): 293 def __init__(self, no_grad: bool = False): 294 super().__init__() 295 torch.manual_seed(0) 296 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 297 self.net1[0].weight.requires_grad = not no_grad 298 self.net1[0].bias.requires_grad = not no_grad 299 self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 300 self.net3 = nn.Linear(32, 64) 301 self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 302 303 def forward(self, x): 304 return self.net4(self.net3(self.net2(self.net1(x)))) 305 306 def get_input(self): 307 return torch.rand(8, 8, device="cuda") 308 309 310class TestFSDPOptimState(FSDPTest): 311 def __init__(self, *args, **kwargs): 312 super().__init__(*args, **kwargs) 313 self._model_class = { 314 _ModelClass.NESTED: self._init_nested_model, 315 _ModelClass.TRANSFORMER: self._init_transformer_model, 316 } 317 318 def _init_nested_model( 319 self, 320 wrap: bool, 321 wrap_alt: bool = False, # ignored if `wrap=False` 322 device: torch.device = torch.device("cuda"), 323 group=None, 324 optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam, 325 use_multiple_param_groups: bool = False, 326 use_diff_optim_inputs: bool = False, 327 fsdp_kwargs: Optional[Dict[str, Any]] = None, 328 ): 329 model = NestedModel().to(device) 330 if wrap: 331 model = ( 332 NestedModel.wrap_alt(model, group, fsdp_kwargs) 333 if wrap_alt 334 else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs) 335 ) 336 if not use_multiple_param_groups: 337 optim_input = list(model.parameters()) 338 else: 339 optim_input = [ 340 {"params": model.param_group0()}, 341 {"params": model.param_group1(), "weight_decay": 0.9}, 342 ] 343 # Use a reversed parameter order for the optimizer input on odd ranks 344 if use_diff_optim_inputs and self.rank % 2 == 1: 345 if isinstance(optim_input[0], dict): 346 for param_group in optim_input: 347 param_group["params"] = list(reversed(param_group["params"])) 348 else: 349 optim_input = list(reversed(optim_input)) 350 optim = optim_class(optim_input, lr=0.01) 351 return model, optim, optim_input 352 353 def _init_transformer_model( 354 self, 355 wrap: bool, 356 device: torch.device = torch.device("cuda"), 357 group=None, 358 optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam, 359 use_multiple_param_groups: bool = False, 360 use_diff_optim_inputs: bool = False, 361 ): 362 if use_multiple_param_groups or use_diff_optim_inputs: 363 # Keep these as arguments for parity with `_init_nested_model()`; 364 # these settings are not implemented since the transformer is 365 # wrapped with FSDP at the top-level, which means that there is 366 # only a single flat parameter, making these booleans vacuous 367 raise NotImplementedError 368 if group is None: 369 group = dist.distributed_c10d._get_default_group() 370 model = TransformerWithSharedParams.init( 371 group, 372 FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP, 373 CUDAInitMode.CUDA_BEFORE, 374 deterministic=True, 375 ) 376 optim = optim_class(model.parameters(), lr=0.01) 377 return model, optim, None 378 379 def _step_model( 380 self, 381 model: torch.nn.Module, 382 optim: torch.optim.Optimizer, 383 device: torch.device = torch.device("cuda"), 384 num_iters: int = 1, 385 ) -> List[float]: 386 """Performs a forward pass, backward pass, and optimizer step 387 ``num_iters``-many times, and returns the per-iteration losses.""" 388 torch.manual_seed(0) # set seed for determinism 389 losses = [] 390 module = getattr(model, "module", model) 391 for _ in range(num_iters): 392 optim.zero_grad() 393 inp = module.get_input(device) 394 output = model(*inp) 395 loss = module.get_loss(inp, output).to(device) 396 losses.append(loss.item()) 397 module.run_backward(loss) 398 optim.step() 399 return losses 400 401 def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None): 402 """Broadcasts the full optimizer state dict in place of using 403 ``torch.save()`` and ``torch.load()`` so that all ranks can have it.""" 404 obj_list = [full_osd] 405 dist.broadcast_object_list( 406 obj_list, 407 src=0, 408 group=group, 409 ) 410 full_osd = obj_list[0] 411 return full_osd 412 413 def _are_equal_states( 414 self, 415 state1: Dict[str, Any], 416 state2: Dict[str, Any], 417 ) -> bool: 418 """Checks if ``state1`` and ``state2`` contain the same mappings.""" 419 if set(state1.keys()) != set(state2.keys()): 420 return False 421 for state_name, value1 in state1.items(): 422 value2 = state2[state_name] 423 if type(value1) != type(value2): 424 return False 425 if torch.is_tensor(value1): # tensor state 426 assert torch.is_tensor(value2) 427 # Check the values on CPU to be device-agnostic 428 value1 = value1.cpu() 429 value2 = value2.cpu() 430 if value1.shape != value2.shape or not torch.all( 431 torch.isclose(value1, value2) 432 ): 433 return False 434 else: # non-tensor state 435 if value1 != value2: 436 return False 437 return True 438 439 def _check_same_state( 440 self, 441 fsdp_osd, 442 ref_osd, 443 check_same_param_keys: bool, 444 ): 445 """Checks that ``full_osd`` and ``ref_osd`` have the same "state" part. 446 If ``check_same_param_keys=True``, then checks that the parameter keys 447 match (e.g. when both should be parameter names), and does not check 448 the parameter keys otherwise.""" 449 assert "state" in ref_osd 450 self.assertTrue("state" in fsdp_osd) 451 ref_osd_state = ref_osd["state"] 452 fsdp_osd_state = { 453 k: _gather_state_dict(v) for k, v in fsdp_osd["state"].items() 454 } 455 456 if check_same_param_keys: 457 # Check parameter keys are the same first for earlier erroring 458 ref_osd_param_ids = set(ref_osd_state.keys()) 459 fsdp_osd_param_ids = set(fsdp_osd_state.keys()) 460 self.assertTrue( 461 ref_osd_param_ids == fsdp_osd_param_ids, 462 f"Rank {self.rank}: {(ref_osd_param_ids, fsdp_osd_param_ids)}", 463 ) 464 # Check state values are the same 465 for param_id, param_state in fsdp_osd_state.items(): 466 for state_name, value in param_state.items(): 467 ref_value = ref_osd_state[param_id][state_name] 468 self.assertEqual(value, ref_value) 469 return 470 # Otherwise, only require the parameter keys to be isomorphic (e.g. 471 # between IDs and names) 472 ref_osd_states = list(ref_osd_state.values()) 473 fsdp_osd_states = list(fsdp_osd_state.values()) 474 self.assertEqual(len(ref_osd_states), len(fsdp_osd_states)) 475 # Use brute-force quadratic-time comparison since it is hard to 476 # hash a tensor by value instead of by object 477 for fsdp_osd_state in fsdp_osd_states: 478 # Check for at least one match (may be > 1 in toy edge cases, e.g. 479 # multiple biases); nonetheless, each having >= 1 match and the two 480 # lists having equal length imply that the list contents are equal 481 self.assertTrue( 482 any( 483 self._are_equal_states(fsdp_osd_state, ref_osd_state) 484 for ref_osd_state in ref_osd_states 485 ) 486 ) 487 488 def _check_same_param_groups( 489 self, 490 full_osd, 491 ref_osd, 492 check_same_param_keys: bool, 493 ): 494 """Checks that ``full_osd`` and ``ref_osd`` have the same 495 "param_groups" part. If ``check_same_param_keys=True`, then checks that 496 the parameter keys match (e.g. when both should be parameter names), 497 and does not check the parameter keys otherwise.""" 498 assert "param_groups" in ref_osd 499 self.assertTrue("param_groups" in full_osd) 500 ref_osd_param_groups = ref_osd["param_groups"] 501 full_osd_param_groups = full_osd["param_groups"] 502 self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups)) 503 for full_osd_pg, ref_osd_pg in zip( 504 full_osd_param_groups, 505 ref_osd_param_groups, 506 ): 507 self.assertEqual( 508 set(full_osd_pg.keys()), 509 set(ref_osd_pg.keys()), 510 ) 511 for name, full_osd_value in full_osd_pg.items(): 512 if name == "params" and not check_same_param_keys: 513 continue 514 self.assertEqual(full_osd_value, ref_osd_pg[name]) 515 516 @skip_if_lt_x_gpu(2) 517 @parametrize("state_dict_type", STATE_DICT_TYPES) 518 @parametrize("use_multiple_param_groups", [False, True]) 519 @parametrize("rank0_only", [False, True]) 520 @parametrize("use_diff_optim_inputs", [False, True]) 521 def test_optim_state_dict_nested( 522 self, 523 state_dict_type: StateDictType, 524 use_multiple_param_groups: bool, 525 rank0_only: bool, 526 use_diff_optim_inputs: bool, 527 ) -> None: 528 """ 529 Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict` 530 by comparing the returned dict for an FSDP-wrapped model with that of 531 an equivalent non-wrapped model. 532 533 The test checks the equivalence excluding the parameter keys since the 534 FSDP and normal optimizer state dicts key by names and IDs, 535 respectively. This means that the test can pass even if parameter keys 536 are incorrectly mapped to values. Their correct mapping is tested in 537 other tests that exercise the save/load workflow. 538 """ 539 self.run_subtests( 540 {"use_optim_input": [False, True]}, 541 self._test_optim_state_dict_nested, 542 state_dict_type=state_dict_type, 543 use_multiple_param_groups=use_multiple_param_groups, 544 rank0_only=rank0_only, 545 use_diff_optim_inputs=use_diff_optim_inputs, 546 ) 547 548 def _test_optim_state_dict_nested( 549 self, 550 state_dict_type: StateDictType, 551 use_multiple_param_groups: bool, 552 rank0_only: bool, 553 use_diff_optim_inputs: bool, 554 use_optim_input: bool, 555 ) -> None: 556 if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT: 557 return # not supported 558 NUM_ITERS = 3 559 model1, optim1, optim_input = self._init_nested_model( 560 wrap=True, 561 use_multiple_param_groups=use_multiple_param_groups, 562 use_diff_optim_inputs=use_diff_optim_inputs, 563 ) 564 losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS) 565 if state_dict_type == StateDictType.FULL_STATE_DICT: 566 if use_optim_input: 567 fsdp_osd = FSDP.full_optim_state_dict( 568 model1, 569 optim1, 570 optim_input, 571 rank0_only=rank0_only, 572 ) 573 else: 574 fsdp_osd = FSDP.full_optim_state_dict( 575 model1, 576 optim1, 577 rank0_only=rank0_only, 578 ) 579 else: 580 fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1) 581 # Non-target ranks get an empty state dict 582 if rank0_only and self.rank != 0: 583 self.assertEqual(len(fsdp_osd), 0) 584 return 585 model2, optim2, _ = self._init_nested_model( 586 wrap=False, 587 use_multiple_param_groups=use_multiple_param_groups, 588 use_diff_optim_inputs=use_diff_optim_inputs, 589 ) 590 losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS) 591 ref_osd = optim2.state_dict() 592 # Check the losses to eliminate model drift as a source of error 593 for i, (l1, l2) in enumerate(zip(losses1, losses2)): 594 assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}" 595 # Do not check the parameter keys since the full/sharded optimizer state 596 # dict uses parameter names, while the non-wrapped equivalent uses 597 # parameter IDs 598 check_same_param_keys = False 599 self._check_same_param_groups( 600 fsdp_osd, 601 ref_osd, 602 check_same_param_keys=check_same_param_keys, 603 ) 604 self._check_same_state( 605 fsdp_osd, 606 ref_osd, 607 check_same_param_keys=check_same_param_keys, 608 ) 609 610 @skip_if_lt_x_gpu(2) 611 def test_full_optim_state_dict_keys(self): 612 """Tests that the parameter keys returned by 613 :meth:`full_optim_state_dict` match those of :meth:`state_dict` with 614 full ``state_dict_type`` for a non-FSDP-root model with nested FSDP 615 instances and ignored modules.""" 616 device = torch.device("cuda") 617 model = NestedModel().to(device) 618 wrapped_model = NestedModel.wrap(model, ignore_modules=True) 619 # Add checkpointing to ensure optim_state_dict and state_dict strip out 620 # checkpointing prefixes. 621 apply_activation_checkpointing( 622 model, check_fn=lambda module: isinstance(module, torch.nn.Sequential) 623 ) 624 optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) 625 self._step_model(model, optim, device) 626 optim_state_dict = FSDP.full_optim_state_dict( 627 wrapped_model, optim, rank0_only=False 628 ) 629 with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT): 630 state_dict = wrapped_model.state_dict() 631 self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys()) 632 # Check that checkpointing prefix was indeed stripped. 633 for key in optim_state_dict["state"]: 634 self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key) 635 636 @skip_if_lt_x_gpu(2) 637 def test_full_optim_state_dict_nested_invalid(self): 638 """Tests that :meth:`full_optim_state_dict` raises an error when 639 nonzero ranks are missing the optimizer state for parameters on rank 640 0.""" 641 device = torch.device("cuda") 642 model = NestedModel.wrap(NestedModel().to(device), None) 643 optim_input = list(model.parameters()) 644 if self.rank != 0: 645 # Exclude a parameter so that nonzero ranks are missing state 646 optim_input = optim_input[:-1] 647 optim = torch.optim.Adam(optim_input, lr=1e-3) 648 self._step_model(model, optim, num_iters=3) 649 error_regex = ( 650 "FSDP currently requires each rank to have at least the " 651 "optimizer states needed by rank 0's optimizer but some ranks " 652 "are missing some of those states" 653 ) 654 with self.assertRaisesRegex(RuntimeError, error_regex): 655 FSDP.full_optim_state_dict(model, optim) 656 657 @skip_if_lt_x_gpu(2) 658 @parametrize("use_multiple_param_groups", [False, True]) 659 @parametrize("wrap_alt", [False, True]) 660 @parametrize("use_diff_optim_inputs", [False, True]) 661 def test_shard_full_optim_state_dict_nested( 662 self, 663 use_multiple_param_groups: bool, 664 wrap_alt: bool, 665 use_diff_optim_inputs: bool, 666 ): 667 """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model 668 with nested FSDP instances.""" 669 self.run_subtests( 670 {"use_optim_input": [False, True]}, 671 self._test_load_optim_state, 672 model_class=_ModelClass.NESTED, 673 use_multiple_param_groups=use_multiple_param_groups, 674 halve_world_size=False, 675 osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST, 676 use_diff_optim_inputs=use_diff_optim_inputs, 677 wrap_alt=wrap_alt, 678 num_iters=3, 679 ) 680 681 self._test_load_optim_state_with_optim_state_dict( 682 _ModelClass.NESTED, 683 state_dict_settings=StateDictSettings( 684 StateDictType.FULL_STATE_DICT, 685 FullStateDictConfig(), 686 FullOptimStateDictConfig(), 687 ), 688 use_multiple_param_groups=False, 689 halve_world_size=False, 690 use_diff_optim_inputs=use_diff_optim_inputs, 691 wrap_alt=wrap_alt, 692 num_iters=3, 693 ) 694 695 @skip_if_lt_x_gpu(2) 696 def test_shard_full_optim_state_dict_nested_halve_world_size(self): 697 """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model 698 with nested FSDP instances when loading into a new process group with 699 halved world size.""" 700 # To save CI costs, we test with the "harder" settings: 701 use_multiple_param_groups = True 702 use_diff_optim_inputs = True 703 wrap_alt = True 704 self.run_subtests( 705 {"use_optim_input": [False, True]}, 706 self._test_load_optim_state, 707 model_class=_ModelClass.NESTED, 708 use_multiple_param_groups=use_multiple_param_groups, 709 halve_world_size=True, 710 osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST, 711 use_diff_optim_inputs=use_diff_optim_inputs, 712 wrap_alt=wrap_alt, 713 num_iters=3, 714 ) 715 716 self._test_load_optim_state_with_optim_state_dict( 717 _ModelClass.NESTED, 718 state_dict_settings=StateDictSettings( 719 StateDictType.FULL_STATE_DICT, 720 FullStateDictConfig(), 721 FullOptimStateDictConfig(), 722 ), 723 use_multiple_param_groups=use_multiple_param_groups, 724 halve_world_size=True, 725 use_diff_optim_inputs=use_diff_optim_inputs, 726 wrap_alt=wrap_alt, 727 num_iters=3, 728 ) 729 730 @skip_if_lt_x_gpu(2) 731 def test_shard_full_optim_state_dict_transformer(self) -> None: 732 """Tests :meth:`shard_full_optim_state_dict` for an FSDP-root 733 transformer model with shared parameters.""" 734 self.run_subtests( 735 {"use_optim_input": [False, True]}, 736 self._test_load_optim_state, 737 model_class=_ModelClass.TRANSFORMER, 738 use_multiple_param_groups=False, 739 halve_world_size=True, 740 osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST, 741 use_diff_optim_inputs=False, 742 num_iters=3, 743 ) 744 745 self._test_load_optim_state_with_optim_state_dict( 746 _ModelClass.TRANSFORMER, 747 state_dict_settings=StateDictSettings( 748 StateDictType.FULL_STATE_DICT, 749 FullStateDictConfig(), 750 FullOptimStateDictConfig(), 751 ), 752 use_multiple_param_groups=False, 753 halve_world_size=True, 754 use_diff_optim_inputs=False, 755 num_iters=3, 756 ) 757 758 @skip_if_lt_x_gpu(2) 759 @parametrize("use_multiple_param_groups", [False, True]) 760 @parametrize("wrap_alt", [False, True]) 761 @parametrize("use_diff_optim_inputs", [False, True]) 762 def test_scatter_full_optim_state_dict_nested( 763 self, 764 use_multiple_param_groups: bool, 765 wrap_alt: bool, 766 use_diff_optim_inputs: bool, 767 ): 768 """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root 769 model with nested FSDP instances.""" 770 self.run_subtests( 771 {"use_optim_input": [False, True]}, 772 self._test_load_optim_state, 773 model_class=_ModelClass.NESTED, 774 use_multiple_param_groups=use_multiple_param_groups, 775 halve_world_size=False, 776 osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD, 777 use_diff_optim_inputs=use_diff_optim_inputs, 778 wrap_alt=wrap_alt, 779 num_iters=3, 780 ) 781 782 self._test_load_optim_state_with_optim_state_dict( 783 _ModelClass.NESTED, 784 state_dict_settings=StateDictSettings( 785 StateDictType.FULL_STATE_DICT, 786 FullStateDictConfig(), 787 FullOptimStateDictConfig(rank0_only=True), 788 ), 789 use_multiple_param_groups=use_multiple_param_groups, 790 halve_world_size=False, 791 use_diff_optim_inputs=use_diff_optim_inputs, 792 wrap_alt=wrap_alt, 793 num_iters=3, 794 ) 795 796 @skip_if_lt_x_gpu(2) 797 def test_scatter_full_optim_state_dict_nested_halve_world_size(self): 798 """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root 799 model with nested FSDP instances when loading into a new process group 800 with halved world size.""" 801 # To save CI costs, we test with the "harder" settings: 802 use_multiple_param_groups = True 803 use_diff_optim_inputs = True 804 wrap_alt = True 805 self.run_subtests( 806 {"use_optim_input": [False, True]}, 807 self._test_load_optim_state, 808 model_class=_ModelClass.NESTED, 809 use_multiple_param_groups=use_multiple_param_groups, 810 halve_world_size=True, 811 osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD, 812 use_diff_optim_inputs=use_diff_optim_inputs, 813 wrap_alt=wrap_alt, 814 num_iters=3, 815 ) 816 817 self._test_load_optim_state_with_optim_state_dict( 818 _ModelClass.NESTED, 819 state_dict_settings=StateDictSettings( 820 StateDictType.FULL_STATE_DICT, 821 FullStateDictConfig(), 822 FullOptimStateDictConfig(rank0_only=True), 823 ), 824 use_multiple_param_groups=use_multiple_param_groups, 825 halve_world_size=True, 826 use_diff_optim_inputs=use_diff_optim_inputs, 827 wrap_alt=wrap_alt, 828 num_iters=3, 829 ) 830 831 @skip_if_lt_x_gpu(2) 832 def test_scatter_full_optim_state_dict_transformer(self) -> None: 833 """Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root 834 transformer model with shared parameters.""" 835 self.run_subtests( 836 {"use_optim_input": [False, True]}, 837 self._test_load_optim_state, 838 model_class=_ModelClass.TRANSFORMER, 839 use_multiple_param_groups=False, 840 halve_world_size=True, 841 osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD, 842 use_diff_optim_inputs=False, 843 num_iters=3, 844 ) 845 846 self._test_load_optim_state_with_optim_state_dict( 847 _ModelClass.TRANSFORMER, 848 state_dict_settings=StateDictSettings( 849 StateDictType.FULL_STATE_DICT, 850 FullStateDictConfig(), 851 FullOptimStateDictConfig(rank0_only=True), 852 ), 853 use_multiple_param_groups=False, 854 halve_world_size=True, 855 use_diff_optim_inputs=False, 856 num_iters=3, 857 ) 858 859 @skip_if_lt_x_gpu(2) 860 def test_flatten_sharded_optim_state_dict_nested(self) -> None: 861 """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root 862 nested model.""" 863 self._test_load_optim_state( 864 _ModelClass.NESTED, 865 use_multiple_param_groups=False, 866 halve_world_size=False, 867 osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD, 868 use_diff_optim_inputs=False, 869 use_optim_input=False, 870 wrap_alt=True, 871 num_iters=3, 872 ) 873 874 self._test_load_optim_state_with_optim_state_dict( 875 _ModelClass.NESTED, 876 state_dict_settings=StateDictSettings( 877 StateDictType.SHARDED_STATE_DICT, 878 ShardedStateDictConfig(), 879 ShardedOptimStateDictConfig(), 880 ), 881 use_multiple_param_groups=False, 882 halve_world_size=False, 883 use_diff_optim_inputs=False, 884 wrap_alt=True, 885 num_iters=3, 886 ) 887 888 @skip_if_lt_x_gpu(2) 889 def test_flatten_sharded_optim_state_dict_transformer(self) -> None: 890 """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root 891 transformer model.""" 892 self._test_load_optim_state( 893 _ModelClass.TRANSFORMER, 894 use_multiple_param_groups=False, 895 halve_world_size=False, 896 osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD, 897 use_diff_optim_inputs=False, 898 use_optim_input=False, 899 num_iters=3, 900 ) 901 902 self._test_load_optim_state_with_optim_state_dict( 903 _ModelClass.TRANSFORMER, 904 state_dict_settings=StateDictSettings( 905 StateDictType.SHARDED_STATE_DICT, 906 ShardedStateDictConfig(), 907 ShardedOptimStateDictConfig(), 908 ), 909 use_multiple_param_groups=False, 910 halve_world_size=False, 911 use_diff_optim_inputs=False, 912 num_iters=3, 913 ) 914 915 @skip_if_lt_x_gpu(2) 916 def test_use_orig_params(self) -> None: 917 """Tests :meth:`optim_state_dict` for an FSDP-root nested model.""" 918 self.run_subtests( 919 { 920 "halve_world_size": [True, False], 921 "wrap_alt": [True, False], 922 }, 923 self._test_load_optim_state_with_optim_state_dict, 924 model_class=_ModelClass.NESTED, 925 state_dict_settings=StateDictSettings( 926 StateDictType.FULL_STATE_DICT, 927 FullStateDictConfig(), 928 FullOptimStateDictConfig(), 929 ), 930 use_multiple_param_groups=False, 931 use_diff_optim_inputs=False, 932 num_iters=3, 933 fsdp_kwargs={"use_orig_params": True}, 934 ) 935 936 self.run_subtests( 937 { 938 "halve_world_size": [True, False], 939 "wrap_alt": [True, False], 940 }, 941 self._test_load_optim_state_with_optim_state_dict, 942 model_class=_ModelClass.NESTED, 943 state_dict_settings=StateDictSettings( 944 StateDictType.FULL_STATE_DICT, 945 FullStateDictConfig(), 946 FullOptimStateDictConfig(rank0_only=True), 947 ), 948 use_multiple_param_groups=False, 949 use_diff_optim_inputs=False, 950 num_iters=3, 951 fsdp_kwargs={"use_orig_params": True}, 952 ) 953 954 self.run_subtests( 955 { 956 "wrap_alt": [True, False], 957 }, 958 self._test_load_optim_state_with_optim_state_dict, 959 model_class=_ModelClass.NESTED, 960 state_dict_settings=StateDictSettings( 961 StateDictType.SHARDED_STATE_DICT, 962 ShardedStateDictConfig(), 963 ShardedOptimStateDictConfig(), 964 ), 965 use_multiple_param_groups=False, 966 # We cannot test halve_world_size with SHARDED_STATE_DICT. 967 halve_world_size=False, 968 use_diff_optim_inputs=False, 969 num_iters=3, 970 fsdp_kwargs={"use_orig_params": True}, 971 ) 972 973 def _test_load_optim_state( 974 self, 975 model_class: _ModelClass, 976 use_multiple_param_groups: bool, 977 halve_world_size: bool, 978 osd_comm_method: _OSDCommMethod, 979 use_diff_optim_inputs: bool, 980 use_optim_input: bool, 981 num_iters: int, 982 **new_model_kwargs, 983 ): 984 """ 985 (1) Runs a model with full world size for K iterations to generate a 986 full/sharded optimizer state dict; 987 (2) initializes a model with halved world size and possibly different 988 FSDP wrapping scheme (based on ``new_model_kwargs``); 989 (3) loads the full/sharded optimizer state dict from (1) according to the 990 halved-world-size model; 991 (4) runs the halved-world-size model for K iterations; and 992 (5) checks that the sharded optimizer state dict from (3) matches the 993 halved-world-size model's local optimizer state dict, meaning that the 994 former could have equivalently been loaded into the local optimizer. 995 """ 996 initializer = self._model_class[model_class] 997 if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT: 998 osd_method = FSDP.optim_state_dict 999 elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD: 1000 osd_method = FSDP.sharded_optim_state_dict 1001 else: 1002 osd_method = FSDP.full_optim_state_dict 1003 1004 # First, run a wrapped model with full world size for a few iterations 1005 model1, optim1, optim_input1 = initializer( 1006 wrap=True, 1007 use_multiple_param_groups=use_multiple_param_groups, 1008 ) 1009 self._step_model(model1, optim1, num_iters=num_iters) 1010 fsdp_osd1 = ( 1011 osd_method(model1, optim1, optim_input1) 1012 if use_optim_input 1013 else osd_method(model1, optim1) 1014 ) 1015 if halve_world_size: 1016 # Create a new process group with halved world size 1017 new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0] 1018 new_group = dist.new_group(ranks=new_group_ranks) 1019 if self.rank not in new_group_ranks: 1020 return 1021 else: 1022 # Continue using the same group and hence world size 1023 new_group = dist.distributed_c10d._get_default_group() 1024 # Second, run a wrapped model with (possibly) halved world size and 1025 # (possibly) differing `optim_input` across ranks 1026 model2, optim2, optim_input2 = initializer( 1027 wrap=True, 1028 group=new_group, 1029 use_multiple_param_groups=use_multiple_param_groups, 1030 use_diff_optim_inputs=use_diff_optim_inputs, 1031 **new_model_kwargs, # specify `wrap_alt` to change wrapping 1032 ) 1033 self._step_model(model2, optim2, num_iters=num_iters) 1034 fsdp_osd2 = ( 1035 osd_method(model2, optim2, optim_input2, group=new_group) 1036 if use_optim_input 1037 else osd_method(model2, optim2, group=new_group) 1038 ) 1039 # Compute two sharded optim state dicts: (1) for the first model 1040 # according to the second model and (2) for the second model according 1041 # to the second model 1042 if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST: 1043 fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group) 1044 sharded_osd1 = ( 1045 FSDP.shard_full_optim_state_dict( 1046 fsdp_osd1, model2, optim_input=optim_input2 1047 ) 1048 if use_optim_input 1049 else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2) 1050 ) 1051 fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group) 1052 sharded_osd2 = ( 1053 FSDP.shard_full_optim_state_dict( 1054 fsdp_osd2, model2, optim_input=optim_input2 1055 ) 1056 if use_optim_input 1057 else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2) 1058 ) 1059 elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD: 1060 sharded_osd1 = ( 1061 FSDP.scatter_full_optim_state_dict( 1062 fsdp_osd1 if self.rank == 0 else None, 1063 model2, 1064 optim_input=optim_input2, 1065 group=new_group, 1066 ) 1067 if use_optim_input 1068 else FSDP.scatter_full_optim_state_dict( 1069 fsdp_osd1 if self.rank == 0 else None, 1070 model2, 1071 optim=optim2, 1072 group=new_group, 1073 ) 1074 ) 1075 sharded_osd2 = ( 1076 FSDP.scatter_full_optim_state_dict( 1077 fsdp_osd2 if self.rank == 0 else None, 1078 model2, 1079 optim_input=optim_input2, 1080 group=new_group, 1081 ) 1082 if use_optim_input 1083 else FSDP.scatter_full_optim_state_dict( 1084 fsdp_osd2 if self.rank == 0 else None, 1085 model2, 1086 optim=optim2, 1087 group=new_group, 1088 ) 1089 ) 1090 elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD: 1091 sharded_osd1 = FSDP.flatten_sharded_optim_state_dict( 1092 fsdp_osd1, 1093 model2, 1094 optim=optim2, 1095 ) 1096 sharded_osd2 = FSDP.flatten_sharded_optim_state_dict( 1097 fsdp_osd2, 1098 model2, 1099 optim=optim2, 1100 ) 1101 elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT: 1102 sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1) 1103 sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2) 1104 1105 # As a sanity check, check that sharding the second model's full/sharded 1106 # optimizer state dict according to itself is equivalent to its local 1107 # optimizer's state dict 1108 local_osd2 = optim2.state_dict() 1109 check_same_param_keys = True # should all have matching parameter IDs 1110 self._check_same_param_groups( 1111 sharded_osd2, 1112 local_osd2, 1113 check_same_param_keys=check_same_param_keys, 1114 ) 1115 self._check_same_state( 1116 sharded_osd2, 1117 local_osd2, 1118 check_same_param_keys=check_same_param_keys, 1119 ) 1120 # Check that sharding the first model's full/sharded optimizer state dict 1121 # according to the second model is equivalent to the second model's 1122 # local optimizer state dict 1123 self._check_same_param_groups( 1124 sharded_osd1, 1125 local_osd2, 1126 check_same_param_keys=check_same_param_keys, 1127 ) 1128 self._check_same_state( 1129 sharded_osd1, 1130 local_osd2, 1131 check_same_param_keys=check_same_param_keys, 1132 ) 1133 # As a sanity check, check that we can load and run a few iterations 1134 optim2.load_state_dict(sharded_osd2) 1135 self._step_model(model2, optim2, num_iters=num_iters) 1136 1137 @skip_if_lt_x_gpu(2) 1138 @parametrize("state_dict_type", STATE_DICT_TYPES) 1139 @parametrize("add_to_fsdp_module", [False, True]) 1140 def test_shard_full_optim_state_dict_unmanaged_params( 1141 self, 1142 state_dict_type: StateDictType, 1143 add_to_fsdp_module: bool, 1144 ): 1145 """ 1146 Tests :meth:`shard_full_optim_state_dict` when there are unmanaged 1147 parameters. 1148 - If ``add_to_fsdp_module=True``, then the unmanaged parameters are 1149 added to a module to be wrapped with FSDP, in which case there should 1150 be an error since we require that all unflattened parameter 1151 comprising a flat parameter have the same scalar state (e.g. Adam 1152 "step") but the added parameter is missing its entry. 1153 - If ``add_to_fsdp_module=False``, then the unmanaged parameters are 1154 added to a module not to be wrapped with FSDP, in which case there 1155 should be no error (emulating model parallel use cases where some 1156 parameters may be managed externally to FSDP). 1157 We do not separately test unmanaged parameters for 1158 :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict` 1159 to save CI cost since it call into the same subroutine 1160 :meth:`_flatten_optim_state_dict`. 1161 """ 1162 if state_dict_type == StateDictType.SHARDED_STATE_DICT: 1163 use_optim_input = [False] 1164 else: 1165 use_optim_input = [False, True] 1166 self.run_subtests( 1167 {"use_optim_input": use_optim_input}, 1168 self._test_shard_full_optim_state_dict_unmanaged_params, 1169 state_dict_type=state_dict_type, 1170 add_to_fsdp_module=add_to_fsdp_module, 1171 ) 1172 1173 def _test_shard_full_optim_state_dict_unmanaged_params( 1174 self, 1175 state_dict_type: StateDictType, 1176 add_to_fsdp_module: bool, 1177 use_optim_input: bool, 1178 ): 1179 NUM_ITERS = 1 1180 # Create a normal wrapped model 1181 model, optim, optim_input = self._init_nested_model(wrap=True) 1182 self._step_model(model, optim, num_iters=NUM_ITERS) 1183 1184 if state_dict_type == StateDictType.FULL_STATE_DICT: 1185 fsdp_osd = ( 1186 FSDP.full_optim_state_dict(model, optim, optim_input, rank0_only=False) 1187 if use_optim_input 1188 else FSDP.full_optim_state_dict(model, optim, rank0_only=False) 1189 ) # save on all ranks to avoid having to broadcast from rank 0 1190 else: 1191 fsdp_osd = FSDP.sharded_optim_state_dict(model, optim) 1192 # Create a new model with the same structure but additional unmanaged 1193 # parameters, representing the model for which we want to load 1194 device = torch.device("cuda") 1195 model = NestedModel().to(device) 1196 model, unmanaged_params = NestedModel.wrap_with_unmanaged_params( 1197 model, 1198 add_to_fsdp_module, 1199 ) 1200 optim_input = list(model.parameters()) 1201 optim = torch.optim.Adam(optim_input, lr=1e-3) 1202 if add_to_fsdp_module: 1203 # If we add the unmanaged parameters to a module wrapped with FSDP, 1204 # then the flat parameter will be comprised of some unflattened 1205 # parameters with zero-dimensional tensor state (i.e. Adam "step") 1206 # and others without (i.e. the unmanaged parameters), which 1207 # triggers an error that we have to ensure correctness 1208 error_prefix = ( 1209 "^(All unflattened parameters comprising a " 1210 "single flat parameter must have scalar state with the " 1211 "same value and dtype)" 1212 ) 1213 with self.assertRaisesRegex(ValueError, error_prefix): 1214 if state_dict_type == StateDictType.FULL_STATE_DICT: 1215 ( 1216 FSDP.shard_full_optim_state_dict( 1217 fsdp_osd, model, optim_input=optim_input 1218 ) 1219 if use_optim_input 1220 else FSDP.shard_full_optim_state_dict( 1221 fsdp_osd, model, optim=optim 1222 ) 1223 ) 1224 else: 1225 FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim) 1226 else: 1227 # If we add the unmanaged parameters to a module not wrapped with 1228 # FSDP, then we simply ignore them without erroring to enable 1229 # model parallelism use cases, where some parameters are managed 1230 # externally to FSDP 1231 if state_dict_type == StateDictType.FULL_STATE_DICT: 1232 flattened_osd = ( 1233 FSDP.shard_full_optim_state_dict( 1234 fsdp_osd, model, optim_input=optim_input 1235 ) 1236 if use_optim_input 1237 else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim) 1238 ) 1239 else: 1240 flattened_osd = FSDP.flatten_sharded_optim_state_dict( 1241 fsdp_osd, model, optim=optim 1242 ) 1243 # Add entries for the unmanaged parameters to be able to load 1244 for unmanaged_param in unmanaged_params: 1245 NestedModel.add_unmanaged_param_entry( 1246 flattened_osd, 1247 unmanaged_param, 1248 NUM_ITERS, 1249 ) 1250 # Check that we can load the optimizer state dict 1251 optim.load_state_dict(flattened_osd) 1252 1253 @skip_if_lt_x_gpu(2) 1254 @parametrize("state_dict_type", STATE_DICT_TYPES) 1255 @parametrize("use_multiple_param_groups", [False, True]) 1256 def test_rekey_optim_state_dict_to_ids( 1257 self, 1258 state_dict_type: StateDictType, 1259 use_multiple_param_groups: bool, 1260 ): 1261 """Tests :meth:`rekey_optim_state_dict` with the new keys being 1262 parameter IDs by checking that a wrapped model (i.e. with FSDP modules) 1263 can rekey its optimizer state dict to match that of an equivalent 1264 non-wrapped model (i.e. without FSDP modules).""" 1265 if state_dict_type == StateDictType.SHARDED_STATE_DICT: 1266 use_optim_input = [False] 1267 else: 1268 use_optim_input = [False, True] 1269 self.run_subtests( 1270 {"use_optim_input": use_optim_input}, 1271 self._test_rekey_optim_state_dict_to_ids, 1272 state_dict_type=state_dict_type, 1273 use_multiple_param_groups=use_multiple_param_groups, 1274 ) 1275 1276 @skip_if_lt_x_gpu(2) 1277 def _test_rekey_optim_state_dict_to_ids( 1278 self, 1279 state_dict_type: StateDictType, 1280 use_multiple_param_groups: bool, 1281 use_optim_input: bool, 1282 ): 1283 NUM_ITERS = 3 1284 # Run a wrapped model for a few iterations 1285 model1, optim1, optim_input1 = self._init_nested_model( 1286 wrap=True, 1287 use_multiple_param_groups=use_multiple_param_groups, 1288 ) 1289 self._step_model(model1, optim1, num_iters=NUM_ITERS) 1290 if state_dict_type == StateDictType.FULL_STATE_DICT: 1291 fsdp_osd = ( 1292 FSDP.full_optim_state_dict(model1, optim1, optim_input1) 1293 if use_optim_input 1294 else FSDP.full_optim_state_dict(model1, optim1) 1295 ) 1296 # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks 1297 # have the full state dict 1298 fsdp_osd = self._broadcast_full_osd(fsdp_osd) 1299 else: 1300 fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1) 1301 # Run a non-wrapped model for a few iterations 1302 model2, optim2, optim_input2 = self._init_nested_model( 1303 wrap=False, 1304 use_multiple_param_groups=use_multiple_param_groups, 1305 ) 1306 self._step_model(model2, optim2, num_iters=NUM_ITERS) 1307 # Re-key the wrapped model's optimizer state dict using parameter IDs 1308 # according to the non-wrapped model 1309 rekeyed_osd = ( 1310 FSDP.rekey_optim_state_dict( 1311 fsdp_osd, 1312 OptimStateKeyType.PARAM_ID, 1313 model2, 1314 optim_input=optim_input2, 1315 ) 1316 if use_optim_input 1317 else FSDP.rekey_optim_state_dict( 1318 fsdp_osd, 1319 OptimStateKeyType.PARAM_ID, 1320 model2, 1321 optim=optim2, 1322 ) 1323 ) 1324 # Check that the re-keyed dict and actual dict are the same 1325 osd = optim2.state_dict() 1326 check_same_param_keys = True 1327 self._check_same_param_groups( 1328 rekeyed_osd, 1329 osd, 1330 check_same_param_keys=check_same_param_keys, 1331 ) 1332 self._check_same_state( 1333 rekeyed_osd, 1334 osd, 1335 check_same_param_keys=check_same_param_keys, 1336 ) 1337 # As a sanity check, check that we can load and run a few iterations 1338 if state_dict_type != StateDictType.SHARDED_STATE_DICT: 1339 optim2.load_state_dict(rekeyed_osd) 1340 self._step_model(model2, optim2, num_iters=NUM_ITERS) 1341 1342 @skip_if_lt_x_gpu(2) 1343 def test_rekey_optim_state_dict_to_names(self): 1344 """Tests :meth:`rekey_optim_state_dict` with the new keys being 1345 parameter names by checking that a non-wrapped model (i.e. without FSDP 1346 modules) can rekey its optimizer state dict to match the expected 1347 output of :meth:`full_optim_state_dict`, hence be sharded using 1348 :meth:`shard_full_optim_state_dict`, and finally match the per-rank 1349 optimizer state dict of a wrapped model (i.e. with FSDP modules).""" 1350 self.run_subtests( 1351 {"use_optim_input": [False, True]}, 1352 self._test_rekey_optim_state_dict_to_names, 1353 use_multiple_param_groups=False, 1354 ) 1355 1356 def _test_rekey_optim_state_dict_to_names( 1357 self, 1358 use_multiple_param_groups: bool, 1359 use_optim_input: bool, 1360 ): 1361 NUM_ITERS = 3 1362 # Run a wrapped model for a few iterations 1363 model1, optim1, optim_input1 = self._init_nested_model( 1364 wrap=True, 1365 use_multiple_param_groups=use_multiple_param_groups, 1366 ) 1367 self._step_model(model1, optim1, num_iters=NUM_ITERS) 1368 # Run a non-wrapped model for a few iterations 1369 model2, optim2, optim_input2 = self._init_nested_model( 1370 wrap=False, 1371 use_multiple_param_groups=use_multiple_param_groups, 1372 ) 1373 self._step_model(model2, optim2, num_iters=NUM_ITERS) 1374 # Re-key the non-wrapped model's optimizer state dict using parameter 1375 # names (still according to itself) 1376 osd2 = optim2.state_dict() 1377 rekeyed_osd = ( 1378 FSDP.rekey_optim_state_dict( 1379 osd2, 1380 OptimStateKeyType.PARAM_NAME, 1381 model2, 1382 optim_input=optim_input2, 1383 ) 1384 if use_optim_input 1385 else FSDP.rekey_optim_state_dict( 1386 osd2, 1387 OptimStateKeyType.PARAM_NAME, 1388 model2, 1389 optim=optim2, 1390 ) 1391 ) 1392 # Shard the non-wrapped model's re-keyed optimizer state dict, which 1393 # maps back to (flattened) parameter IDs 1394 sharded_osd = ( 1395 FSDP.shard_full_optim_state_dict( 1396 rekeyed_osd, 1397 model1, 1398 optim_input=optim_input1, 1399 ) 1400 if use_optim_input 1401 else FSDP.shard_full_optim_state_dict( 1402 rekeyed_osd, 1403 model1, 1404 optim=optim1, 1405 ) 1406 ) 1407 # Check that this sharded optimizer state dict matches the wrapped 1408 # model's per-rank optimizer state dict 1409 osd1 = optim1.state_dict() 1410 check_same_param_keys = True 1411 self._check_same_param_groups( 1412 sharded_osd, 1413 osd1, 1414 check_same_param_keys=check_same_param_keys, 1415 ) 1416 self._check_same_state( 1417 sharded_osd, 1418 osd1, 1419 check_same_param_keys=check_same_param_keys, 1420 ) 1421 # As a sanity check, check that we can load and run a few iterations 1422 optim1.load_state_dict(sharded_osd) 1423 self._step_model(model1, optim1, num_iters=NUM_ITERS) 1424 1425 @skip_if_lt_x_gpu(2) 1426 def test_optim_input_warning(self): 1427 """Tests that passing the ``optim_input`` argument into optimizer state 1428 checkpointing APIs issues a warning.""" 1429 1430 def should_check_method(method_name: str): 1431 # Check every method since they all accept `optim_input` 1432 return method_name not in ( 1433 "sharded_optim_state_dict", 1434 "flatten_sharded_optim_state_dict", 1435 ) 1436 1437 def get_warning_context(): 1438 warning_regex = "`optim_input` argument is deprecated" 1439 return self.assertWarnsRegex( 1440 expected_warning=FutureWarning, expected_regex=warning_regex 1441 ) 1442 1443 self._run_on_all_optim_state_apis( 1444 should_check_method, get_warning_context, fsdp_kwargs=None 1445 ) 1446 1447 def _run_on_all_optim_state_apis( 1448 self, 1449 should_check_method_fn: Callable[[str], bool], 1450 context_fn: Callable, 1451 fsdp_kwargs: Optional[Dict[str, Any]], 1452 ): 1453 """ 1454 Runs through all optimizer state checkpointing APIs with a context 1455 manager instantiated by ``context_fn``. Certain APIs can be skipped 1456 via ``should_check_method_fn``, which gets passed the string name of 1457 the method. 1458 """ 1459 wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model( 1460 wrap=True, 1461 use_multiple_param_groups=False, 1462 fsdp_kwargs=fsdp_kwargs, 1463 ) 1464 self._step_model(wrapped_model, wrapped_optim, num_iters=2) 1465 1466 # Sharded optim state dict 1467 if should_check_method_fn("sharded_optim_state_dict"): 1468 with context_fn(): 1469 fsdp_osd = FSDP.sharded_optim_state_dict( 1470 wrapped_model, 1471 wrapped_optim, 1472 ) 1473 if "fsdp_osd" not in locals(): 1474 fsdp_osd = {} # may not be defined due to previous method erroring 1475 if should_check_method_fn("flatten_sharded_optim_state_dict"): 1476 with context_fn(): 1477 FSDP.flatten_sharded_optim_state_dict( 1478 fsdp_osd, 1479 wrapped_model, 1480 wrapped_optim, 1481 ) 1482 # Full optim state dict 1483 if should_check_method_fn("full_optim_state_dict"): 1484 with context_fn(): 1485 fsdp_osd = FSDP.full_optim_state_dict( 1486 wrapped_model, 1487 wrapped_optim, 1488 optim_input=wrapped_optim_input, 1489 rank0_only=False, 1490 ) 1491 if should_check_method_fn("shard_full_optim_state_dict"): 1492 with context_fn(): 1493 FSDP.shard_full_optim_state_dict( 1494 fsdp_osd, 1495 wrapped_model, 1496 optim_input=wrapped_optim_input, 1497 ) 1498 if should_check_method_fn("scatter_full_optim_state_dict"): 1499 with context_fn(): 1500 FSDP.scatter_full_optim_state_dict( 1501 fsdp_osd, 1502 wrapped_model, 1503 optim_input=wrapped_optim_input, 1504 ) 1505 # Rekey optim state dict 1506 ( 1507 nonwrapped_model, 1508 nonwrapped_optim, 1509 nonwrapped_optim_input, 1510 ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False) 1511 if should_check_method_fn("rekey_optim_state_dict"): 1512 with context_fn(): 1513 rekeyed_osd = FSDP.rekey_optim_state_dict( 1514 fsdp_osd, # from `full_optim_state_dict()` 1515 OptimStateKeyType.PARAM_ID, 1516 nonwrapped_model, 1517 optim_input=nonwrapped_optim_input, 1518 ) 1519 self._step_model(nonwrapped_model, nonwrapped_optim, num_iters=2) 1520 osd = nonwrapped_optim.state_dict() 1521 if should_check_method_fn("rekey_optim_state_dict"): 1522 with context_fn(): 1523 FSDP.rekey_optim_state_dict( 1524 osd, 1525 OptimStateKeyType.PARAM_NAME, 1526 nonwrapped_model, 1527 optim_input=nonwrapped_optim_input, 1528 ) 1529 1530 @skip_if_lt_x_gpu(2) 1531 @parametrize("state_dict_type", STATE_DICT_TYPES) 1532 def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType): 1533 """ 1534 Tests saving and loading an optim state dict for Adam optimizer (i.e. 1535 any optimizer with a "step" key in its state) when the first parameter 1536 does not have optimizer state (e.g. unused or frozen). 1537 """ 1538 1539 class Model(nn.Module): 1540 def __init__(self) -> None: 1541 super().__init__() 1542 self.lin1 = nn.Linear(5, 5) 1543 self.lin2 = nn.Linear(5, 5) 1544 self.relu = nn.ReLU() 1545 1546 def forward(self, x: torch.Tensor) -> torch.Tensor: 1547 # Do not use `lin1`, which is the parameter passed to the 1548 # optimizer and the one checked for "step" state to see if it 1549 # is tensor or float 1550 return self.relu(self.lin2(x)) 1551 1552 model = Model().cuda() 1553 model.lin1 = FSDP(model.lin1) 1554 model.lin2 = FSDP(model.lin2) 1555 fsdp_model = FSDP(model) 1556 optim = torch.optim.Adam( 1557 fsdp_model.parameters(), lr=1e-2 1558 ) # or any optimizer with "step" 1559 1560 # Run an iteration to construct optimizer state 1561 device = torch.device("cuda") 1562 inp = torch.randn((2, 5), device=device) 1563 loss = fsdp_model(inp).sum() 1564 loss.backward() 1565 optim.step() 1566 1567 # Check that save and load does not error 1568 if state_dict_type == StateDictType.FULL_STATE_DICT: 1569 fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False) 1570 flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model) 1571 elif state_dict_type == StateDictType.SHARDED_STATE_DICT: 1572 fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim) 1573 flattened_osd = FSDP.flatten_sharded_optim_state_dict( 1574 fsdp_osd, fsdp_model, optim 1575 ) 1576 optim.load_state_dict(flattened_osd) 1577 # `__setstate__()` will check the 0th parameter to see if "step" is 1578 # represented as a tensor or float, so it is imperative that its state 1579 # is non-empty. 1580 1581 # Run an iteration as a sanity check 1582 inp = torch.randn((2, 5), device=device) 1583 loss = fsdp_model(inp).sum() 1584 loss.backward() 1585 optim.step() 1586 1587 @skip_if_lt_x_gpu(2) 1588 def test_compatible_with_trec(self): 1589 class DenseModel(torch.nn.Module): 1590 def __init__(self) -> None: 1591 super().__init__() 1592 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 1593 self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 1594 self.net3 = nn.Linear(32, 64) 1595 self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 1596 1597 def forward(self, x): 1598 return self.net4(self.net3(self.net2(self.net1(x)))) 1599 1600 class FakeMPModel(torch.nn.Module): 1601 def __init__(self) -> None: 1602 super().__init__() 1603 torch.manual_seed(0) 1604 self.dense = FSDP(DenseModel().cuda(), use_orig_params=True) 1605 if dist.get_rank() == 0: 1606 self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) 1607 else: 1608 self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) 1609 1610 def forward(self, x): 1611 if dist.get_rank() == 0: 1612 sparse = self.sparse0(x) 1613 else: 1614 sparse = self.sparse1(x) 1615 dist.all_reduce(sparse) 1616 return self.dense(sparse) 1617 1618 models = [FakeMPModel().cuda(), FakeMPModel().cuda()] 1619 optims = [ 1620 torch.optim.Adam(models[0].parameters(), lr=1e-2), 1621 _NamedOptimizer( 1622 models[1].named_parameters(), 1623 torch.optim.Adam, 1624 [{"params": models[1].parameters()}], 1625 models[1], 1626 lr=1e-2, 1627 ), 1628 ] 1629 state_dicts = [] 1630 1631 # Train one batch and see if optim_state_dict are the same. 1632 batch = torch.rand(5, 8, device=torch.device("cuda")) 1633 for model, optim in zip(models, optims): 1634 # Eagerly initialize the states 1635 for param in model.parameters(): 1636 if param.requires_grad: 1637 t = torch.zeros_like(param) 1638 param.grad = torch.autograd.Variable(t) 1639 optim.step() 1640 loss = model(batch).sum() 1641 loss.backward() 1642 optim.step() 1643 state_dicts.append(deepcopy(FSDP.optim_state_dict(model, optim))) 1644 1645 self._check_same_param_groups( 1646 state_dicts[0], state_dicts[1], check_same_param_keys=False 1647 ) 1648 self._check_same_state( 1649 state_dicts[0], state_dicts[1], check_same_param_keys=True 1650 ) 1651 1652 # Make optim1 has a different state. 1653 for i in range(5): 1654 batch = torch.rand(5, 8).cuda() 1655 loss = models[1](batch).sum() 1656 loss.backward() 1657 optims[1].step() 1658 1659 # Load the state back to see if load_optim_state_dict works. 1660 state_dict_to_load = FSDP.optim_state_dict_to_load( 1661 models[1], optims[1], state_dicts[1], is_named_optimizer=True 1662 ) 1663 optims[1].load_state_dict(state_dict_to_load) 1664 state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1]) 1665 1666 self._check_same_param_groups( 1667 state_dicts[0], state_dicts[1], check_same_param_keys=False 1668 ) 1669 self._check_same_state( 1670 state_dicts[0], state_dicts[1], check_same_param_keys=True 1671 ) 1672 1673 @skip_if_lt_x_gpu(2) 1674 def test_optim_state_without_param_groups(self): 1675 class SimpleModel(torch.nn.Module): 1676 def __init__(self) -> None: 1677 super().__init__() 1678 torch.manual_seed(0) 1679 self.net1 = nn.Sequential(nn.Linear(2, 4), nn.ReLU()) 1680 1681 def forward(self, x): 1682 return self.net1(x) 1683 1684 model = FSDP(SimpleModel().cuda()) 1685 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 1686 1687 # Train one step to save original optimizer state dict and original optimizer param groups. 1688 batch = torch.rand(3, 2, device=torch.device("cuda")) 1689 for param in model.parameters(): 1690 if param.requires_grad: 1691 t = torch.zeros_like(param) 1692 param.grad = torch.autograd.Variable(t) 1693 optim.step() 1694 loss = model(batch).sum() 1695 loss.backward() 1696 1697 original_osd = deepcopy(optim.state_dict()) 1698 original_osd_no_param_groups = deepcopy(original_osd) 1699 # manually remove param_groups from optimizer state dict 1700 original_param_groups = deepcopy( 1701 original_osd_no_param_groups.pop("param_groups") 1702 ) 1703 # passing the osd without param_groups to FSDP 1704 original_fsdp_optim_state_dict = deepcopy( 1705 FSDP.optim_state_dict( 1706 model, optim, optim_state_dict=original_osd_no_param_groups 1707 ) 1708 ) 1709 # check the state_dict sharded by FSDP does not contain param_groups. 1710 self.assertEqual(None, original_fsdp_optim_state_dict.get("param_groups")) 1711 1712 # train another step to make optim a different state. 1713 for param in model.parameters(): 1714 if param.requires_grad: 1715 t = torch.zeros_like(param) 1716 param.grad = torch.autograd.Variable(t) 1717 optim.step() 1718 loss = model(batch).sum() 1719 loss.backward() 1720 1721 state_dict_to_load = FSDP.optim_state_dict_to_load( 1722 model, optim, original_fsdp_optim_state_dict 1723 ) 1724 # manually add param_groups to state_dict_to_load before loading the optimizer state 1725 state_dict_to_load["param_groups"] = original_param_groups 1726 optim.load_state_dict(state_dict_to_load) 1727 self.assertEqual(original_osd, optim.state_dict()) 1728 1729 fsdp_optim_state = FSDP.optim_state_dict(model, optim) 1730 self._check_same_state( 1731 original_fsdp_optim_state_dict, fsdp_optim_state, check_same_param_keys=True 1732 ) 1733 self.assertEqual(original_param_groups, optim.state_dict()["param_groups"]) 1734 1735 @skip_if_lt_x_gpu(2) 1736 def test_with_empty_optimizer_state(self): 1737 model = FSDP(TestDummyModel().cuda()) 1738 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 1739 state_dict = optim.state_dict() 1740 gathered_state_dict = FSDP.optim_state_dict(model, optim) 1741 self.assertEqual(gathered_state_dict["state"], state_dict["state"]) 1742 1743 def _test_load_optim_state_with_optim_state_dict( 1744 self, 1745 model_class: _ModelClass, 1746 state_dict_settings: StateDictSettings, 1747 use_multiple_param_groups: bool, 1748 halve_world_size: bool, 1749 use_diff_optim_inputs: bool, 1750 num_iters: int, 1751 **new_model_kwargs, 1752 ): 1753 """ 1754 (1) Runs a model with full world size for K iterations to generate a 1755 full/sharded optimizer state dict; 1756 (2) initializes a model with halved world size and possibly different 1757 FSDP wrapping scheme (based on ``new_model_kwargs``); 1758 (3) loads the full/sharded optimizer state dict from (1) according to the 1759 halved-world-size model; 1760 (4) runs the halved-world-size model for K iterations; and 1761 (5) checks that the sharded optimizer state dict from (3) matches the 1762 halved-world-size model's local optimizer state dict, meaning that the 1763 former could have equivalently been loaded into the local optimizer. 1764 """ 1765 initializer = self._model_class[model_class] 1766 1767 # First, run a wrapped model with full world size for a few iterations 1768 model1, optim1, optim_input1 = initializer( 1769 wrap=True, 1770 use_multiple_param_groups=use_multiple_param_groups, 1771 ) 1772 FSDP.set_state_dict_type( 1773 model1, 1774 state_dict_settings.state_dict_type, 1775 state_dict_settings.state_dict_config, 1776 state_dict_settings.optim_state_dict_config, 1777 ) 1778 self._step_model(model1, optim1, num_iters=num_iters) 1779 fsdp_osd1 = FSDP.optim_state_dict(model1, optim1) 1780 if halve_world_size: 1781 # Create a new process group with halved world size 1782 new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0] 1783 new_group = dist.new_group(ranks=new_group_ranks) 1784 if self.rank not in new_group_ranks: 1785 return 1786 else: 1787 # Continue using the same group and hence world size 1788 new_group = dist.distributed_c10d._get_default_group() 1789 # Second, run a wrapped model with (possibly) halved world size and 1790 # (possibly) differing `optim_input` across ranks 1791 model2, optim2, optim_input2 = initializer( 1792 wrap=True, 1793 group=new_group, 1794 use_multiple_param_groups=use_multiple_param_groups, 1795 use_diff_optim_inputs=use_diff_optim_inputs, 1796 **new_model_kwargs, # specify `wrap_alt` to change wrapping 1797 ) 1798 FSDP.set_state_dict_type( 1799 model2, 1800 state_dict_settings.state_dict_type, 1801 state_dict_settings.state_dict_config, 1802 state_dict_settings.optim_state_dict_config, 1803 ) 1804 self._step_model(model2, optim2, num_iters=num_iters) 1805 fsdp_osd2 = FSDP.optim_state_dict(model2, optim2, group=new_group) 1806 # Compute two sharded optim state dicts: (1) for the first model 1807 # according to the second model and (2) for the second model according 1808 # to the second model 1809 sharded_osd2 = FSDP.optim_state_dict_to_load( 1810 model2, optim2, fsdp_osd2, group=new_group 1811 ) 1812 1813 # As a sanity check, check that sharding the second model's full/sharded 1814 # optimizer state dict according to itself is equivalent to its local 1815 # optimizer's state dict 1816 local_osd2 = optim2.state_dict() 1817 self._check_same_param_groups( 1818 sharded_osd2, 1819 local_osd2, 1820 check_same_param_keys=True, 1821 ) 1822 self._check_same_state( 1823 sharded_osd2, 1824 local_osd2, 1825 check_same_param_keys=True, 1826 ) 1827 # Check that sharding the first model's full/sharded optimizer state dict 1828 # according to the second model is equivalent to the second model's 1829 # local optimizer state dict 1830 sharded_osd1 = FSDP.optim_state_dict_to_load( 1831 model2, optim2, fsdp_osd1, group=new_group 1832 ) 1833 self._check_same_param_groups( 1834 sharded_osd1, 1835 local_osd2, 1836 check_same_param_keys=True, 1837 ) 1838 self._check_same_state( 1839 sharded_osd1, 1840 local_osd2, 1841 check_same_param_keys=True, 1842 ) 1843 # As a sanity check, check that we can load and run a few iterations 1844 optim2.load_state_dict(sharded_osd2) 1845 self._step_model(model2, optim2, num_iters=num_iters) 1846 1847 @skip_if_lt_x_gpu(2) 1848 def test_interface_arguments(self): 1849 model = FSDP(TestDummyModel().cuda()) 1850 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 1851 1852 def step(): 1853 loss = model(model.get_input()) 1854 loss.backward(loss) 1855 optim.step() 1856 1857 step() 1858 original_osd = deepcopy(optim.state_dict()) 1859 osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd) 1860 self._check_same_state( 1861 FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True 1862 ) 1863 step() 1864 osd_to_load = FSDP.optim_state_dict_to_load( 1865 model, optim, osd, load_directly=True 1866 ) 1867 self._check_same_state( 1868 optim.state_dict(), original_osd, check_same_param_keys=True 1869 ) 1870 1871 # Test the default setting. 1872 osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd) 1873 for state in osd["state"].values(): 1874 for s in state.values(): 1875 self.assertFalse(isinstance(s, ShardedTensor)) 1876 self.assertFalse(s.is_cuda) 1877 1878 # Test sharded state_dict without offload_to_cpu 1879 with FSDP.state_dict_type( 1880 model, 1881 StateDictType.SHARDED_STATE_DICT, 1882 ShardedStateDictConfig(), 1883 ShardedOptimStateDictConfig(offload_to_cpu=False), 1884 ): 1885 osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd) 1886 for state in osd["state"].values(): 1887 for s in state.values(): 1888 if s.dim() == 0: 1889 continue 1890 self.assertTrue(isinstance(s, ShardedTensor)) 1891 if s._local_shards[0]: 1892 self.assertTrue(s._local_shards[0].tensor.is_cuda) 1893 1894 # Test full state_dict with rank0_only 1895 with FSDP.state_dict_type( 1896 model, 1897 StateDictType.FULL_STATE_DICT, 1898 FullStateDictConfig(), 1899 FullOptimStateDictConfig( 1900 offload_to_cpu=True, 1901 rank0_only=True, 1902 ), 1903 ): 1904 osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd) 1905 if dist.get_rank() > 0: 1906 self.assertEqual(osd, {}) 1907 else: 1908 for state in osd["state"].values(): 1909 for s in state.values(): 1910 if s.dim() == 0: 1911 continue 1912 self.assertFalse(s.is_cuda) 1913 self.assertFalse(isinstance(s, ShardedTensor)) 1914 1915 @skip_if_lt_x_gpu(2) 1916 def test_state_dict_with_none_tensor_state(self): 1917 def _run_test(use_orig_params, optimizer_has_tensor_state): 1918 model = FSDP(TestDummyModel().cuda(), use_orig_params=use_orig_params) 1919 optimizer_cls = ( 1920 torch.optim.Adam if optimizer_has_tensor_state else torch.optim.SGD 1921 ) 1922 optim = optimizer_cls(model.parameters(), lr=1e-2) 1923 1924 def step(): 1925 loss = model(model.get_input()) 1926 loss.backward(loss) 1927 optim.step() 1928 1929 step() 1930 original_osd = deepcopy(optim.state_dict()) 1931 for state in original_osd["state"].values(): 1932 # Add customized value 1933 state["value1"] = 2.74 1934 state["value2"] = None 1935 1936 osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd) 1937 osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd) 1938 for state in osd_to_load["state"].values(): 1939 self.assertEqual(state["value1"], 2.74) 1940 self.assertEqual(state["value2"], None) 1941 1942 self.run_subtests( 1943 { 1944 "use_orig_params": [False, True], 1945 "optimizer_has_tensor_state": [False, True], 1946 }, 1947 _run_test, 1948 ) 1949 1950 @skip_if_lt_x_gpu(2) 1951 def test_with_no_shard(self): 1952 def _run_test(use_orig_params: bool) -> None: 1953 model = FSDP( 1954 TestDummyModel().cuda(), 1955 sharding_strategy=ShardingStrategy.NO_SHARD, 1956 use_orig_params=use_orig_params, 1957 ) 1958 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 1959 1960 def step(): 1961 loss = model(model.get_input()) 1962 loss.backward(loss) 1963 optim.step() 1964 1965 step() 1966 1967 original_osd = deepcopy(optim.state_dict()) 1968 1969 osd = FSDP.optim_state_dict(model, optim) 1970 osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd) 1971 optim.load_state_dict(osd_to_load) 1972 1973 new_osd = optim.state_dict() 1974 1975 self.assertEqual(original_osd, new_osd) 1976 1977 self.run_subtests({"use_orig_params": [False, True]}, _run_test) 1978 1979 @skip_if_lt_x_gpu(2) 1980 def test_no_grad(self): 1981 model = TestDummyModel(no_grad=True).cuda() 1982 fsdp_model = FSDP(deepcopy(model), use_orig_params=True) 1983 fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) 1984 1985 for i in range(5): 1986 if i % 2 == 1: 1987 fsdp_model.net1[0].weight.requires_grad = True 1988 fsdp_model.net1[0].bias.requires_grad = True 1989 else: 1990 fsdp_model.net1[0].weight.requires_grad = False 1991 fsdp_model.net1[0].bias.requires_grad = False 1992 batch = fsdp_model.get_input() 1993 loss = fsdp_model(batch).sum() 1994 loss.backward() 1995 fsdp_optim.step() 1996 orig_state_dict = deepcopy(fsdp_optim.state_dict()) 1997 optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim) 1998 FSDP.optim_state_dict_to_load( 1999 fsdp_model, 2000 fsdp_optim, 2001 FSDP.optim_state_dict(fsdp_model, fsdp_optim), 2002 load_directly=True, 2003 ) 2004 2005 self._check_same_state( 2006 fsdp_optim.state_dict(), 2007 orig_state_dict, 2008 check_same_param_keys=True, 2009 ) 2010 2011 2012instantiate_parametrized_tests(TestFSDPOptimState) 2013 2014if __name__ == "__main__": 2015 run_tests() 2016