1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import copy 5import functools 6import unittest 7from typing import Iterable, List, Tuple, Type, Union 8 9import torch 10import torch.distributed as dist 11import torch.distributed.checkpoint as dcp 12import torch.nn as nn 13from torch.distributed._composable import checkpoint, replicate 14from torch.distributed._composable.fsdp import ( 15 CPUOffloadPolicy, 16 FSDPModule, 17 fully_shard, 18 OffloadPolicy, 19 register_fsdp_forward_method, 20) 21from torch.distributed._tensor import DTensor, init_device_mesh 22from torch.distributed._tensor.debug.comm_mode import CommDebugMode 23from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 24 _CHECKPOINT_PREFIX, 25 apply_activation_checkpointing, 26 CheckpointWrapper, 27) 28from torch.distributed.checkpoint.state_dict import ( 29 get_model_state_dict, 30 get_optimizer_state_dict, 31) 32from torch.distributed.device_mesh import DeviceMesh 33from torch.testing._internal.common_cuda import TEST_CUDA 34from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 35from torch.testing._internal.common_fsdp import ( 36 check_sharded_parity, 37 FSDPTest, 38 FSDPTestMultiThread, 39 MLP, 40 MLPStack, 41 patch_all_gather, 42 patch_reduce_scatter, 43 test_compiled_fsdp, 44) 45from torch.testing._internal.common_utils import ( 46 get_cycles_per_ms, 47 run_tests, 48 skipIfRocm, 49 wrapSwapTensorsTest, 50) 51from torch.testing._internal.distributed._tensor.common_dtensor import ( 52 ModelArgs, 53 Transformer, 54 TransformerBlock, 55) 56from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 57 58c10d_ops = torch.ops.c10d 59funcol = torch.ops.c10d_functional 60 61 62class TestFullyShardForwardInputs(FSDPTestMultiThread): 63 @property 64 def world_size(self) -> int: 65 return 2 66 67 @unittest.skipIf(not TEST_CUDA, "no cuda") 68 def test_root_move_forward_input_to_device(self): 69 device = torch.device("cuda", 0) 70 71 class ParamlessModule(nn.Module): 72 def forward(self, x: torch.Tensor, ys: Tuple[torch.Tensor, ...]): 73 # Check that FSDP moved the inputs to GPU, including recursing 74 # into the tuple data structure 75 assert x.device == device, f"Expects {device} but got {x.device}" 76 assert ( 77 ys[0].device == device 78 ), f"Expects {device} but got {ys[0].device}" 79 assert ( 80 ys[1].device == device 81 ), f"Expects {device} but got {ys[1].device}" 82 y = ys[0] + ys[1] 83 return x + y + 1 84 85 model = ParamlessModule() 86 fully_shard(model) 87 x = torch.randn((3,)) 88 ys = (torch.randn((3,)), torch.randn((3,))) 89 self.assertEqual(x.device, torch.device("cpu")) 90 self.assertEqual(ys[0].device, torch.device("cpu")) 91 self.assertEqual(ys[1].device, torch.device("cpu")) 92 model(x, ys) 93 94 95class TestFullyShardRegisteredParams(FSDPTestMultiThread): 96 @property 97 def world_size(self) -> int: 98 return 4 99 100 @unittest.skipIf(not TEST_CUDA, "no cuda") 101 def test_param_registration_after_forward(self): 102 """Tests the parameter registration after forward.""" 103 device = torch.device("cuda", 0) 104 # Single FSDP group 105 for reshard_after_forward in (True, False, 2): 106 torch.manual_seed(42) 107 model = MLP(3, device) 108 # Since seed is per process, not per thread, we broadcast to ensure 109 # the same parameters across ranks 110 for param in model.parameters(): 111 dist.broadcast(param, src=0) 112 ref_model = copy.deepcopy(model) 113 fully_shard(model, reshard_after_forward=reshard_after_forward) # root only 114 inp = torch.randn((2, 3), device="cuda") 115 self._assert_dtensor_params(model.parameters()) 116 self._assert_same_params(model.parameters(), ref_model.parameters()) 117 model(inp) # root does not reshard after forward 118 self._assert_tensor_params(model.parameters()) 119 self._assert_same_params(model.parameters(), ref_model.parameters()) 120 model.reshard() # however, we can manually reshard 121 self._assert_dtensor_params(model.parameters()) 122 self._assert_same_params(model.parameters(), ref_model.parameters()) 123 124 # Multiple FSDP groups 125 for reshard_after_forward in (True, False, 2): 126 torch.manual_seed(42) 127 model = nn.Sequential(MLP(3, device), MLP(3, device)) 128 for param in model.parameters(): 129 dist.broadcast(param, src=0) 130 ref_model = copy.deepcopy(model) 131 fully_shard(model[0].in_proj, reshard_after_forward=reshard_after_forward) 132 fully_shard(model[0].out_proj, reshard_after_forward=reshard_after_forward) 133 fully_shard(model, reshard_after_forward=reshard_after_forward) 134 135 self._assert_dtensor_params(model.parameters()) 136 self._assert_same_params(model.parameters(), ref_model.parameters()) 137 model(inp) 138 non_root_params = list(model[0].in_proj.parameters()) + list( 139 model[0].out_proj.parameters() 140 ) 141 root_params = list(set(model.parameters()) - set(non_root_params)) 142 if reshard_after_forward is False: 143 self._assert_tensor_params(non_root_params) 144 else: 145 self._assert_dtensor_params(non_root_params) 146 self._assert_tensor_params(root_params) 147 self._assert_same_params(model.parameters(), ref_model.parameters()) 148 for module in model.modules(): 149 if isinstance(module, FSDPModule): 150 module.reshard() # however, we can manually reshard 151 self._assert_dtensor_params(model.parameters()) 152 self._assert_same_params(model.parameters(), ref_model.parameters()) 153 154 @unittest.skipIf(not TEST_CUDA, "no cuda") 155 def test_param_registration_after_backward(self): 156 """Tests the parameter registration after backward.""" 157 device = torch.device("cuda", 0) 158 # Single FSDP group 159 for reshard_after_forward in (True, False, 2): 160 model = MLP(8, device) 161 fully_shard(model, reshard_after_forward=reshard_after_forward) # root only 162 inp = torch.randn((2, 8), device="cuda") 163 self._assert_dtensor_params(model.parameters()) 164 model(inp).sum().backward() 165 self._assert_dtensor_params(model.parameters()) 166 167 # Multiple FSDP groups 168 for reshard_after_forward in (True, False, 2): 169 model = MLP(8, device) 170 fully_shard(model.in_proj, reshard_after_forward=reshard_after_forward) 171 fully_shard(model.out_proj, reshard_after_forward=reshard_after_forward) 172 fully_shard(model, reshard_after_forward=reshard_after_forward) 173 self._assert_dtensor_params(model.parameters()) 174 model(inp).sum().backward() 175 self._assert_dtensor_params(model.parameters()) 176 177 def _assert_tensor_params(self, params: Iterable[nn.Parameter]): 178 self.assertGreater(len(list(params)), 0) 179 for param in params: 180 self.assertNotIsInstance(param, DTensor) 181 self.assertIsInstance(param, torch.Tensor) 182 183 def _assert_dtensor_params(self, params: Iterable[nn.Parameter]): 184 self.assertGreater(len(list(params)), 0) 185 for param in params: 186 self.assertIsInstance(param, DTensor) 187 188 def _assert_same_params( 189 self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter] 190 ): 191 params, ref_params = list(params), list(ref_params) 192 self.assertEqual(len(params), len(ref_params)) 193 for param, ref_param in zip(params, ref_params): 194 if isinstance(param, DTensor): 195 param = param.full_tensor() 196 self.assertEqual(param.shape, ref_param.shape) 197 self.assertEqual(param, ref_param) 198 199 200class TestFullyShardCastAfterInit(FSDPTestMultiThread): 201 @property 202 def world_size(self) -> int: 203 return 2 204 205 @unittest.skipIf(not TEST_CUDA, "no cuda") 206 @wrapSwapTensorsTest(True) 207 def test_to_float64_after_init(self): 208 """Tests that the user can cast the module to float64 after init.""" 209 # NOTE: Test fp64 instead of a lower precision dtype like bf16 for 210 # better numerics. The important part is changing the dtype. 211 torch.manual_seed(42) 212 mlp_dim, device, dtype = 4, torch.device("cuda"), torch.float64 213 model = MLP(mlp_dim, device=device) 214 for param in model.parameters(): 215 dist.broadcast(param, src=0) 216 ref_model = copy.deepcopy(model).to(dtype) 217 replicate(ref_model) 218 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 219 for module in (model.in_proj, model.out_proj, model): 220 fully_shard(module) 221 model.to(dtype) 222 for param in model.parameters(): 223 self.assertEqual(param.dtype, dtype) 224 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 225 check_sharded_parity(self, ref_model, model) 226 torch.manual_seed(42 + self.rank + 1) 227 inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype) 228 for iter_idx in range(10): 229 losses: List[torch.Tensor] = [] 230 for _model in (ref_model, model): 231 losses.append(_model(inp).sum()) 232 losses[-1].backward() 233 self.assertEqual(losses[0], losses[1]) 234 check_sharded_parity(self, ref_model, model) 235 for _optim in (ref_optim, optim): 236 _optim.step() 237 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 238 239 240class TestFullyShard1DTrainingCore(FSDPTest): 241 @property 242 def world_size(self) -> int: 243 return min(8, torch.cuda.device_count()) 244 245 @skip_if_lt_x_gpu(2) 246 def test_train_parity_single_group(self): 247 """Tests train parity with DDP for a single FSDP group.""" 248 self.run_subtests( 249 { 250 "lin_shapes": [[(16, 15), (15, 8)], [(7, 15), (15, 3)]], 251 }, 252 self._test_train_parity_single_group, 253 ) 254 255 def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]): 256 torch.manual_seed(42) 257 model = nn.Sequential( 258 nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1]) 259 ) 260 ref_model = copy.deepcopy(model).cuda() 261 replicate(ref_model, device_ids=[self.rank]) 262 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 263 fully_shard(model) 264 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 265 torch.manual_seed(42 + self.rank + 1) 266 inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),) 267 for iter_idx in range(10): 268 losses: List[torch.Tensor] = [] 269 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 270 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 271 losses.append(_model(*inp).sum()) 272 losses[-1].backward() 273 _optim.step() 274 self.assertEqual(losses[0], losses[1]) 275 276 @skip_if_lt_x_gpu(2) 277 @test_compiled_fsdp(compile_compute_on_module=Transformer) 278 def test_train_parity_multi_group(self): 279 """ 280 Tests train parity against DDP when using multiple parameter groups for 281 communication (for communication and computation overlap plus memory 282 reduction). 283 """ 284 self.run_subtests( 285 { 286 "reshard_after_forward": [True, False, 2], 287 "device_type": ["cuda"], 288 "offload_policy": [OffloadPolicy()], 289 "delay_after_forward": [False, True], 290 "delay_before_all_gather": [False, True], 291 "delay_before_reduce_scatter": [False, True], 292 "delay_before_optim": [False, True], 293 }, 294 self._test_train_parity_multi_group, 295 ) 296 297 @skip_if_lt_x_gpu(2) 298 def test_train_parity_multi_group_cpu_offload_eager(self): 299 """ 300 Tests train parity against DDP when using multiple parameter groups for 301 communication and CPU offloading. 302 """ 303 self.run_subtests( 304 { 305 "reshard_after_forward": [True], # save CI time 306 "offload_policy": [ 307 CPUOffloadPolicy(pin_memory=True), 308 CPUOffloadPolicy(pin_memory=False), 309 ], 310 "device_type": ["cuda"], 311 "delay_after_forward": [False, True], 312 "delay_before_all_gather": [False, True], 313 "delay_before_reduce_scatter": [False, True], 314 "delay_before_optim": [False, True], 315 }, 316 self._test_train_parity_multi_group, 317 ) 318 319 def _test_train_parity_multi_group( 320 self, 321 reshard_after_forward: Union[bool, int], 322 offload_policy: OffloadPolicy, 323 device_type: str, 324 delay_after_forward: bool, 325 delay_before_all_gather: bool, 326 delay_before_reduce_scatter: bool, 327 delay_before_optim: bool, 328 ): 329 # Only test individual delays or all four delays to save test time 330 if ( 331 delay_after_forward 332 + delay_before_all_gather 333 + delay_before_reduce_scatter 334 + delay_before_optim 335 in (2, 3) 336 ): 337 return 338 assert device_type in ("cuda", "cpu"), f"{device_type}" 339 torch.manual_seed(42) 340 lin_dim = 32 341 vocab_size = 1024 342 model_args = ModelArgs( 343 n_layers=3, 344 n_heads=4, 345 vocab_size=vocab_size, 346 max_seq_len=64, 347 dropout_p=0, 348 ) 349 model = Transformer(model_args) 350 ref_model = copy.deepcopy(model) 351 if device_type == "cuda": 352 replicate(ref_model.cuda(), device_ids=[self.rank]) 353 else: 354 gloo_pg = dist.new_group(backend="gloo") 355 replicate(ref_model, process_group=gloo_pg) 356 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 357 mesh = init_device_mesh(device_type, (self.world_size,)) 358 fully_shard_fn = functools.partial( 359 fully_shard, 360 mesh=mesh, 361 reshard_after_forward=reshard_after_forward, 362 offload_policy=offload_policy, 363 ) 364 for module in model.modules(): 365 if isinstance(module, TransformerBlock): 366 fully_shard_fn(module) 367 fully_shard_fn(model) 368 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 369 370 delay_in_ms = 100 371 orig_all_gather = dist.all_gather_into_tensor 372 orig_reduce_scatter = dist.reduce_scatter_tensor 373 374 def delayed_all_gather(*args, **kwargs): 375 torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms())) 376 return orig_all_gather(*args, **kwargs) 377 378 def delayed_reduce_scatter(*args, **kwargs): 379 torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms())) 380 return orig_reduce_scatter(*args, **kwargs) 381 382 torch.manual_seed(42 + self.rank + 1) 383 patch_all_gather_ctx = ( 384 patch_all_gather(delayed_all_gather) 385 if delay_before_all_gather 386 else contextlib.nullcontext() 387 ) 388 patch_reduce_scatter_ctx = ( 389 patch_reduce_scatter(delayed_reduce_scatter) 390 if delay_before_reduce_scatter 391 else contextlib.nullcontext() 392 ) 393 with patch_all_gather_ctx, patch_reduce_scatter_ctx: 394 for iter_idx in range(10): 395 inp = torch.randint(0, vocab_size, (3, 64), device=device_type) 396 losses: List[torch.Tensor] = [] 397 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 398 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 399 losses.append(_model(inp).sum()) 400 if _model is model and delay_after_forward: 401 torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms())) 402 losses[-1].backward() 403 if _model is model and delay_before_optim: 404 torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms())) 405 _optim.step() 406 self.assertEqual(losses[0], losses[1]) 407 408 @skip_if_lt_x_gpu(2) 409 def test_non_root_forward_backward(self): 410 """ 411 Tests running forward/backward through the root and then through a 412 non-root. The non-root needs to synchronize streams/queue the callback. 413 """ 414 torch.manual_seed(42) 415 lin_dim = 32 416 model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)]) 417 ref_model = copy.deepcopy(model).cuda() 418 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 419 for mlp in model: 420 fully_shard(mlp) 421 fully_shard(model) 422 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 423 torch.manual_seed(42 + self.rank) 424 inp = torch.randn((8, lin_dim), device=torch.device("cuda")) 425 426 ref_root_loss = ref_model(inp).sum() 427 ref_root_loss.backward() 428 for param in ref_model.parameters(): 429 dist.all_reduce(param.grad) 430 param.grad.detach().div_(self.world_size) 431 ref_optim.step() 432 ref_optim.zero_grad() 433 ref_nonroot_loss = ref_model[0](inp).sum() 434 ref_nonroot_loss.backward() 435 for param in ref_model.parameters(): 436 if param.grad is not None: 437 dist.all_reduce(param.grad) 438 param.grad.detach().div_(self.world_size) 439 ref_optim.step() 440 441 root_loss = model(inp).sum() 442 root_loss.backward() 443 torch.cuda._sleep(int(100 * get_cycles_per_ms())) 444 optim.step() 445 optim.zero_grad() 446 nonroot_loss = model[0](inp).sum() 447 nonroot_loss.backward() 448 optim.step() 449 450 self.assertEqual(ref_root_loss, root_loss) 451 self.assertEqual(ref_nonroot_loss, nonroot_loss) 452 self.assertEqual(ref_model(inp).sum(), model(inp).sum()) 453 454 @skip_if_lt_x_gpu(2) 455 def test_multi_forward_module(self): 456 """ 457 Tests parity with DDP when running a module that participates multiple 458 times in forward. 459 """ 460 self.run_subtests( 461 {"reshard_after_forward": [True, False, 2]}, 462 self._test_multi_forward_module, 463 ) 464 465 def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]): 466 class MultiForwardModule(nn.Module): 467 def __init__(self, device: torch.device): 468 super().__init__() 469 self.inner = nn.Linear(4, 4, device=device) 470 self.outer = nn.Linear(4, 5, device=device) 471 472 def forward(self, x): 473 i = self.inner(x) 474 j = self.inner(x) 475 return self.outer(i + j) 476 477 torch.manual_seed(42) 478 model = MultiForwardModule(device="cuda") 479 ref_model = copy.deepcopy(model) 480 replicate(ref_model, device_ids=[self.rank]) 481 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 482 fully_shard(model.inner) 483 fully_shard(model) 484 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 485 486 torch.manual_seed(42 + self.rank) 487 inp = torch.randn((32, 4), device="cuda") 488 for iter_idx in range(10): 489 losses: List[torch.Tensor] = [] 490 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 491 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 492 losses.append(_model(inp).sum()) 493 losses[-1].backward() 494 _optim.step() 495 self.assertEqual(losses[0], losses[1]) 496 497 498class TestFullyShard1DTrainingCompose(FSDPTest): 499 @property 500 def world_size(self) -> int: 501 # Since these tests run with a larger transformer model, they may see 502 # some numeric drift with >2 GPUs 503 return min(torch.cuda.device_count(), 2) 504 505 @skip_if_lt_x_gpu(2) 506 @test_compiled_fsdp(compile_compute_on_module=Transformer) 507 def test_train_parity_with_activation_checkpointing(self): 508 """ 509 Tests train parity against DDP when composing with activation 510 checkpointing. 511 """ 512 self.run_subtests( 513 { 514 "reshard_after_forward": [True, False, 2], 515 "checkpoint_impl": ["composable", "utils", "wrapper"], 516 }, 517 self._test_train_parity_with_activation_checkpointing, 518 ) 519 520 def _test_train_parity_with_activation_checkpointing( 521 self, reshard_after_forward: Union[bool, int], checkpoint_impl: str 522 ): 523 assert checkpoint_impl in ("composable", "utils", "wrapper") 524 testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard 525 if testing_compile and checkpoint_impl == "composable": 526 return 527 torch.manual_seed(42) 528 vocab_size = 1024 529 with torch.device(torch.device("cuda")): 530 model_args = ModelArgs( 531 n_layers=3, 532 n_heads=4, 533 vocab_size=vocab_size, 534 max_seq_len=64, 535 dropout_p=0, 536 checkpoint_activations=(checkpoint_impl == "utils"), 537 ) 538 model = Transformer(model_args) 539 ref_model = replicate(copy.deepcopy(model), device_ids=[self.rank]) 540 foreach = True 541 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach) 542 fully_shard_fn = functools.partial( 543 fully_shard, 544 reshard_after_forward=reshard_after_forward, 545 ) 546 if checkpoint_impl == "wrapper": 547 prefixes_to_ignore = (_CHECKPOINT_PREFIX,) 548 apply_activation_checkpointing( 549 model, check_fn=lambda m: isinstance(m, TransformerBlock) 550 ) 551 for module in model.modules(): 552 # Apply to `CheckpointWrapper`, which wraps `TransformerBlock` 553 if isinstance(module, CheckpointWrapper): 554 fully_shard_fn(module) 555 else: 556 prefixes_to_ignore = () 557 for module in model.modules(): 558 if isinstance(module, TransformerBlock): 559 if checkpoint_impl == "composable": 560 checkpoint(module) 561 fully_shard_fn(module) 562 fully_shard_fn(model) 563 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach) 564 565 torch.manual_seed(42 + self.rank) 566 # Reuse the same input across iterations to avoid loss explosion from 567 # trying to learn from random inputs 568 inp = torch.randint(0, vocab_size, (3, 64), device="cuda") 569 check_sharded_parity( 570 self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore 571 ) 572 for iter_idx in range(10): 573 losses: List[torch.Tensor] = [] 574 for _model in (ref_model, model): 575 torch.manual_seed(iter_idx + 1) # for dropout determinism 576 losses.append(_model(inp).sum()) 577 losses[-1].backward() 578 if not testing_compile: 579 check_sharded_parity( 580 self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore 581 ) 582 self.assertEqual(losses[0], losses[1]) 583 for _optim in (ref_optim, optim): 584 _optim.step() 585 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 586 if not testing_compile: 587 check_sharded_parity( 588 self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore 589 ) 590 591 592class TestFullyShardSharedParams(FSDPTest): 593 @property 594 def world_size(self) -> int: 595 return min(4, torch.cuda.device_count()) 596 597 @skip_if_lt_x_gpu(2) 598 def test_train_parity_with_shared_params(self): 599 self.run_subtests( 600 { 601 "reshard_after_forward": [False, True], 602 "use_activation_checkpointing": [False, True], 603 }, 604 self._test_train_shared_params, 605 ) 606 607 def _test_train_shared_params( 608 self, 609 reshard_after_forward: bool, 610 use_activation_checkpointing: bool, 611 ): 612 torch.manual_seed(42) 613 model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True) 614 model = Transformer(model_args) 615 ref_model = copy.deepcopy(model).cuda() 616 replicate(ref_model, device_ids=[self.rank]) 617 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 618 for module in model.modules(): 619 if isinstance(module, TransformerBlock): 620 if use_activation_checkpointing: 621 checkpoint(module) 622 fully_shard(module, reshard_after_forward=reshard_after_forward) 623 fully_shard(model, reshard_after_forward=reshard_after_forward) 624 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 625 626 torch.manual_seed(42 + self.rank + 1) 627 for iter_idx in range(10): 628 inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") 629 losses: List[torch.Tensor] = [] 630 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 631 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 632 losses.append(_model(inp).sum()) 633 losses[-1].backward() 634 _optim.step() 635 self.assertEqual(losses[0], losses[1]) 636 637 638class TestFullyShardGradientAccumulation(FSDPTest): 639 @property 640 def world_size(self) -> int: 641 return min(4, torch.cuda.device_count()) 642 643 @skip_if_lt_x_gpu(2) 644 def test_gradient_accumulation(self): 645 """ 646 Tests gradient accumulation with/without gradient reduction and 647 with/without resharding after backward. 648 """ 649 meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP 650 if self.world_size == 4: # test HSDP too if enough GPUs 651 shard_size, replicate_size = 2, 2 652 meshes.append(init_device_mesh("cuda", (replicate_size, shard_size))) 653 self.run_subtests( 654 { 655 "mesh": meshes, 656 "reshard_after_forward": [True, False, 2], 657 # "all": disable reduce-scatter for all modules 658 # "root_only": disable reduce-scatter for root's linear only 659 # "some_mlps": disable reduce-scatter for some MLPs 660 "mode": ["all", "root_only", "some_mlps"], 661 "reshard_after_backward": [False, True], 662 "offload_policy": [OffloadPolicy(), CPUOffloadPolicy()], 663 # For HSDP only: 664 # `True`: reduce-scatter only (no all-reduce) each microbatch 665 # until the last microbatch 666 # `False`: neither reduce-scatter nor all-reduce each 667 # microbatch until the last microbatch 668 "reduce_scatter_only": [False, True], 669 }, 670 self._test_gradient_accumulation, 671 ) 672 673 def _test_gradient_accumulation( 674 self, 675 mesh: DeviceMesh, 676 reshard_after_forward: Union[bool, int], 677 mode: str, 678 reshard_after_backward: bool, 679 offload_policy: OffloadPolicy, 680 reduce_scatter_only: bool, # for HSDP 681 ): 682 if ( 683 ( 684 not reshard_after_backward 685 and (reshard_after_forward is not False or mode == "some_mlps") 686 ) 687 or ( 688 isinstance(offload_policy, CPUOffloadPolicy) 689 and reshard_after_forward is not True 690 ) 691 or (mesh.ndim != 2 and reduce_scatter_only) 692 ): 693 return # skip since not common or applicable 694 695 torch.manual_seed(42) 696 batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3) 697 if mode == "some_mlps": 698 num_mlps_to_disable_reduce_scatter = 2 699 modules = [nn.Linear(lin_dim, lin_dim)] 700 modules.extend(MLP(lin_dim) for _ in range(num_mlps)) 701 model = nn.Sequential(*modules) 702 ref_model = copy.deepcopy(model).cuda() 703 fully_shard_fn = functools.partial( 704 fully_shard, 705 mesh=mesh, 706 reshard_after_forward=reshard_after_forward, 707 offload_policy=offload_policy, 708 ) 709 for mlp in model[1:]: 710 fully_shard_fn(mlp) 711 fully_shard_fn(model) # root gets the 1st linear 712 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 713 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 714 715 def set_grad_sync_flag( 716 module: nn.Module, is_last_microbatch: bool, recurse: bool = True 717 ): 718 if reduce_scatter_only: 719 module.set_requires_all_reduce(is_last_microbatch, recurse=recurse) 720 else: 721 module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse) 722 723 def set_backward_flags(_model: nn.Module, is_last_microbatch: bool): 724 if mode == "all": 725 set_grad_sync_flag(_model, is_last_microbatch) 726 if not reshard_after_backward: 727 _model.set_reshard_after_backward(is_last_microbatch) 728 elif mode == "some_mlps": 729 for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: 730 set_grad_sync_flag(mlp, is_last_microbatch) 731 if not reshard_after_backward: 732 mlp.set_reshard_after_backward(is_last_microbatch) 733 elif mode == "root_only": 734 set_grad_sync_flag(model, is_last_microbatch, recurse=False) 735 if not reshard_after_backward: 736 model.set_reshard_after_backward(is_last_microbatch, recurse=False) 737 738 torch.manual_seed(42 + self.rank + 1) 739 for iter_idx in range(5): 740 with CommDebugMode() as comm_mode: 741 for microbatch_idx in range(num_microbatches): 742 is_last_microbatch = microbatch_idx == num_microbatches - 1 743 set_backward_flags(model, is_last_microbatch) 744 inp = torch.randn(batch_size, lin_dim, device="cuda") 745 losses: List[torch.Tensor] = [] 746 for _model in (ref_model, model): 747 losses.append(_model(inp).sum()) 748 losses[-1].backward() 749 self.assertEqual(losses[0], losses[1]) 750 751 comm_counts = comm_mode.get_comm_counts() 752 all_gather_count = comm_counts[c10d_ops._allgather_base_] 753 reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_] 754 all_reduce_count = comm_counts[c10d_ops.allreduce_] 755 756 # Expect one reduce-scatter per MLP plus one for the root's linear 757 # on the last microbatch 758 expected_reduce_scatter_count = num_mlps + 1 759 if mode == "some_mlps": 760 # Expect additional reduce-scatters for non-disabled MLPs and 761 # the root's linear 762 expected_reduce_scatter_count += ( 763 num_mlps - num_mlps_to_disable_reduce_scatter + 1 764 ) * (num_microbatches - 1) 765 elif mode == "root_only": 766 # Expect additional reduce-scatters for all MLPs 767 expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1) 768 expected_all_reduce_count = ( 769 expected_reduce_scatter_count if mesh.ndim == 2 else 0 770 ) 771 if reduce_scatter_only: 772 # Specially for HSDP if only reduce-scattering but not 773 # all-reducing until the last microbatch, expect one 774 # reduce-scatter per MLP plus for the root per microbatch 775 expected_reduce_scatter_count = (num_mlps + 1) * num_microbatches 776 self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) 777 self.assertEqual(all_reduce_count, expected_all_reduce_count) 778 779 # Expect one all-gather per MLP plus one for the root's linear in 780 # the first microbatch's forward 781 expected_all_gather_count = num_mlps + 1 782 if reshard_after_forward is not False: # `True` or `2` 783 # Add the number of MLPs without the +1 for the backward 784 # all-gathers since the root does not reshard after forward 785 expected_all_gather_count += num_mlps 786 # Multiply by the number of microbatches since these 787 # all-gathers run every microbatch 788 expected_all_gather_count *= num_microbatches 789 elif reshard_after_backward: # `reshard_after_forward=False` 790 expected_all_gather_count *= num_microbatches 791 elif mode == "all": # `reshard_after_forward/backward=False` 792 # Only reshard parameters after the last microbatch's backward, 793 # so there should not be any more all-gathers 794 pass 795 elif mode == "root_only": # `reshard_after_forward/backward=False` 796 # The MLPs should still contribute all-gathers in each 797 # microbatch forward 798 expected_all_gather_count += num_mlps * (num_microbatches - 1) 799 self.assertEqual(all_gather_count, expected_all_gather_count) 800 801 for param in ref_model.parameters(): 802 if param.grad is not None: 803 dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) 804 check_sharded_parity(self, ref_model, model) 805 for _optim in (optim, ref_optim): 806 _optim.step() 807 # When `set_to_none=False`, we are exercising mixing 808 # gradient accumulation with and without communication 809 _optim.zero_grad(set_to_none=(iter_idx % 2)) 810 811 @skip_if_lt_x_gpu(2) 812 def test_1f1b_microbatching(self): 813 self.run_subtests( 814 { 815 "use_explicit_unshard": [False, True], 816 "reshard_after_backward": [False, True], 817 }, 818 self._test_1f1b_microbatching, 819 ) 820 821 def _test_1f1b_microbatching( 822 self, use_explicit_unshard: bool, reshard_after_backward: bool 823 ): 824 torch.manual_seed(42) 825 model_args = ModelArgs(dropout_p=0.0) 826 model = Transformer(model_args) 827 ref_model = copy.deepcopy(model).cuda() 828 ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) 829 for module in model.modules(): 830 if isinstance(module, TransformerBlock): 831 fully_shard(module, reshard_after_forward=False) 832 fully_shard(model, reshard_after_forward=False) 833 optim = torch.optim.AdamW(model.parameters(), lr=1e-2) 834 835 num_microbatches = 3 836 local_batch_size = 2 837 torch.manual_seed(42 + self.rank + 1) 838 inps = [ 839 torch.randint( 840 0, model_args.vocab_size, (local_batch_size, 16), device="cuda" 841 ) 842 for _ in range(num_microbatches) 843 ] 844 845 # Before pipelining, we may prefer to issue all all-gathers ahead of 846 # time to increase overlap opportunity at no difference in parameter 847 # memory usage since we do not reshard after forward 848 if use_explicit_unshard: 849 for module in model.modules(): 850 if isinstance(module, FSDPModule): 851 module.unshard(async_op=True) 852 853 # Emulate the 1f1b pipeline schedule and only reduce gradients on the 854 # last microbatch 855 losses: List[torch.Tensor] = [] 856 ref_losses: List[torch.Tensor] = [] 857 for inp_idx, inp in enumerate(inps): 858 is_last_microbatch = inp_idx == num_microbatches - 1 859 model.set_requires_gradient_sync(is_last_microbatch) 860 model.set_is_last_backward(is_last_microbatch) 861 if not reshard_after_backward: 862 model.set_reshard_after_backward(is_last_microbatch) 863 losses.append(model(inp).sum()) 864 losses[-1].backward() 865 ref_losses.append(ref_model(inp).sum()) 866 ref_losses[-1].backward() 867 for param in ref_model.parameters(): 868 dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) 869 870 for loss, ref_loss in zip(losses, ref_losses): 871 self.assertEqual(loss, ref_loss) 872 optim.step() 873 ref_optim.step() 874 check_sharded_parity(self, ref_model, model) 875 876 877class TestFullyShard2DTraining(FSDPTest): 878 @property 879 def world_size(self) -> int: 880 return min(4, torch.cuda.device_count()) 881 882 def init_global_mesh(self) -> DeviceMesh: 883 # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP 884 dp_size = 2 if self.world_size > 2 else 1 885 return init_device_mesh( 886 "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") 887 ) 888 889 @skip_if_lt_x_gpu(2) 890 @skipIfRocm 891 def test_train_parity_2d_mlp(self): 892 global_mesh = self.init_global_mesh() 893 self.run_subtests( 894 { 895 "reshard_after_forward": [False, True], 896 "use_activation_checkpointing": [False, True], 897 "mlp_dim": [3, 16, 17], 898 }, 899 functools.partial(self._test_train_parity_2d_mlp, global_mesh), 900 ) 901 902 def _test_train_parity_2d_mlp( 903 self, 904 global_mesh: DeviceMesh, 905 reshard_after_forward: bool, 906 use_activation_checkpointing: bool, 907 mlp_dim: int, 908 ): 909 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 910 dp_pg = dp_mesh.get_group() # used for `replicate()` 911 912 torch.manual_seed(42) 913 model = MLPStack(mlp_dim) 914 ref_model = copy.deepcopy(model).cuda() 915 replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) 916 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) 917 model.parallelize( 918 tp_mesh, 919 dp_mesh, 920 use_activation_checkpointing, 921 reshard_after_forward=reshard_after_forward, 922 ) 923 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) 924 925 torch.manual_seed(42 + dp_pg.rank() + 1) 926 device = torch.device("cuda") 927 for iter_idx in range(10): 928 inp = torch.randn((8, mlp_dim), device=device) 929 losses: List[torch.Tensor] = [] 930 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 931 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 932 losses.append(_model(inp).sum()) 933 losses[-1].backward() 934 _optim.step() 935 self.assertEqual(losses[0], losses[1]) 936 937 @skip_if_lt_x_gpu(2) 938 @skipIfRocm 939 def test_tp_with_fsdp_offloading(self): 940 global_mesh = init_device_mesh( 941 "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") 942 ) 943 dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] 944 torch.manual_seed(42) 945 mlp_dim = 16 946 model = MLPStack(mlp_dim) 947 ref_model = copy.deepcopy(model).cuda() 948 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) 949 # Parallelize with N-way TP and 1-way FSDP 950 model.parallelize( 951 tp_mesh, 952 dp_mesh, 953 use_activation_checkpointing=False, 954 reshard_after_forward=True, 955 offload_policy=CPUOffloadPolicy(), 956 ) 957 for param in model.parameters(): 958 self.assertEqual(param.device.type, "cpu") 959 num_mlps = sum(isinstance(module, MLP) for module in model.modules()) 960 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) 961 962 # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops 963 # called, but they will just be no-ops without issuing any kernels. 964 # We prefer to keep the no-op check at the c10d level, not in FSDP. 965 inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks 966 for iter_idx in range(10): 967 ref_optim.zero_grad() 968 optim.zero_grad() 969 970 with CommDebugMode() as fwd_comm_mode: 971 loss = model(inp).sum() 972 973 fwd_comm_counts = fwd_comm_mode.get_comm_counts() 974 self.assertEqual(len(fwd_comm_counts), 2) 975 self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) 976 self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) 977 ref_loss = ref_model(inp).sum() 978 self.assertEqual(loss, ref_loss) 979 980 with CommDebugMode() as bwd_comm_mode: 981 loss.backward() 982 bwd_comm_counts = bwd_comm_mode.get_comm_counts() 983 self.assertEqual(len(bwd_comm_counts), 3) 984 # First MLP's input gradient does not need to be all-reduced 985 self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) 986 self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) 987 self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) 988 ref_loss.backward() 989 990 optim.step() 991 ref_optim.step() 992 993 # TODO: remove this test when 2d state_dict is ready. 994 @skip_if_lt_x_gpu(2) 995 @skipIfRocm 996 def test_raise_not_implemented_state_dict_if_2d(self): 997 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): 998 _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel) 999 for layer in _model.layers: 1000 fully_shard(layer, mesh=mesh["dp"]) 1001 fully_shard(_model, mesh=mesh["dp"]) 1002 return _model 1003 1004 global_mesh = self.init_global_mesh() 1005 seed = 42 1006 torch.manual_seed(seed) 1007 model_args = ModelArgs(dropout_p=0.0) 1008 model = parallelize(Transformer(model_args), global_mesh, True) 1009 1010 with self.assertRaisesRegex(NotImplementedError, "2D"): 1011 get_model_state_dict(model) 1012 1013 # Temporarily disable 2D state dict test, while strided sharding is being devleoped. 1014 # TODO: re-enable this test once 2d state_dict is ready. 1015 @skip_if_lt_x_gpu(2) 1016 @with_temp_dir 1017 def _temp_disable_test_train_parity_2d_transformer_checkpoint_resume(self): 1018 """ 1019 Tests train parity of a 2D transformer without checkpointing against a 1020 2D transformer with a checkpoint save/load. 1021 """ 1022 self.run_subtests( 1023 { 1024 "use_seq_parallel": [False, True], 1025 # If reusing, then load into the same model/optimizer instance 1026 # else construct new ones (requiring eager optim state init) 1027 "reuse_model_optim": [False, True], 1028 "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], 1029 # TODO: need to update `parallelize` before including foreach=True for testing 1030 "foreach": [False], 1031 }, 1032 self._test_train_parity_2d_transformer_checkpoint_resume, 1033 ) 1034 1035 def _test_train_parity_2d_transformer_checkpoint_resume( 1036 self, 1037 use_seq_parallel: bool, 1038 reuse_model_optim: bool, 1039 optimizer_class: Type[torch.optim.Optimizer], 1040 foreach: bool, 1041 ): 1042 def train_step( 1043 _model: nn.Module, _optim: torch.optim.Optimizer, _inp: torch.Tensor 1044 ) -> torch.Tensor: 1045 loss = _model(_inp).sum() 1046 loss.backward() 1047 _optim.step() 1048 _optim.zero_grad() 1049 return loss 1050 1051 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): 1052 _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel) 1053 for layer in _model.layers: 1054 fully_shard(layer, mesh=mesh["dp"]) 1055 fully_shard(_model, mesh=mesh["dp"]) 1056 return _model 1057 1058 global_mesh = self.init_global_mesh() 1059 # Baseline: run two iterations without checkpointing 1060 seed = 42 1061 torch.manual_seed(seed) 1062 model_args = ModelArgs(dropout_p=0.0) 1063 model_no_cp = parallelize( 1064 Transformer(model_args), global_mesh, use_seq_parallel 1065 ) 1066 optim_no_cp = optimizer_class( 1067 model_no_cp.parameters(), lr=1e-2, foreach=foreach 1068 ) 1069 1070 torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1) 1071 inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda") 1072 loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp) 1073 loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp) 1074 1075 # Test: run one iteration, save checkpoint, zero states or init new 1076 # model/optimizer, load checkpoint, and run another iteration 1077 torch.manual_seed(seed) 1078 model_cp = parallelize(Transformer(model_args), global_mesh, use_seq_parallel) 1079 optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach) 1080 1081 loss_cp1 = train_step(model_cp, optim_cp, inp) 1082 self.assertEqual(loss_no_cp1, loss_cp1) 1083 1084 sharded_sd = { 1085 "model": get_model_state_dict(model_cp), 1086 # Use `get_optimizer_state_dict` to handle eager optim state init 1087 # when constructing a new optimizer instance 1088 "optim": get_optimizer_state_dict(model_cp, optim_cp), 1089 } 1090 dcp.save( 1091 state_dict=sharded_sd, 1092 storage_writer=dcp.FileSystemWriter(self.temp_dir), 1093 ) 1094 if reuse_model_optim: 1095 with torch.no_grad(): 1096 for param in model_cp.parameters(): 1097 param.zero_() 1098 optim_sd = optim_cp.state_dict() 1099 for param_states in optim_sd["state"].values(): 1100 for state_value in param_states.values(): 1101 if torch.is_tensor(state_value): 1102 state_value.zero_() 1103 else: 1104 torch.manual_seed(seed + 1) # different seed 1105 model_cp = parallelize( 1106 Transformer(model_args), global_mesh, use_seq_parallel 1107 ) 1108 optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach) 1109 self.assertNotEqual(loss_no_cp2, train_step(model_cp, optim_cp, inp)) 1110 1111 sharded_sd = { 1112 "model": get_model_state_dict(model_cp), 1113 "optim": get_optimizer_state_dict(model_cp, optim_cp), 1114 } 1115 dcp.load( 1116 state_dict=sharded_sd, 1117 storage_reader=dcp.FileSystemReader(self.temp_dir), 1118 ) 1119 self.assertGreater(len(optim_cp.state_dict()["state"]), 0) 1120 1121 loss_cp2 = train_step(model_cp, optim_cp, inp) 1122 self.assertEqual(loss_no_cp2, loss_cp2) 1123 1124 1125class TestFullyShardNDTraining(FSDPTest): 1126 @property 1127 def world_size(self) -> int: 1128 return min(8, torch.cuda.device_count()) 1129 1130 def init_global_mesh(self) -> DeviceMesh: 1131 # Prefer to test with >=8 GPUs, but for 2 GPUs, use 2-way TP 1132 dp_size = 2 if self.world_size > 2 else 1 1133 pp_size = 2 if self.world_size > 4 else 1 1134 return init_device_mesh( 1135 "cuda", 1136 (pp_size, dp_size, self.world_size // (dp_size * pp_size)), 1137 mesh_dim_names=("pp", "dp", "tp"), 1138 ) 1139 1140 @skip_if_lt_x_gpu(4) 1141 def test_2d_mlp_with_nd_mesh(self): 1142 global_mesh = self.init_global_mesh() 1143 self.run_subtests( 1144 { 1145 "reshard_after_forward": [False, True], 1146 "use_activation_checkpointing": [False, True], 1147 "mlp_dim": [3, 16, 17], 1148 "foreach": [False], 1149 }, 1150 functools.partial(self._test_2d_mlp_with_nd_mesh, global_mesh), 1151 ) 1152 1153 def _test_2d_mlp_with_nd_mesh( 1154 self, 1155 global_mesh: DeviceMesh, 1156 reshard_after_forward: bool, 1157 use_activation_checkpointing: bool, 1158 mlp_dim: int, 1159 foreach: bool, 1160 ): 1161 global_mesh = self.init_global_mesh() 1162 pp_mesh, dp_mesh, tp_mesh = ( 1163 global_mesh["pp"], 1164 global_mesh["dp"], 1165 global_mesh["tp"], 1166 ) 1167 dp_pg = dp_mesh.get_group() # used for `replicate()` 1168 1169 torch.manual_seed(42) 1170 model = MLPStack(mlp_dim) 1171 ref_model = copy.deepcopy(model).cuda() 1172 replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) 1173 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach) 1174 model.parallelize( 1175 tp_mesh, 1176 dp_mesh, 1177 use_activation_checkpointing, 1178 reshard_after_forward=reshard_after_forward, 1179 ) 1180 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach) 1181 1182 torch.manual_seed(42 + dp_pg.rank() + 1) 1183 device = torch.device("cuda") 1184 for iter_idx in range(10): 1185 inp = torch.randn((8, mlp_dim), device=device) 1186 losses: List[torch.Tensor] = [] 1187 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 1188 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 1189 losses.append(_model(inp).sum()) 1190 losses[-1].backward() 1191 _optim.step() 1192 self.assertEqual(losses[0], losses[1]) 1193 1194 for n, p in model.named_parameters(): 1195 self.assertIsInstance(p, DTensor) 1196 self.assertEqual(p.device_mesh.ndim, 2) 1197 self.assertEqual(len(p.placements), 2) 1198 self.assertEqual(p.device_mesh.mesh_dim_names, ("dp", "tp")) 1199 1200 1201class TestFullyShardHSDPTraining(FSDPTest): 1202 @property 1203 def world_size(self) -> int: 1204 return min(4, torch.cuda.device_count()) 1205 1206 @skip_if_lt_x_gpu(2) 1207 def test_train_parity_hsdp(self): 1208 shard_size = 2 if self.world_size > 2 else 1 1209 replicate_size = self.world_size // shard_size 1210 global_mesh = init_device_mesh( 1211 "cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard") 1212 ) 1213 self.run_subtests( 1214 { 1215 "reshard_after_forward": [False, True], 1216 "use_activation_checkpointing": [False, True], 1217 "mlp_dim": [3, 16, 17], 1218 "sync_gradients_at_last_batch": [True, False], 1219 }, 1220 functools.partial(self._test_train_parity_hsdp, global_mesh), 1221 ) 1222 1223 def _test_train_parity_hsdp( 1224 self, 1225 global_mesh: DeviceMesh, 1226 reshard_after_forward: bool, 1227 use_activation_checkpointing: bool, 1228 mlp_dim: int, 1229 sync_gradients_at_last_batch: bool, 1230 ): 1231 torch.manual_seed(42) 1232 model = nn.Sequential( 1233 nn.LayerNorm(mlp_dim, bias=False), 1234 MLP(mlp_dim, dim_multiplier=3), 1235 MLP(mlp_dim), 1236 MLP(mlp_dim, dim_multiplier=3), 1237 ) 1238 ref_model = copy.deepcopy(model).cuda() 1239 replicate(ref_model, device_ids=[self.rank]) 1240 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 1241 for mlp in model: 1242 if use_activation_checkpointing: 1243 checkpoint(mlp) 1244 fully_shard( 1245 mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward 1246 ) 1247 fully_shard( 1248 model, mesh=global_mesh, reshard_after_forward=reshard_after_forward 1249 ) 1250 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 1251 check_sharded_parity(self, ref_model, model) 1252 torch.manual_seed(42 + self.rank + 1) 1253 device = torch.device("cuda") 1254 num_microbatches = 3 1255 for iter_idx in range(5): 1256 for microbatch_idx in range(num_microbatches): 1257 is_last_microbatch = microbatch_idx == num_microbatches - 1 1258 if sync_gradients_at_last_batch: 1259 model.set_requires_gradient_sync(is_last_microbatch) 1260 inp = torch.randn((8, mlp_dim), device=device) 1261 losses: List[torch.Tensor] = [] 1262 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 1263 losses.append(_model(inp).sum()) 1264 losses[-1].backward() 1265 self.assertEqual(losses[0], losses[1]) 1266 check_sharded_parity(self, ref_model, model) 1267 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 1268 _optim.step() 1269 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 1270 check_sharded_parity(self, ref_model, model) 1271 1272 1273class TestFullyShardCustomForwardMethod(FSDPTest): 1274 @property 1275 def world_size(self) -> int: 1276 return min(torch.cuda.device_count(), 2) 1277 1278 @skip_if_lt_x_gpu(2) 1279 def test_register_fsdp_forward_method(self): 1280 """Based on https://github.com/pytorch/pytorch/issues/109385""" 1281 1282 class VisionTransformer(nn.Module): 1283 def __init__(self): 1284 super().__init__() 1285 self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14) 1286 1287 def forward_features(self, imgs: torch.Tensor) -> torch.Tensor: 1288 return self.patch_proj(imgs).flatten(2).transpose(1, 2) 1289 1290 def forward(self, imgs: torch.Tensor) -> torch.Tensor: 1291 return self.forward_features(imgs).sum(dim=1) 1292 1293 class Model(nn.Module): 1294 def __init__(self): 1295 super().__init__() 1296 self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256) 1297 1298 def forward(self, imgs: torch.Tensor) -> torch.Tensor: 1299 # Run `vit.forward_features`, which is not `forward`! 1300 patch_embeddings = self.vit.forward_features(imgs) 1301 return self.projector(patch_embeddings) 1302 1303 torch.manual_seed(42) 1304 model = Model() 1305 ref_model = copy.deepcopy(model).cuda() 1306 fully_shard(model.vit) 1307 fully_shard(model.projector) 1308 fully_shard(model) 1309 register_fsdp_forward_method(model.vit, "forward_features") 1310 1311 torch.manual_seed(42 + self.rank + 1) 1312 inp = torch.randn(4, 3, 224, 224, device="cuda") 1313 ref_loss = ref_model(inp).sum() 1314 loss = model(inp).sum() 1315 self.assertEqual(ref_loss, loss) 1316 ref_loss.backward() 1317 loss.backward() 1318 for param in ref_model.parameters(): 1319 dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) 1320 check_sharded_parity(self, ref_model, model) 1321 1322 1323if __name__ == "__main__": 1324 run_tests() 1325