1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5import itertools 6import os 7import sys 8import unittest 9from typing import Any, Dict, List, Optional, Tuple, Type 10 11import torch 12import torch.nn as nn 13from torch import distributed as dist 14from torch.distributed.fsdp import ( 15 BackwardPrefetch, 16 CPUOffload, 17 FullyShardedDataParallel as FSDP, 18 MixedPrecision, 19 ShardingStrategy, 20 StateDictType, 21) 22from torch.distributed.fsdp._common_utils import clean_tensor_name 23from torch.distributed.fsdp._flat_param import ( 24 _FSDP_SKIP_WRITEBACK_CHECK, 25 _FSDP_USE_FULL_PREC_IN_EVAL, 26) 27from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES 28from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy 29from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer 30from torch.nn.parallel.distributed import DistributedDataParallel as DDP 31from torch.testing._internal.common_cuda import TEST_CUDA 32from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 33from torch.testing._internal.common_fsdp import ( 34 CUDAInitMode, 35 FSDPInitMode, 36 FSDPTest, 37 TransformerWithSharedParams, 38) 39from torch.testing._internal.common_utils import ( 40 instantiate_parametrized_tests, 41 parametrize, 42 run_tests, 43 TEST_WITH_DEV_DBG_ASAN, 44 TestCase, 45) 46from torch.utils._triton import has_triton 47 48 49if not dist.is_available(): 50 print("Distributed not available, skipping tests", file=sys.stderr) 51 sys.exit(0) 52 53if TEST_WITH_DEV_DBG_ASAN: 54 print( 55 "Skip dev-asan as torch + multiprocessing spawn have known issues", 56 file=sys.stderr, 57 ) 58 sys.exit(0) 59 60 61class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest): 62 """Tests multiple parameter groups.""" 63 64 @property 65 def world_size(self) -> int: 66 return 2 67 68 def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]: 69 """ 70 Constructs separate parameter groups for weights, biases, and other 71 parameters. 72 """ 73 param_groups = [ 74 {"params": [], "weight_decay": 0.1, "lr": 1e-2}, 75 {"params": [], "weight_decay": 0.01, "lr": 1e-3}, 76 {"params": []}, 77 ] 78 for param_name, param in model.named_parameters(): 79 if "weight" in param_name: 80 param_groups[0]["params"].append(param) 81 elif "bias" in param_name: 82 param_groups[1]["params"].append(param) 83 else: 84 param_groups[2]["params"].append(param) 85 return param_groups 86 87 def _get_optim( 88 self, 89 model: nn.Module, 90 optim_class: Type[torch.optim.Optimizer], 91 multi_tensor: bool, 92 ) -> torch.optim.Optimizer: 93 """ 94 Constructs an Adam optimizer with three parameter groups, one for 95 weights, one for biases, and one for everything else, each with 96 different weight decay and learning rates. 97 """ 98 param_groups = self._get_param_groups(model) 99 return optim_class(param_groups, lr=5e-3, foreach=multi_tensor) 100 101 def _get_ddp_transformer(self, find_unused_params: bool) -> DDP: 102 """Returns a transformer with shared parameters wrapped with DDP.""" 103 model = TransformerWithSharedParams.init( 104 self.process_group, 105 FSDPInitMode.NO_FSDP, 106 CUDAInitMode.CUDA_BEFORE, 107 deterministic=True, 108 ) 109 ddp_model = DDP( 110 model, 111 device_ids=[self.rank], 112 find_unused_parameters=find_unused_params, 113 ) 114 return ddp_model 115 116 def _get_fsdp_transformer_and_optim( 117 self, 118 cuda_init_mode: CUDAInitMode, 119 init_optim_before_wrap: bool, 120 optim_class: Type[torch.optim.Optimizer], 121 multi_tensor: bool, 122 sharding_strategy: ShardingStrategy, 123 backward_prefetch: Optional[BackwardPrefetch], 124 cpu_offload: CPUOffload, 125 ) -> Tuple[FSDP, torch.optim.Optimizer]: 126 """ 127 Returns a transformer with shared parameters wrapped with FSDP and a 128 corresponding optimizer. 129 """ 130 # Each transformer layer has multiple linear layers, so this policy, in 131 # combination with the parameter group construction, ensures different 132 # hyperparameter settings within one `FlatParameter` 133 fsdp_kwargs = { 134 "auto_wrap_policy": ModuleWrapPolicy( 135 { 136 TransformerEncoderLayer, 137 TransformerDecoderLayer, 138 } 139 ), 140 "use_orig_params": True, 141 "sharding_strategy": sharding_strategy, 142 "backward_prefetch": backward_prefetch, 143 "cpu_offload": cpu_offload, 144 } 145 model = TransformerWithSharedParams.init( 146 self.process_group, 147 FSDPInitMode.NO_FSDP, 148 cuda_init_mode, 149 deterministic=True, 150 ) 151 if init_optim_before_wrap: 152 fsdp_optim = self._get_optim(model, optim_class, multi_tensor) 153 fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs) 154 else: 155 fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs) 156 fsdp_optim = self._get_optim(fsdp_model, optim_class, multi_tensor) 157 if ( 158 cuda_init_mode == CUDAInitMode.CUDA_AFTER 159 and not fsdp_model.cpu_offload.offload_params 160 ): 161 fsdp_model = fsdp_model.cuda() 162 return fsdp_model, fsdp_optim 163 164 def _check_train_parity( 165 self, 166 ddp_model: DDP, 167 ddp_optim: torch.optim.Optimizer, 168 fsdp_model: FSDP, 169 fsdp_optim: torch.optim.Optimizer, 170 set_to_none: bool, 171 num_iters: int = 10, 172 ): 173 """Checks training parity between DDP and FSDP.""" 174 device = torch.device("cuda") 175 for i in range(num_iters): 176 iter_losses = [] 177 for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)): 178 module = model.module 179 # Test two different `zero_grad()` timings 180 if i % 2 == 0: 181 optim.zero_grad(set_to_none=set_to_none) # pre-forward 182 inp = module.get_input(device) 183 output = model(*inp) 184 loss = module.get_loss(inp, output).to(device) 185 iter_losses.append(loss) 186 if i % 2 == 1: 187 optim.zero_grad(set_to_none=set_to_none) # pre-backward 188 module.run_backward(loss) 189 # Perform the DDP optimizer step on CPU to match FSDP if needed 190 if model is ddp_model and fsdp_model.cpu_offload.offload_params: 191 model.to(torch.device("cpu")) 192 optim.step() 193 if model is ddp_model and fsdp_model.cpu_offload.offload_params: 194 model.to(device) 195 torch.testing.assert_close(iter_losses[0], iter_losses[1]) 196 iter_losses.clear() 197 self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) 198 199 def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP): 200 with FSDP.summon_full_params(fsdp_model): 201 for (n1, p1), (n2, p2) in zip( 202 ddp_model.module.named_parameters(), fsdp_model.named_parameters() 203 ): 204 # Allow for FSDP prefixes 205 self.assertEqual(n1, clean_tensor_name(n2)) 206 torch.testing.assert_close(p1, p2) 207 208 def _get_sharding_strategy_from_str( 209 self, sharding_strategy_str: str 210 ) -> ShardingStrategy: 211 if sharding_strategy_str == "no_shard": 212 sharding_strategy = ShardingStrategy.NO_SHARD 213 elif sharding_strategy_str == "shard_grad_op": 214 sharding_strategy = ShardingStrategy.SHARD_GRAD_OP 215 elif sharding_strategy_str == "full_shard": 216 sharding_strategy = ShardingStrategy.FULL_SHARD 217 else: 218 raise ValueError(f"Invalid string: {sharding_strategy_str}") 219 return sharding_strategy 220 221 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 222 @skip_if_lt_x_gpu(2) 223 def test_fsdp_compile(self): 224 self.run_subtests( 225 { 226 "sharding_strategy": [ 227 ShardingStrategy.FULL_SHARD, 228 ShardingStrategy.SHARD_GRAD_OP, 229 ShardingStrategy.NO_SHARD, 230 ], 231 "skip_fsdp_guards": [True, False], 232 }, 233 self._test_fsdp_compile, 234 ) 235 236 def _test_fsdp_compile( 237 self, sharding_strategy: ShardingStrategy, skip_fsdp_guards: bool 238 ): 239 torch._dynamo.config.skip_fsdp_guards = skip_fsdp_guards 240 fsdp_kwargs = { 241 "auto_wrap_policy": ModuleWrapPolicy( 242 { 243 TransformerEncoderLayer, 244 TransformerDecoderLayer, 245 } 246 ), 247 "use_orig_params": True, 248 "sharding_strategy": sharding_strategy, 249 "backward_prefetch": BackwardPrefetch.BACKWARD_PRE, 250 "cpu_offload": CPUOffload(False), 251 } 252 base_model = TransformerWithSharedParams.init( 253 self.process_group, 254 FSDPInitMode.NO_FSDP, 255 CUDAInitMode.CUDA_BEFORE, 256 deterministic=True, 257 ) 258 ref_model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs) 259 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 260 model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs) 261 model = torch.compile(model) 262 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 263 for i in range(10): 264 losses = [] 265 inp = ref_model.get_input(torch.device("cuda")) 266 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 267 _optim.zero_grad() 268 loss = _model(*inp).sum() 269 losses.append(loss) 270 loss.backward() 271 _optim.step() 272 self.assertEqual(losses[0], losses[1]) 273 274 @skip_if_lt_x_gpu(2) 275 @parametrize( 276 "sharding_strategy_str", 277 ["no_shard", "shard_grad_op", "full_shard"], 278 ) 279 def test_diff_hyperparams(self, sharding_strategy_str: str): 280 """ 281 Tests FSDP parity with DDP when using multiple parameter groups with 282 different hyperparameter settings. 283 """ 284 sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str) 285 self.run_subtests( 286 { 287 "cuda_init_mode": [ 288 CUDAInitMode.CUDA_BEFORE, 289 CUDAInitMode.CUDA_AFTER, 290 ], 291 "init_optim_before_wrap": [False, True], 292 "optim_class": [torch.optim.AdamW], 293 "multi_tensor": [False, True], 294 "set_to_none": [False, True], 295 "backward_prefetch": [ 296 None, 297 BackwardPrefetch.BACKWARD_PRE, 298 BackwardPrefetch.BACKWARD_POST, 299 ], 300 "skip_writeback_check": [False, True], 301 }, 302 self._test_diff_hyperparams, 303 cpu_offload=CPUOffload(offload_params=False), 304 sharding_strategy=sharding_strategy, 305 ) 306 307 @skip_if_lt_x_gpu(2) 308 @parametrize( 309 "sharding_strategy_str", 310 ["no_shard", "shard_grad_op", "full_shard"], 311 ) 312 def test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str): 313 """ 314 Tests FSDP parity with DDP when using multiple parameter groups with 315 different hyperparameter settings with CPU offloading enabled. This is 316 separate from :meth:`test_diff_hyperparams` because CPU offloading has 317 some issues with subtesting for some specific subtesting configs (e.g., 318 with ``offload_params=False`` followed by ``True`` but not vice versa). 319 """ 320 sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str) 321 for skip_writeback_check in (False, True): 322 self._test_diff_hyperparams( 323 cuda_init_mode=CUDAInitMode.CUDA_BEFORE, 324 init_optim_before_wrap=False, 325 optim_class=torch.optim.Adam, 326 multi_tensor=False, 327 set_to_none=False, 328 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 329 cpu_offload=CPUOffload(offload_params=True), 330 sharding_strategy=sharding_strategy, 331 skip_writeback_check=skip_writeback_check, 332 ) 333 334 def _test_diff_hyperparams( 335 self, 336 cuda_init_mode: CUDAInitMode, 337 init_optim_before_wrap: bool, 338 optim_class: Type[torch.optim.Optimizer], 339 multi_tensor: bool, 340 set_to_none: bool, 341 backward_prefetch: Optional[BackwardPrefetch], 342 cpu_offload: CPUOffload, 343 sharding_strategy: ShardingStrategy, 344 skip_writeback_check: bool, 345 ): 346 """ 347 Args: 348 init_optim_before_wrap (bool): If ``True``, initializes the 349 FSDP optimizer before wrapping the model with FSDP; otherwise, 350 initializes the FSDP optimizer after wrapping the model with 351 FSDP. We permit both forms of initialization to give users 352 flexibility. 353 """ 354 if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: 355 return # not supported 356 if skip_writeback_check: 357 os.environ[_FSDP_SKIP_WRITEBACK_CHECK] = "1" 358 ddp_model = self._get_ddp_transformer(find_unused_params=False) 359 ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor) 360 fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim( 361 cuda_init_mode=cuda_init_mode, 362 init_optim_before_wrap=init_optim_before_wrap, 363 optim_class=optim_class, 364 multi_tensor=multi_tensor, 365 sharding_strategy=sharding_strategy, 366 backward_prefetch=backward_prefetch, 367 cpu_offload=cpu_offload, 368 ) 369 self._check_train_parity( 370 ddp_model, ddp_optim, fsdp_model, fsdp_optim, set_to_none 371 ) 372 373 @skip_if_lt_x_gpu(2) 374 def test_diff_trainability(self): 375 """ 376 Tests FSDP parity with DDP when using multiple parameter groups and 377 freezing the parameters in one parameter group. 378 """ 379 self.run_subtests( 380 { 381 "multi_tensor": [False, True], 382 "sharding_strategy": [ 383 ShardingStrategy.FULL_SHARD, 384 ShardingStrategy.SHARD_GRAD_OP, 385 ShardingStrategy.NO_SHARD, 386 ], 387 }, 388 self._test_diff_trainability, 389 ) 390 391 def _test_diff_trainability( 392 self, 393 multi_tensor: bool, 394 sharding_strategy: ShardingStrategy, 395 ): 396 optim_class = torch.optim.Adam 397 ddp_model = self._get_ddp_transformer(find_unused_params=True) 398 ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor) 399 fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim( 400 cuda_init_mode=CUDAInitMode.CUDA_BEFORE, 401 init_optim_before_wrap=False, 402 optim_class=optim_class, 403 multi_tensor=multi_tensor, 404 sharding_strategy=sharding_strategy, 405 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 406 cpu_offload=None, 407 ) 408 # Freeze all biases (which happen to be in the same parameter group) 409 for param_name, param in ddp_model.named_parameters(): 410 if "bias" in param_name: 411 param.requires_grad_(False) 412 for param_name, param in fsdp_model.named_parameters(): 413 if "bias" in param_name: 414 param.requires_grad_(False) 415 self._check_train_parity(ddp_model, ddp_optim, fsdp_model, fsdp_optim, False) 416 417 @skip_if_lt_x_gpu(2) 418 def test_multiple_optimizers(self): 419 """ 420 Tests using two optimizers where only one sets gradients to ``None``. 421 """ 422 self.run_subtests( 423 { 424 "sharding_strategy": [ 425 ShardingStrategy.FULL_SHARD, 426 ShardingStrategy.SHARD_GRAD_OP, 427 ] 428 }, 429 self._test_multiple_optimizers, 430 ) 431 432 def _test_multiple_optimizers(self, sharding_strategy: ShardingStrategy): 433 ddp_model = self._get_ddp_transformer(find_unused_params=True) 434 ddp_param_groups = self._get_param_groups(ddp_model) 435 assert len(ddp_param_groups) == 3, f"{len(ddp_param_groups)}" 436 ( 437 fsdp_model, 438 _, 439 ) = self._get_fsdp_transformer_and_optim( # ignore returned optimizer 440 cuda_init_mode=CUDAInitMode.CUDA_BEFORE, 441 init_optim_before_wrap=False, 442 optim_class=torch.optim.Adam, # ignored 443 multi_tensor=False, # ignored 444 sharding_strategy=sharding_strategy, 445 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 446 cpu_offload=None, 447 ) 448 fsdp_param_groups = self._get_param_groups(fsdp_model) 449 assert len(fsdp_param_groups) == 3, f"{len(fsdp_param_groups)}" 450 ddp_optims = [] 451 fsdp_optims = [] 452 # For the transformer model, every parameter is either a weight or a 453 # bias, so we only use the first two parameter groups. Moreover, we use 454 # Adam and AdamW in particular since they both use bias correction 455 # dependent on the step, which is incremented even if a parameter has a 456 # zero gradient but not if the gradient is `None`. This is to test that 457 # we are differentiating between a zero and `None` gradient correctly. 458 optim_ctors = [ 459 functools.partial(torch.optim.Adam, lr=5e-3), 460 functools.partial(torch.optim.AdamW, lr=1e-2), 461 ] 462 463 for optim_ctor, ddp_param_group, fsdp_param_group in zip( 464 optim_ctors, 465 ddp_param_groups[:2], 466 fsdp_param_groups[:2], 467 ): 468 ddp_optims.append(optim_ctor(ddp_param_group["params"])) 469 fsdp_optims.append(optim_ctor(fsdp_param_group["params"])) 470 device = torch.device("cuda") 471 472 # Check that there exists a `FlatParameter` that has both a weight and 473 # a bias in this rank's shard 474 has_both = False 475 for fsdp_module in FSDP.fsdp_modules(fsdp_model): 476 handle = fsdp_module._handle 477 if not handle: 478 continue 479 flat_param = handle.flat_param 480 assert flat_param._params is not None 481 has_weight = False 482 has_bias = False 483 for param, fqn in zip(flat_param._params, flat_param._fqns): 484 if "weight" in fqn and param.numel() > 0: 485 has_weight = True 486 elif "bias" in fqn and param.numel() > 0: 487 has_bias = True 488 has_both |= has_weight and has_bias 489 assert has_both, ( 490 f"Rank {self.rank} does not have a `FlatParameter` with both a " 491 "weight and a bias in its shard, meaning that this test is vacuous" 492 ) 493 494 # Run one iteration to generate gradients 495 def run_iter(): 496 iter_losses = [] 497 for model, optims in ((ddp_model, ddp_optims), (fsdp_model, fsdp_optims)): 498 module = model.module 499 inp = module.get_input(device) 500 output = model(*inp) 501 loss = module.get_loss(inp, output).to(device) 502 iter_losses.append(loss) 503 module.run_backward(loss) 504 for optim in optims: 505 optim.step() 506 torch.testing.assert_close(iter_losses[0], iter_losses[1]) 507 iter_losses.clear() 508 self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) 509 510 run_iter() 511 512 # Only set the weights' gradients to None 513 ddp_optims[0].zero_grad(set_to_none=True) 514 fsdp_optims[0].zero_grad(set_to_none=True) 515 inp = ddp_model.module.get_input(device) 516 ddp_output = ddp_model(*inp) 517 fsdp_output = fsdp_model(*inp) 518 519 # Check that FSDP correctly exposes gradients even after forward 520 # (namely, `None` for weights and non-`None` for biases) 521 if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES: 522 # Skip the check since we do not expose the gradients after forward 523 # for these strategies 524 return 525 for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip( 526 ddp_model.module.named_parameters(), 527 fsdp_model.named_parameters(), 528 ): 529 self.assertEqual(ddp_n, clean_tensor_name(fsdp_n)) 530 if fsdp_p.numel() == 0: 531 # Not in this rank's shard 532 self.assertTrue(fsdp_p.grad is None) 533 continue 534 if ddp_p.grad is None: 535 self.assertTrue(fsdp_p.grad is None) 536 else: 537 self.assertEqual(ddp_p.flatten(), fsdp_p.flatten()) 538 self.assertEqual(ddp_p.grad.flatten(), fsdp_p.grad.flatten()) 539 self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) 540 541 # Finish the iteration (backward pass and optimizer step) 542 ddp_loss = ddp_model.module.get_loss(inp, ddp_output).to(device) 543 fsdp_loss = fsdp_model.module.get_loss(inp, fsdp_output).to(device) 544 ddp_model.module.run_backward(ddp_loss) 545 fsdp_model.module.run_backward(fsdp_loss) 546 for optim in itertools.chain(ddp_optims, fsdp_optims): 547 optim.step() 548 self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) 549 550 # Run one more iteration to confirm bias corrections are correct 551 run_iter() 552 self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) 553 554 555class TestFSDPUseOrigParamsUnshardReshard(FSDPTest): 556 """Tests the unshard/reshard flow.""" 557 558 @property 559 def world_size(self) -> int: 560 return 2 561 562 def _get_fsdp_models_and_optims( 563 self, 564 sharding_strategy: ShardingStrategy, 565 cpu_offload: CPUOffload, 566 ) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]: 567 """ 568 Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False`` 569 and ``True``, respectively. 570 """ 571 LR = 1e-2 572 fsdp_kwargs = { 573 "sharding_strategy": sharding_strategy, 574 "cpu_offload": cpu_offload, 575 "use_orig_params": False, 576 } 577 fsdp_model = TransformerWithSharedParams.init( 578 self.process_group, 579 FSDPInitMode.RECURSIVE, 580 CUDAInitMode.CUDA_BEFORE, 581 fsdp_kwargs=fsdp_kwargs, 582 deterministic=True, 583 ) 584 optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR) 585 fsdp_kwargs["use_orig_params"] = True 586 fsdp_model_orig_params = TransformerWithSharedParams.init( 587 self.process_group, 588 FSDPInitMode.RECURSIVE, 589 CUDAInitMode.CUDA_BEFORE, 590 fsdp_kwargs=fsdp_kwargs, 591 deterministic=True, 592 ) 593 optim_orig_params = torch.optim.Adam( 594 fsdp_model_orig_params.parameters(), foreach=False, lr=LR 595 ) 596 return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params 597 598 def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None: 599 """Checks that two FSDP instances have the same model parameters.""" 600 with FSDP.summon_full_params(fsdp1), FSDP.summon_full_params(fsdp2): 601 for (n1, p1), (n2, p2) in zip( 602 fsdp1.named_parameters(), 603 fsdp2.named_parameters(), 604 ): 605 self.assertEqual(n1, n2) 606 torch.testing.assert_close(p1, p2) 607 608 def _get_fsdp_parity_subtest_config(self): 609 return { 610 "sharding_strategy": [ 611 ShardingStrategy.NO_SHARD, 612 ShardingStrategy.SHARD_GRAD_OP, 613 ShardingStrategy.FULL_SHARD, 614 ], 615 } 616 617 @skip_if_lt_x_gpu(2) 618 @parametrize("offload_params", [False, True]) 619 def test_multiple_forward(self, offload_params: bool): 620 """ 621 Tests that ``use_orig_params=True`` has parity with ``False`` when 622 running multiple forward passes before a backward pass. 623 """ 624 cpu_offload = CPUOffload(offload_params=offload_params) 625 self.run_subtests( 626 self._get_fsdp_parity_subtest_config(), 627 self._test_multiple_forward, 628 cpu_offload=cpu_offload, 629 ) 630 631 @skip_if_lt_x_gpu(2) 632 def _test_multiple_forward( 633 self, 634 sharding_strategy: ShardingStrategy, 635 cpu_offload: CPUOffload, 636 ): 637 ( 638 fsdp_model, 639 optim, 640 fsdp_model_orig_params, 641 optim_orig_params, 642 ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload) 643 device = torch.device("cuda") 644 for _ in range(3): 645 inp1 = fsdp_model.get_input(device) 646 _inp2 = fsdp_model.get_input(device) 647 inp2 = tuple( 648 t + torch.ones_like(t) for t in _inp2 649 ) # make different from `inp1` 650 # For these loss lists: elem 0 is baseline; elem 1 is test 651 losses1 = [] 652 losses2 = [] 653 losses = [] 654 for _model, _optim in (fsdp_model, optim), ( 655 fsdp_model_orig_params, 656 optim_orig_params, 657 ): 658 _optim.zero_grad() 659 loss1 = _model(*inp1) 660 losses1.append(loss1) 661 loss2 = _model(*inp2) 662 losses2.append(loss2) 663 loss = (loss1 + loss2).sum() 664 losses.append(loss) 665 _model.run_backward(loss) 666 _optim.step() 667 self.assertEqual(losses1[0], losses1[1]) 668 self.assertEqual(losses2[0], losses2[1]) 669 self.assertEqual(losses[0], losses[1]) 670 self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params) 671 672 @skip_if_lt_x_gpu(2) 673 @parametrize("offload_params", [False, True]) 674 def test_summon_between_two_forwards(self, offload_params: bool): 675 """ 676 Tests that ``use_orig_params=True`` has parity with ``False`` when 677 running a forward pass, :meth:`summon_full_params()`, and another 678 forward pass before a backward pass. 679 """ 680 cpu_offload = CPUOffload(offload_params=offload_params) 681 self.run_subtests( 682 self._get_fsdp_parity_subtest_config(), 683 self._test_summon_between_two_forwards, 684 cpu_offload=cpu_offload, 685 ) 686 687 def _test_summon_between_two_forwards( 688 self, 689 sharding_strategy: ShardingStrategy, 690 cpu_offload: CPUOffload, 691 ): 692 ( 693 fsdp_model, 694 optim, 695 fsdp_model_orig_params, 696 optim_orig_params, 697 ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload) 698 device = torch.device("cuda") 699 for _ in range(3): 700 optim.zero_grad() 701 optim_orig_params.zero_grad() 702 703 inp1 = fsdp_model.get_input(device) 704 loss1 = fsdp_model(*inp1) 705 loss_orig_params1 = fsdp_model_orig_params(*inp1) 706 self.assertEqual(loss1, loss_orig_params1) 707 708 # Calls into `summon_full_params()` 709 self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params) 710 711 inp2 = fsdp_model.get_input(device) 712 loss2 = fsdp_model(*inp2) 713 loss_orig_params2 = fsdp_model_orig_params(*inp2) 714 self.assertEqual(loss2, loss_orig_params2) 715 716 loss = (loss1 + loss2).sum() 717 loss_orig_params = (loss_orig_params1 + loss_orig_params2).sum() 718 fsdp_model.run_backward(loss) 719 fsdp_model_orig_params.run_backward(loss_orig_params) 720 optim.step() 721 optim_orig_params.step() 722 self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params) 723 724 725class TestFSDPUseOrigParamsParamAccess(FSDPTest): 726 """Tests original parameter access.""" 727 728 @property 729 def world_size(self): 730 # Force a world size of 2 since the tests hard code to the FSDP 731 # sharding strategy to check sharded parameter parity 732 return 2 733 734 @skip_if_lt_x_gpu(2) 735 def test_access_params_after_forward(self): 736 """ 737 Tests that accessing the original parameters after the forward but 738 before the backward. Notably, this is not supported when 739 ``use_orig_params=False``. However, for ``True``, FSDP exposes the 740 (flattened) sharded original parameters, making it possible. 741 """ 742 self.run_subtests( 743 { 744 "sharding_strategy": [ 745 ShardingStrategy.NO_SHARD, 746 ShardingStrategy.FULL_SHARD, 747 ShardingStrategy.SHARD_GRAD_OP, 748 ], 749 }, 750 self._test_access_params_after_forward, 751 ) 752 753 def _test_access_params_after_forward( 754 self, 755 sharding_strategy: ShardingStrategy, 756 ): 757 # NOTE: This test needs to be changed if the FSDP sharding algorithm 758 # changes. It is still valuable until such a change to sanity check the 759 # `use_orig_params=True` implementation. 760 class Model(nn.Module): 761 def __init__(self) -> None: 762 super().__init__() 763 torch.manual_seed(42) 764 # 5 * 5 = 25 numel -> pad to 26 -> 13 on each rank 765 self.lin1 = nn.Linear(5, 5, bias=False) 766 # 5 * 7 + (1) + 7 = 43 numel -> pad to 44 -> 22 on each rank, 767 # where the (1) is from intra-`FlatParameter` alignment padding 768 # 22 of weight on rank 0; 13 of weight, 1 alignment padding, 769 # and 7 of bias on rank 1 770 self.lin2 = nn.Linear(5, 7) 771 772 def forward(self, x: torch.Tensor) -> torch.Tensor: 773 z = self.lin1(x) 774 z = nn.functional.relu(z) 775 z = self.lin2(z) 776 return z 777 778 def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]: 779 return (torch.randn((2, 5)).to(device),) 780 781 def get_loss(self, inp, out): 782 return out.sum() 783 784 def check_parameter_parity( 785 ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool 786 ): 787 assert self.rank in ( 788 0, 789 1, 790 ), f"Expects world size of 2 but got {self.world_size}" 791 for (n1, p1), (n2, p2) in zip( 792 ddp_model.module.named_parameters(), 793 fsdp_model.named_parameters(), 794 ): 795 self.assertEqual(n1, clean_tensor_name(n2)) 796 if sharding_strategy == ShardingStrategy.NO_SHARD: 797 # For `NO_SHARD`, do nothing since the original parameters 798 # are unflattened 799 pass 800 elif ( 801 between_fwd_and_bwd 802 and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES 803 ): 804 # For no reshard after forward strategies, do nothing since 805 # FSDP did not use sharded views after forward 806 pass 807 # Otherwise, case on the parameter (see the model definition) 808 elif n1 == "lin1.weight": 809 if self.rank == 0: 810 p1 = p1.flatten()[:13] 811 elif self.rank == 1: 812 p1 = p1.flatten()[13:] 813 elif n1 == "lin2.weight": 814 if self.rank == 0: 815 p1 = p1.flatten()[:22] 816 elif self.rank == 1: 817 p1 = p1.flatten()[22:] 818 elif n1 == "lin2.bias": 819 if self.rank == 0: 820 p1 = torch.empty(0, device=p1.device) 821 elif self.rank == 1: 822 p1 = p1.flatten() 823 torch.testing.assert_close(p1, p2) 824 825 ddp_model = DDP(Model().cuda(), device_ids=[self.rank]) 826 fsdp_model = FSDP( 827 Model().cuda(), 828 sharding_strategy=sharding_strategy, 829 auto_wrap_policy=always_wrap_policy, 830 use_orig_params=True, 831 ) 832 LR = 1e-2 833 ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) 834 fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR) 835 device = torch.device("cuda") 836 837 inp = fsdp_model.get_input(device) 838 ddp_out = ddp_model(*inp) 839 fsdp_out = fsdp_model(*inp) 840 check_parameter_parity(ddp_model, fsdp_model, True) 841 842 ddp_loss = ddp_model.module.get_loss(inp, ddp_out) 843 fsdp_loss = fsdp_model.get_loss(inp, fsdp_out) 844 ddp_loss.backward() 845 fsdp_loss.backward() 846 ddp_optim.step() 847 fsdp_optim.step() 848 check_parameter_parity(ddp_model, fsdp_model, False) 849 850 inp = fsdp_model.get_input(device) 851 ddp_out = ddp_model(*inp) 852 fsdp_out = fsdp_model(*inp) 853 check_parameter_parity(ddp_model, fsdp_model, True) 854 855 856class TestFSDPUseOrigParamsWriteback(FSDPTest): 857 """Tests parameter and gradient writeback.""" 858 859 class Model(nn.Module): 860 def __init__(self, device: torch.device): 861 super().__init__() 862 torch.manual_seed(42) 863 self.lin1 = nn.Linear(5, 5, bias=True, device=device) 864 self.lin2 = nn.Linear(5, 7, bias=True, device=device) 865 866 def forward(self, x: torch.Tensor) -> torch.Tensor: 867 z = self.lin1(x) 868 z = nn.functional.relu(z) 869 z = self.lin2(z) 870 return z 871 872 def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]: 873 return (torch.randn((2, 5)).to(device),) 874 875 def get_loss(self, inp, out): 876 return out.sum() 877 878 @property 879 def world_size(self): 880 # Force a world size of 2 since the tests hard code to the FSDP 881 # sharding strategy 882 return 2 883 884 def _check_param_parity(self, ddp_model: DDP, fsdp_model: FSDP): 885 with FSDP.summon_full_params(fsdp_model): 886 for (n1, p1), (n2, p2) in zip( 887 ddp_model.module.named_parameters(), 888 fsdp_model.named_parameters(), 889 ): 890 self.assertEqual(n1, n2) 891 torch.testing.assert_close(p1, p2) 892 893 @skip_if_lt_x_gpu(2) 894 def test_param_writeback(self): 895 """Tests that changes to the original parameters are written back.""" 896 self.run_subtests( 897 { 898 "change_first_weight": [True, False], # first vs. second `weight` 899 "change_data": [True, False], # change `.data` vs. variable itself 900 }, 901 self._test_param_writeback, 902 ) 903 904 def _test_param_writeback(self, change_first_weight: bool, change_data: bool): 905 def transform_param(param: nn.Parameter) -> nn.Parameter: 906 return nn.Parameter(torch.ones_like(param) * 2) 907 908 # Check that the writeback propagates 909 ddp_model = DDP( 910 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), 911 device_ids=[self.rank], 912 ) 913 fsdp_model = FSDP( 914 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), 915 use_orig_params=True, 916 ) 917 ddp = ddp_model.module # for brevity 918 fsdp = fsdp_model.module 919 if change_first_weight: 920 if change_data: 921 ddp.lin1.weight.data = transform_param(ddp.lin1.weight) 922 fsdp.lin1.weight.data = transform_param(fsdp.lin1.weight) 923 else: 924 ddp.lin1.weight = transform_param(ddp.lin1.weight) 925 fsdp.lin1.weight = transform_param(fsdp.lin1.weight) 926 else: 927 if change_data: 928 ddp.lin2.weight.data = transform_param(ddp.lin2.weight) 929 fsdp.lin2.weight.data = transform_param(fsdp.lin2.weight) 930 else: 931 ddp.lin2.weight = transform_param(ddp.lin2.weight) 932 fsdp.lin2.weight = transform_param(fsdp.lin2.weight) 933 self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback 934 935 @skip_if_lt_x_gpu(2) 936 def test_grad_writeback(self): 937 """ 938 Tests that changes to the original parameters' gradients are written 939 back. 940 """ 941 self.run_subtests( 942 { 943 "change_first_weight_grad": [False, True], 944 "change_data": [False, True], # change `.data` vs. variable itself 945 "set_to_none": [False, True], 946 }, 947 self._test_grad_writeback, 948 ) 949 950 def _test_grad_writeback( 951 self, 952 change_first_weight_grad: bool, 953 change_data: bool, 954 set_to_none: bool, 955 ): 956 if change_data and set_to_none: 957 return # not well-defined 958 959 def transform_grad(param: nn.Parameter) -> nn.Parameter: 960 return None if set_to_none else torch.ones_like(param) * 2 961 962 ddp_model = DDP( 963 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), 964 device_ids=[self.rank], 965 ) 966 fsdp_model = FSDP( 967 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), 968 use_orig_params=True, 969 ) 970 LR = 1e-2 971 # TODO: If we add `summon_full_params(with_grads=True)`, then replace 972 # the following. For now, we use the optimizer step as a surrogate for 973 # checking that gradients were written back. 974 ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) 975 fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR) 976 977 # Generate an initial gradient 978 inp = fsdp_model.get_input(torch.device("cuda")) 979 ddp_out = ddp_model(*inp) 980 fsdp_out = fsdp_model(*inp) 981 ddp_out.sum().backward() 982 fsdp_out.sum().backward() 983 984 # Change the gradient through the original parameters 985 ddp = ddp_model.module # for brevity 986 fsdp = fsdp_model.module 987 if change_first_weight_grad: 988 if change_data: 989 ddp.lin1.weight.grad.data = transform_grad(ddp.lin1.weight) 990 if fsdp.lin1.weight.grad is not None: 991 fsdp.lin1.weight.grad.data = transform_grad(fsdp.lin1.weight) 992 else: 993 ddp.lin1.weight.grad = transform_grad(ddp.lin1.weight) 994 fsdp.lin1.weight.grad = transform_grad(fsdp.lin1.weight) 995 else: 996 if change_data: 997 ddp.lin2.weight.grad.data = transform_grad(ddp.lin2.weight) 998 if fsdp.lin2.weight.grad is not None: 999 fsdp.lin2.weight.grad.data = transform_grad(fsdp.lin2.weight) 1000 else: 1001 ddp.lin2.weight.grad = transform_grad(ddp.lin2.weight) 1002 fsdp.lin2.weight.grad = transform_grad(fsdp.lin2.weight) 1003 ddp_optim.step() 1004 fsdp_optim.step() 1005 self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback 1006 1007 # Intentionally do not zero the gradient to check writeback 1008 inp = fsdp_model.get_input(torch.device("cuda")) 1009 ddp_out = ddp_model(*inp) 1010 fsdp_out = fsdp_model(*inp) 1011 ddp_out.sum().backward() 1012 fsdp_out.sum().backward() 1013 ddp_optim.step() 1014 fsdp_optim.step() 1015 self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback 1016 1017 @skip_if_lt_x_gpu(2) 1018 def test_writeback_shape_mismatch(self): 1019 fsdp_model = FSDP( 1020 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), 1021 use_orig_params=True, 1022 ) 1023 # Check that writing back with mismatched shape errors 1024 fsdp = fsdp_model.module # for brevity 1025 assert self.rank in (0, 1), f"Expects world size of 2 but got {self.world_size}" 1026 with self.assertRaisesRegex(RuntimeError, "Cannot writeback"): 1027 # Change the gradient to a new one with 1 added to each dimension 1028 # to force a shape mismatch when writing back 1029 if self.rank == 0: 1030 # Change `lin1.weight.grad` since it exists on rank 0 1031 lin1_weight_shape = list(fsdp.lin1.weight.shape) 1032 for dim_index in range(len(lin1_weight_shape)): 1033 lin1_weight_shape[dim_index] += 1 1034 fsdp.lin1.weight = nn.Parameter( 1035 torch.randn( 1036 torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device 1037 ) 1038 ) 1039 fsdp.lin1.weight.grad = torch.randn( 1040 torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device 1041 ) 1042 elif self.rank == 1: 1043 # Change `lin2.weight.grad` since it exists (partially) on rank 1 1044 lin2_weight_shape = list(fsdp.lin2.weight.shape) 1045 for dim_index in range(len(lin2_weight_shape)): 1046 lin2_weight_shape[dim_index] += 1 1047 fsdp.lin2.weight = nn.Parameter( 1048 torch.randn( 1049 torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device 1050 ) 1051 ) 1052 fsdp.lin2.weight.grad = torch.randn( 1053 torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device 1054 ) 1055 with FSDP.summon_full_params(fsdp_model): # triggers a writeback 1056 ... 1057 1058 @skip_if_lt_x_gpu(2) 1059 def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self): 1060 fsdp_kwargs = { 1061 "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, 1062 "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}), 1063 "use_orig_params": True, 1064 } 1065 fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs) 1066 1067 # Test changing the parameter storage to no longer be a view into the 1068 # flat parameter 1069 fsdp_model = fsdp_wrapper( 1070 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")) 1071 ) 1072 inp = fsdp_model.get_input(torch.device("cuda")) 1073 loss = fsdp_model(*inp).sum() 1074 fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone() 1075 assert_msg = ( 1076 "FSDP does not support changing the parameters between forward and backward" 1077 ) 1078 with self.assertRaisesRegex(AssertionError, assert_msg): 1079 loss.backward() 1080 1081 # Test changing the parameter variable itself 1082 fsdp_model = fsdp_wrapper( 1083 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")) 1084 ) 1085 inp = fsdp_model.get_input(torch.device("cuda")) 1086 loss = fsdp_model(*inp).sum() 1087 fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter( 1088 fsdp_model.lin1.weight.clone() 1089 ) 1090 with self.assertRaisesRegex(AssertionError, assert_msg): 1091 loss.backward() 1092 1093 @skip_if_lt_x_gpu(2) 1094 def test_no_reshard_and_mixed_precision(self): 1095 """ 1096 Tests that writeback does not falsely get triggered for a few 1097 configurations (exercising the sharded view skipping logic): 1098 - Train forward -> full-precision unshard -> train forward 1099 - Train forward -> eval forward 1100 - Train forward/backward -> eval forward -> model checkpoint 1101 """ 1102 self.run_subtests( 1103 {"use_full_prec_in_eval": [False, True]}, 1104 self._test_no_reshard_and_mixed_precision, 1105 ) 1106 1107 def _test_no_reshard_and_mixed_precision(self, use_full_prec_in_eval: bool): 1108 if use_full_prec_in_eval: 1109 os.environ[_FSDP_USE_FULL_PREC_IN_EVAL] = "1" 1110 fsdp_kwargs = { 1111 "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, 1112 "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}), 1113 "mixed_precision": MixedPrecision(param_dtype=torch.float16), 1114 "use_orig_params": True, 1115 } 1116 1117 # Train forward -> full-precision unshard -> train forward 1118 fsdp_model = FSDP( 1119 TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs 1120 ) 1121 inp = fsdp_model.get_input(torch.device("cuda")) 1122 fsdp_model(*inp) 1123 with FSDP.summon_full_params(fsdp_model): 1124 ... 1125 fsdp_model(*inp).sum() 1126 1127 # Train forward -> eval forward 1128 fsdp_model.train() 1129 fsdp_model(*inp) 1130 fsdp_model.eval() 1131 fsdp_model(*inp) 1132 1133 # Train forward/backward -> eval forward -> model checkpoint 1134 fsdp_model.train() 1135 fsdp_model(*inp).sum().backward() 1136 fsdp_model.eval() 1137 fsdp_model(*inp) 1138 with FSDP.state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT): 1139 sd = fsdp_model.state_dict() 1140 fsdp_model.load_state_dict(sd) 1141 fsdp_model(*inp).sum().backward() 1142 1143 1144class TestFSDPUseOrigParamsFQNs(FSDPTest): 1145 @skip_if_lt_x_gpu(2) 1146 def test_named_parameters_in_forward(self): 1147 """ 1148 Tests that calling ``named_parameters()`` during forward returns FQNs 1149 and ``Tensor`` s corresponding to the original parameters. 1150 """ 1151 param_shapes = [None, None] 1152 assert_equal_fn = self.assertEqual 1153 1154 class Model(nn.Module): 1155 def __init__(self) -> None: 1156 super().__init__() 1157 self.lin = nn.Linear(5, 5) 1158 1159 def forward(self, x: torch.Tensor) -> torch.Tensor: 1160 nonlocal param_shapes 1161 # Allow for FSDP prefixes 1162 param_names = [ 1163 clean_tensor_name(tup[0]) for tup in self.named_parameters() 1164 ] 1165 params = [tup[1] for tup in self.named_parameters()] 1166 assert ( 1167 param_shapes[0] is not None and param_shapes[1] is not None 1168 ), "`param_sizes` should be set" 1169 assert_equal_fn( 1170 param_names, 1171 [ 1172 "lin.weight", 1173 "lin.bias", 1174 ], 1175 ) 1176 assert_equal_fn(params[0].shape, param_shapes[0]) 1177 assert_equal_fn(params[1].shape, param_shapes[1]) 1178 return self.lin(x) 1179 1180 model = Model().cuda() 1181 # Save the *unsharded* original parameter shapes and check the shapes 1182 # match in the forward pass 1183 param_shapes[0] = model.lin.weight.shape 1184 param_shapes[1] = model.lin.bias.shape 1185 fsdp_model = FSDP(model, use_orig_params=True) 1186 inp = torch.randn((2, 5), device=torch.device("cuda")) 1187 fsdp_model(inp) 1188 1189 1190class TestFSDPUseOrigParamsNoSync(FSDPTest): 1191 @property 1192 def world_size(self) -> int: 1193 return 2 1194 1195 @skip_if_lt_x_gpu(2) 1196 def test_no_sync_correctness(self): 1197 """ 1198 Tests a basic ``no_sync()`` setup by comparing ``use_orig_params=True`` 1199 against ``use_orig_params=False``. 1200 """ 1201 self.run_subtests( 1202 { 1203 "sharding_strategy": [ 1204 ShardingStrategy.FULL_SHARD, 1205 ShardingStrategy.SHARD_GRAD_OP, 1206 ShardingStrategy.NO_SHARD, 1207 ], 1208 }, 1209 self._test_no_sync_correctness, 1210 ) 1211 1212 def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy): 1213 model = nn.Linear(7, 1, bias=False, device="cuda") 1214 fsdp_kwargs = { 1215 "sharding_strategy": sharding_strategy, 1216 } 1217 model_use_flat_params = FSDP( 1218 copy.deepcopy(model), use_orig_params=False, **fsdp_kwargs 1219 ) 1220 model_use_orig_params = FSDP(model, use_orig_params=True, **fsdp_kwargs) 1221 optim_use_flat_params = torch.optim.AdamW( 1222 model_use_flat_params.parameters(), foreach=True 1223 ) 1224 optim_use_orig_params = torch.optim.AdamW( 1225 model_use_orig_params.parameters(), foreach=True 1226 ) 1227 1228 def _check_param_grad_parity( 1229 _baseline_model: nn.Module, 1230 _test_model: nn.Module, 1231 ): 1232 """ 1233 This assumes that the model is ``nn.Linear(7, 1, bias=False)`` 1234 (i.e. with a single 1D weight parameter) to be able to directly 1235 compare the baseline and test models. On rank 1, the baseline 1236 includes 1 element of padding. 1237 """ 1238 self.assertEqual(len(list(_baseline_model.parameters())), 1) 1239 self.assertEqual(len(list(_test_model.parameters())), 1) 1240 for flat_param, orig_param in zip( 1241 _baseline_model.parameters(), _test_model.parameters() 1242 ): 1243 # Baseline is permitted to have padding 1244 self.assertGreaterEqual(flat_param.numel(), orig_param.numel()) 1245 unpadded_param_numel = orig_param.numel() 1246 # For `NO_SHARD`, `use_orig_params=True` presents unflattened 1247 # parameters, while `False` presents flattened ones 1248 torch.testing.assert_close( 1249 flat_param[:unpadded_param_numel], orig_param.flatten() 1250 ) 1251 # Gradient numel is different if right after `no_sync()` since 1252 # the gradient is unsharded, while the parameter is sharded 1253 unpadded_grad_numel = orig_param.grad.numel() 1254 # For `use_orig_params=False`, the unsharded gradient is 1255 # flattened, while for `True`, it is unflattened 1256 torch.testing.assert_close( 1257 flat_param.grad[:unpadded_grad_numel].reshape( 1258 orig_param.grad.shape 1259 ), 1260 orig_param.grad, 1261 ) 1262 1263 inp = torch.randn((2, 7), device="cuda") 1264 grad = torch.randn((2, 1), device="cuda") 1265 1266 # Compute some reference gradients using one forward/backward 1267 out_use_flat_params = model_use_flat_params(inp) 1268 out_use_orig_params = model_use_orig_params(inp) 1269 torch.testing.assert_close(out_use_flat_params, out_use_orig_params) 1270 out_use_flat_params.backward(grad) 1271 out_use_orig_params.backward(grad) 1272 _check_param_grad_parity(model_use_flat_params, model_use_orig_params) 1273 ref_grads_use_flat_params = [ 1274 param.grad.detach().clone() for param in model_use_flat_params.parameters() 1275 ] 1276 ref_grads_use_orig_params = [ 1277 param.grad.detach().clone() 1278 for param in model_use_orig_params.parameters() 1279 if param.grad is not None 1280 ] 1281 1282 # Run a forward/backward in `no_sync()` 1283 optim_use_flat_params.zero_grad(set_to_none=True) 1284 optim_use_orig_params.zero_grad(set_to_none=True) 1285 for model in (model_use_flat_params, model_use_orig_params): 1286 with model.no_sync(): 1287 out = model(inp) 1288 out.backward(grad) 1289 _check_param_grad_parity(model_use_flat_params, model_use_orig_params) 1290 1291 # Run a forward/backward outside `no_sync()` 1292 for model in (model_use_flat_params, model_use_orig_params): 1293 out = model(inp) 1294 out.backward(grad) 1295 _check_param_grad_parity(model_use_flat_params, model_use_orig_params) 1296 1297 # Check that, since we accumulated gradients across 2 iterations, that 1298 # the new gradients are 2x the reference gradients 1299 grads_use_flat_params = [ 1300 param.grad.detach().clone() for param in model_use_flat_params.parameters() 1301 ] 1302 grads_use_orig_params = [ 1303 param.grad.detach().clone() 1304 for param in model_use_orig_params.parameters() 1305 if param.grad is not None 1306 ] 1307 for grad, ref_grad in zip(grads_use_flat_params, ref_grads_use_flat_params): 1308 torch.testing.assert_close(grad, 2 * ref_grad) 1309 for grad, ref_grad in zip(grads_use_orig_params, ref_grads_use_orig_params): 1310 torch.testing.assert_close(grad, 2 * ref_grad) 1311 1312 @skip_if_lt_x_gpu(2) 1313 def test_no_sync_mixed_precision(self): 1314 """ 1315 Tests that dtypes are as expected when using ``no_sync()`` with 1316 ``use_orig_params=True`` and parameter mixed precision. 1317 """ 1318 self.run_subtests( 1319 { 1320 "sharding_strategy": [ 1321 ShardingStrategy.FULL_SHARD, 1322 ShardingStrategy.SHARD_GRAD_OP, 1323 ShardingStrategy.NO_SHARD, 1324 ] 1325 }, 1326 self._test_no_sync_mixed_precision, 1327 ) 1328 1329 def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy): 1330 model = nn.Linear(3, 3, device="cuda") 1331 mixed_precision = MixedPrecision( 1332 param_dtype=torch.float16, 1333 reduce_dtype=torch.float32, 1334 ) 1335 fsdp_kwargs = { 1336 "sharding_strategy": sharding_strategy, 1337 "mixed_precision": mixed_precision, 1338 "use_orig_params": True, 1339 } 1340 fsdp_model = FSDP(model, **fsdp_kwargs) 1341 inp = torch.randn((2, 3), device="cuda") 1342 with fsdp_model.no_sync(): 1343 # For each of these `no_sync()` backward passes, check that the 1344 # gradients are in the low precision parameter dtype (FP16) 1345 fsdp_model(inp).sum().backward() 1346 for param in fsdp_model.parameters(): 1347 if param.grad is not None: 1348 self.assertEqual(param.grad.dtype, torch.float16) 1349 fsdp_model(inp).sum().backward() 1350 for param in fsdp_model.parameters(): 1351 if param.grad is not None: 1352 self.assertEqual(param.grad.dtype, torch.float16) 1353 # For the backward pass outside `no_sync()`, check that the gradients 1354 # are cast to the full precision in preparation for the optimizer step 1355 fsdp_model(inp).sum().backward() 1356 for param in fsdp_model.parameters(): 1357 if param.grad is not None: 1358 self.assertEqual(param.grad.dtype, torch.float32) 1359 1360 1361class TestFSDPUseOrigParamsInit(FSDPTest): 1362 @skip_if_lt_x_gpu(2) 1363 def test_non_uniform_requires_grad(self): 1364 model = nn.Sequential( 1365 nn.Linear(3, 3, device="cuda"), 1366 nn.Linear(3, 3, device="cuda"), 1367 ) 1368 # Freeze biases only and flatten both weights and biases into the same 1369 # `FlatParameter` to exercise non-uniform `requires_grad` 1370 model[0].bias.requires_grad = False 1371 model[1].bias.requires_grad = False 1372 fsdp_model = FSDP(model, use_orig_params=True) 1373 self.assertTrue(fsdp_model[0].weight.requires_grad) 1374 self.assertFalse(fsdp_model[0].bias.requires_grad) 1375 self.assertTrue(fsdp_model[1].weight.requires_grad) 1376 self.assertFalse(fsdp_model[1].bias.requires_grad) 1377 1378 1379# Define this to be large enough to trigger stack corruption 1380NUM_SIZE0_TENSORS = 1000 1381 1382 1383class TestMultiTensorApply(TestCase): 1384 def test_multi_tensor_apply_size0_tensors_cpu(self): 1385 size0_tensors = [torch.empty(0, device="cpu") for _ in range(NUM_SIZE0_TENSORS)] 1386 # Check that this does not segfault 1387 torch._foreach_mul_(size0_tensors, 0.1) 1388 1389 @unittest.skipIf(not TEST_CUDA, "no cuda") 1390 def test_multi_tensor_apply_size0_tensors_cuda(self): 1391 size0_tensors = [ 1392 torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS) 1393 ] 1394 # Check that this does not segfault 1395 torch._foreach_mul_(size0_tensors, 0.1) 1396 1397 1398instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups) 1399instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard) 1400instantiate_parametrized_tests(TestFSDPUseOrigParamsParamAccess) 1401instantiate_parametrized_tests(TestFSDPUseOrigParamsFQNs) 1402instantiate_parametrized_tests(TestFSDPUseOrigParamsNoSync) 1403 1404if __name__ == "__main__": 1405 run_tests() 1406