1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5import itertools 6import unittest 7from typing import Callable, List, Optional, Tuple, Union 8 9import torch 10import torch.distributed as dist 11import torch.nn as nn 12import torch.nn.functional as F 13from torch.distributed._composable import checkpoint, replicate 14from torch.distributed._composable.fsdp import ( 15 FSDPModule, 16 fully_shard, 17 MixedPrecisionPolicy, 18 OffloadPolicy, 19) 20from torch.distributed._composable.fsdp._fsdp_collectives import ( 21 _div_if_needed, 22 _get_gradient_divide_factors, 23 foreach_all_gather, 24 foreach_all_gather_copy_out, 25 foreach_reduce, 26) 27from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState 28from torch.distributed._composable.fsdp._fsdp_init import ( 29 _get_post_forward_mesh_info, 30 _init_default_fully_shard_mesh, 31) 32from torch.distributed._composable.fsdp._fsdp_param import ShardedState 33from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup 34from torch.distributed._tensor import DTensor 35from torch.distributed._tensor.experimental import implicit_replication 36from torch.distributed.device_mesh import DeviceMesh, init_device_mesh 37from torch.distributed.tensor.debug import CommDebugMode 38from torch.testing._internal.common_cuda import TEST_CUDA 39from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 40from torch.testing._internal.common_fsdp import ( 41 check_sharded_parity, 42 DoubleLinear, 43 FSDPTest, 44 FSDPTestMultiThread, 45 MLP, 46 patch_post_backward, 47 patch_reshard, 48 patch_unshard, 49) 50from torch.testing._internal.common_utils import run_tests 51from torch.testing._internal.distributed._tensor.common_dtensor import ( 52 ModelArgs, 53 Transformer, 54 TransformerBlock, 55) 56 57 58c10d_ops = torch.ops.c10d 59 60# For recording FSDP events like unshard or post-backward 61EventType = Tuple[str, str, TrainingState] 62 63 64class TestFullyShardCollectiveOps(FSDPTestMultiThread): 65 @property 66 def world_size(self) -> int: 67 return 128 68 69 @property 70 def device(self) -> torch.device: 71 return torch.device("cuda:0") 72 73 def _get_param_sizes(self) -> List[torch.Size]: 74 # For world size 128, the fp32 all-gather and reduce-scatter testing 75 # requires ~0.22 GB 76 return [ 77 torch.Size([17, 257]), 78 torch.Size([17]), 79 torch.Size([64, 312]), 80 torch.Size([64]), 81 torch.Size([64, 64]), 82 torch.Size([512, 64]), 83 torch.Size([256]), 84 torch.Size([64, 297]), 85 ] 86 87 def _init_params(self, param_sizes: List[torch.Size]) -> List[nn.Parameter]: 88 torch.manual_seed(42) 89 orig_params = [ 90 nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes 91 ] 92 # Since seed is per process, not per thread, we broadcast to ensure the 93 # same original parameters across ranks 94 for orig_param in orig_params: 95 dist.broadcast(orig_param, src=0) 96 return orig_params 97 98 def _init_fsdp_param_group( 99 self, params: List[nn.Parameter], reshard_after_forward: Union[bool, int] 100 ): 101 module = nn.ParameterList([param.detach().clone() for param in params]) 102 mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0) 103 post_forward_mesh_info = _get_post_forward_mesh_info( 104 reshard_after_forward, mesh_info 105 ) 106 fsdp_param_group = FSDPParamGroup( 107 list(module.parameters()), 108 (module,), 109 mesh_info, 110 post_forward_mesh_info, 111 self.device, 112 MixedPrecisionPolicy(), 113 OffloadPolicy(), 114 ) 115 fsdp_param_group.lazy_init() 116 return fsdp_param_group 117 118 @unittest.skipIf(not TEST_CUDA, "no cuda") 119 def test_all_gather_fp32(self): 120 param_sizes = self._get_param_sizes() 121 default_stream = torch.cuda.current_stream() 122 stream1, stream2 = torch.cuda.Stream(), torch.cuda.Stream() 123 for async_op, streams, reshard_after_forward in itertools.product( 124 (False, True), 125 ((default_stream, default_stream), (stream1, stream2)), 126 (True, 8), 127 ): 128 all_gather_copy_in_stream, all_gather_stream = streams 129 # Save test time by only testing reshard after forward as an int 130 # for non-async and non-default streams (like in pre-backward) 131 if type(reshard_after_forward) is int and ( 132 async_op or all_gather_stream is default_stream 133 ): 134 continue 135 self._test_all_gather( 136 param_sizes, 137 reshard_after_forward=reshard_after_forward, 138 async_op=async_op, 139 all_gather_copy_in_stream=all_gather_copy_in_stream, 140 all_gather_stream=all_gather_stream, 141 ) 142 143 def _test_all_gather( 144 self, 145 param_sizes: List[torch.Size], 146 reshard_after_forward: Union[bool, int], 147 async_op: bool, 148 all_gather_copy_in_stream: torch.cuda.Stream, 149 all_gather_stream: torch.cuda.Stream, 150 ): 151 def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup): 152 all_gather_result = foreach_all_gather( 153 fsdp_param_group.fsdp_params, 154 group, 155 async_op=async_op, 156 all_gather_copy_in_stream=all_gather_copy_in_stream, 157 all_gather_stream=all_gather_stream, 158 device=self.device, 159 ) 160 foreach_all_gather_copy_out(all_gather_result, fsdp_params, group) 161 # Transition to unsharded state to register unsharded parameters 162 for fsdp_param in fsdp_param_group.fsdp_params: 163 fsdp_param.init_unsharded_param() 164 fsdp_param_group._to_unsharded() 165 166 def check_all_gathered_params( 167 orig_params: List[nn.Parameter], module: nn.Module 168 ): 169 for orig_param, param in zip(orig_params, module.parameters()): 170 self.assertIsInstance(param, torch.Tensor) 171 self.assertIsInstance(param, nn.Parameter) 172 self.assertEqual(param, orig_param.to(param.dtype)) 173 174 # Set up the reference parameters and construct the FSDP group 175 orig_params = self._init_params(param_sizes) 176 fsdp_param_group = self._init_fsdp_param_group( 177 orig_params, reshard_after_forward 178 ) 179 fsdp_params = fsdp_param_group.fsdp_params 180 module = fsdp_param_group.modules[0] 181 182 # Sanity check that the parameter sharding is as expected 183 for orig_param, param in zip(orig_params, module.parameters()): 184 self.assertTrue(isinstance(param, DTensor)) 185 self.assertEqual(param.full_tensor(), orig_param) 186 187 # Run the foreach all-gather (including copy-in and copy-out) 188 all_gather(fsdp_param_group, fsdp_param_group.mesh_info.shard_process_group) 189 190 # Check all-gather correctness 191 check_all_gathered_params(orig_params, module) 192 193 # For reshard after after forward as an int, further test emulating the 194 # pre-backward all-gather 195 if type(reshard_after_forward) is not int: 196 return 197 fsdp_param_group._to_sharded_post_forward() 198 all_gather( 199 fsdp_param_group, 200 fsdp_param_group.post_forward_mesh_info.shard_process_group, 201 ) 202 check_all_gathered_params(orig_params, module) 203 204 @unittest.skipIf(not TEST_CUDA, "no cuda") 205 def test_reduce_scatter_fp32(self): 206 param_sizes = self._get_param_sizes() 207 default_stream = torch.cuda.current_stream() 208 stream = torch.cuda.Stream() 209 for reduce_scatter_stream in (default_stream, stream): 210 self._test_reduce_scatter( 211 param_sizes, 212 reduce_scatter_stream=reduce_scatter_stream, 213 reduce_scatter_dtype=torch.float32, 214 ) 215 216 @unittest.skipIf(not TEST_CUDA, "no cuda") 217 def test_reduce_scatter_fp16(self): 218 param_sizes = self._get_param_sizes() 219 default_stream = torch.cuda.current_stream() 220 stream = torch.cuda.Stream() 221 for reduce_scatter_stream in (default_stream, stream): 222 self._test_reduce_scatter( 223 param_sizes, 224 reduce_scatter_stream=reduce_scatter_stream, 225 reduce_scatter_dtype=torch.float16, 226 ) 227 228 def _test_reduce_scatter( 229 self, 230 param_sizes: List[torch.Size], 231 reduce_scatter_stream: torch.cuda.Stream, 232 reduce_scatter_dtype: torch.dtype, 233 ): 234 # Set up the reference parameters and construct the FSDP group 235 orig_params = self._init_params(param_sizes) 236 fsdp_param_group = self._init_fsdp_param_group(orig_params, True) 237 fsdp_params = fsdp_param_group.fsdp_params 238 fsdp_param_group.comm_ctx.lazy_init() 239 240 # Run one unshard to initialize metadata 241 fsdp_param_group.unshard() 242 fsdp_param_group.wait_for_unshard() 243 fsdp_param_group.reshard() 244 245 # Run the foreach reduce-scatter (including copy-in and view-out) 246 torch.manual_seed(42) 247 unsharded_grads = [torch.ones_like(param) * self.rank for param in orig_params] 248 group = fsdp_param_group.mesh_info.shard_process_group 249 self.assertEqual(group.size(), self.world_size) 250 all_reduce_stream = torch.cuda.Stream() 251 ( 252 reduce_scatter_input, 253 reduce_scatter_event, 254 post_reduce_event, 255 _, 256 ) = foreach_reduce( 257 fsdp_params, 258 unsharded_grads, 259 group, 260 reduce_scatter_stream, 261 orig_dtype=orig_params[0].dtype, 262 reduce_dtype=reduce_scatter_dtype, 263 device=self.device, 264 reduce_scatter_reduce_op=None, 265 all_reduce_group=None, 266 all_reduce_stream=all_reduce_stream, 267 all_reduce_grads=True, 268 partial_reduce_output=None, 269 ) 270 torch.cuda.current_stream().wait_event(post_reduce_event) 271 272 # Check reduce-scatter correctness 273 predivide_factor, postdivide_factor = _get_gradient_divide_factors( 274 group, None, reduce_scatter_dtype 275 ) 276 reduced_grads = [grad.detach().clone() for grad in unsharded_grads] 277 for grad in reduced_grads: 278 _div_if_needed(grad, predivide_factor) 279 dist.all_reduce( 280 grad, 281 group=group, 282 op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM, 283 ) 284 _div_if_needed(grad, postdivide_factor) 285 for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads): 286 sharded_grad = fsdp_param.sharded_param.grad 287 self.assertIsInstance(sharded_grad, DTensor) 288 self.assertEqual(sharded_grad.full_tensor(), reduced_grad) 289 290 291class TestFullyShardCommunication(FSDPTest): 292 @property 293 def world_size(self) -> int: 294 return min(4, torch.cuda.device_count()) 295 296 @skip_if_lt_x_gpu(2) 297 def test_fully_shard_communication_count(self): 298 """ 299 Tests that FSDP issues the expected number of all-gathers and 300 reduce-scatters during forward and backward. 301 """ 302 self.run_subtests( 303 {"reshard_after_forward": [True, False, 2]}, 304 self._test_communication_count, 305 ) 306 307 def _test_communication_count( 308 self, 309 reshard_after_forward: Union[bool, int], 310 ): 311 torch.manual_seed(42) 312 model_args = ModelArgs() 313 model = Transformer(model_args) 314 fully_shard_fn = functools.partial( 315 fully_shard, reshard_after_forward=reshard_after_forward 316 ) 317 num_blocks = 0 318 for module in model.modules(): 319 if isinstance(module, TransformerBlock): 320 fully_shard_fn(module) 321 num_blocks += 1 322 fully_shard_fn(model) 323 # We construct `num_blocks` plus 1 FSDP states/communication groups 324 325 torch.manual_seed(42 + self.rank) 326 inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") 327 with CommDebugMode() as fwd_comm_mode: 328 loss = model(inp) 329 fwd_comm_counts = fwd_comm_mode.get_comm_counts() 330 self.assertEqual(len(fwd_comm_counts), 1) 331 self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_blocks + 1) 332 with CommDebugMode() as bwd_comm_mode: 333 loss.sum().backward() 334 bwd_comm_counts = bwd_comm_mode.get_comm_counts() 335 if reshard_after_forward is False: 336 self.assertEqual(len(bwd_comm_counts), 1) 337 else: 338 # The root always does not reshard after forward 339 self.assertEqual(len(bwd_comm_counts), 2) 340 self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks) 341 self.assertEqual( 342 bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_blocks + 1 343 ) 344 345 @skip_if_lt_x_gpu(2) 346 def test_manual_reshard_with_reshard_after_forward_false(self): 347 """ 348 Tests that we can manually call ``reshard`` on FSDP modules that were 349 initialized with ``reshard_after_forward=False`` and still run unshard. 350 """ 351 torch.manual_seed(42) 352 model_args = ModelArgs() 353 model = Transformer(model_args) 354 for module in model.modules(): 355 if isinstance(module, TransformerBlock): 356 fully_shard(module, reshard_after_forward=False) 357 model = fully_shard(model, reshard_after_forward=False) 358 num_fsdp_modules = sum( 359 isinstance(module, FSDPModule) for module in model.modules() 360 ) 361 362 torch.manual_seed(42 + self.rank) 363 inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") 364 with CommDebugMode() as fwd_comm_mode: 365 loss = model(inp) 366 fwd_comm_counts = fwd_comm_mode.get_comm_counts() 367 self.assertEqual(len(fwd_comm_counts), 1) 368 self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_fsdp_modules) 369 370 for module in model.modules(): 371 if isinstance(module, FSDPModule): 372 module.reshard() 373 374 with CommDebugMode() as bwd_comm_mode: 375 loss.sum().backward() 376 bwd_comm_counts = bwd_comm_mode.get_comm_counts() 377 self.assertEqual(len(bwd_comm_counts), 2) 378 self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_fsdp_modules) 379 self.assertEqual( 380 bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_fsdp_modules 381 ) 382 383 @skip_if_lt_x_gpu(2) 384 def test_set_reduce_scatter_divide_factor(self): 385 self.run_subtests( 386 {"divide_factor": [self.world_size * 2, self.world_size]}, 387 self._test_set_reduce_scatter_divide_factor, 388 ) 389 390 def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): 391 torch.manual_seed(42) 392 model_args = ModelArgs(dropout_p=0.0, weight_tying=False) 393 model = Transformer(model_args) 394 ref_model = copy.deepcopy(model).cuda() 395 ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) 396 for module in model.modules(): 397 if isinstance(module, TransformerBlock): 398 fully_shard(module, reshard_after_forward=False) 399 model = fully_shard(model, reshard_after_forward=False) 400 optim = torch.optim.AdamW(model.parameters(), lr=1e-2) 401 model.set_reduce_scatter_divide_factor(divide_factor) 402 403 torch.manual_seed(42 + self.rank) 404 inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") 405 406 for iter_idx in range(10): 407 ref_loss = ref_model(inp).sum() 408 ref_loss.backward() 409 for param in ref_model.parameters(): 410 param.grad.mul_(1.0 / divide_factor) 411 dist.all_reduce(param.grad) 412 loss = model(inp).sum() 413 loss.backward() 414 ref_optim.step() 415 optim.step() 416 ref_optim.zero_grad() 417 optim.zero_grad() 418 self.assertEqual(ref_loss, loss) 419 check_sharded_parity(self, ref_model, model) 420 421 422class TestFullyShardPrefetch(FSDPTest): 423 @property 424 def world_size(self) -> int: 425 return min(4, torch.cuda.device_count()) 426 427 @skip_if_lt_x_gpu(2) 428 def test_fully_shard_backward_prefetch(self): 429 # Activation checkpointing should not affect the expected FSDP events 430 self.run_subtests( 431 { 432 "reshard_after_forward": [True, False, 2], 433 "checkpoint_impl": [None, "utils", "composable"], 434 }, 435 self._test_backward_prefetch_forward_backward, 436 ) 437 self.run_subtests( 438 { 439 "reshard_after_forward": [True, False, 2], 440 "checkpoint_impl": [None, "utils", "composable"], 441 }, 442 self._test_backward_prefetch_multi_forward, 443 ) 444 self._test_backward_prefetch_unused_in_backward(True) 445 446 def _test_backward_prefetch_forward_backward( 447 self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str] 448 ): 449 n_layers = 3 450 model, optim, inp = self._init_transformer( 451 n_layers, reshard_after_forward, checkpoint_impl 452 ) 453 events: List[EventType] = [] 454 unshard_with_record = self._get_unshard_with_record( 455 FSDPParamGroup.unshard, events 456 ) 457 post_backward_with_record = self._get_post_backward_with_record( 458 FSDPParamGroup.post_backward, events 459 ) 460 # Check the order for normal 1 forward, 1 backward, 1 optimizer step 461 with patch_unshard(unshard_with_record), patch_post_backward( 462 post_backward_with_record 463 ): 464 for iter_idx in range(3): 465 loss = model(inp) 466 expected_events = [ 467 ("unshard", "", TrainingState.FORWARD), # root 468 ("unshard", "layers.0", TrainingState.FORWARD), 469 ("unshard", "layers.1", TrainingState.FORWARD), 470 ("unshard", "layers.2", TrainingState.FORWARD), 471 ] 472 self.assertEqual(events, expected_events) 473 events.clear() 474 loss.sum().backward() 475 expected_events = [ 476 # Root does not reshard after forward so there is no 477 # unshard event for it in backward 478 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 479 # Explicit backward prefetching moves the unshards early 480 # by one module (note how swapping each unshard down one 481 # event would give the natural event order) 482 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 483 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 484 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 485 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 486 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 487 ("post_backward", "", TrainingState.POST_BACKWARD), 488 ] 489 if reshard_after_forward is False: 490 # No reshard after forward means no backward unshards 491 expected_events = [e for e in expected_events if e[0] != "unshard"] 492 self.assertEqual(events, expected_events) 493 events.clear() 494 optim.step() 495 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 496 497 def _test_backward_prefetch_multi_forward( 498 self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str] 499 ): 500 n_layers = 3 501 model, optim, inp = self._init_transformer( 502 n_layers, reshard_after_forward, checkpoint_impl 503 ) 504 events: List[EventType] = [] 505 unshard_with_record = self._get_unshard_with_record( 506 FSDPParamGroup.unshard, events 507 ) 508 post_backward_with_record = self._get_post_backward_with_record( 509 FSDPParamGroup.post_backward, events 510 ) 511 # Check the order for multiple forwards before 1 backward 512 with patch_unshard(unshard_with_record), patch_post_backward( 513 post_backward_with_record 514 ): 515 loss1 = model(inp) 516 loss2 = model(inp) 517 expected_events = [ 518 ("unshard", "", TrainingState.FORWARD), # root 519 ("unshard", "layers.0", TrainingState.FORWARD), 520 ("unshard", "layers.1", TrainingState.FORWARD), 521 ("unshard", "layers.2", TrainingState.FORWARD), 522 # Root does not reshard after forward so there is not another 523 # unshard event for it 524 ("unshard", "layers.0", TrainingState.FORWARD), 525 ("unshard", "layers.1", TrainingState.FORWARD), 526 ("unshard", "layers.2", TrainingState.FORWARD), 527 ] 528 if reshard_after_forward is False: 529 # No reshard after forward means no second set of unshards 530 expected_events = expected_events[:-3] 531 self.assertEqual(events, expected_events) 532 events.clear() 533 (loss1 + loss2).sum().backward() 534 expected_events = [ 535 # Same as the single forward/backward case except the root's 536 # post-backward does not run until the end of backward in the 537 # final callback (since the input not requiring gradient means 538 # that we do not have a tensor on which to hook for 539 # post-backward) 540 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 541 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 542 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 543 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 544 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 545 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 546 ] 547 if reshard_after_forward is False: 548 # No reshard after forward means no backward unshards 549 expected_events = [e for e in expected_events if e[0] != "unshard"] 550 # However, the post-backward reshards, so the second set of 551 # unshards will run as real ops 552 expected_events += [ 553 # Repeat the same pattern except with the root's post-backward 554 # at the end since the final callback runs 555 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 556 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 557 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 558 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 559 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 560 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 561 ("post_backward", "", TrainingState.POST_BACKWARD), 562 ] 563 self.assertEqual(events, expected_events) 564 events.clear() 565 566 def _test_backward_prefetch_unused_in_backward( 567 self, reshard_after_forward: Union[bool, int] 568 ): 569 """ 570 Test a model with a linear module then a split into two linear modules, 571 where we run backward through one path first before the other, meaning 572 that (1) only one linear of the two split is used per backward and (2) 573 the initial shared linear is used in both backwards. 574 """ 575 dim = 8 576 model = nn.Sequential(nn.Linear(dim, dim), DoubleLinear(dim)) 577 fully_shard(model[0], reshard_after_forward=reshard_after_forward) 578 fully_shard(model[1].lin1, reshard_after_forward=reshard_after_forward) 579 fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward) 580 fully_shard(model, reshard_after_forward=reshard_after_forward) 581 inp = torch.randn((4, dim), device="cuda") 582 events: List[EventType] = [] 583 unshard_with_record = self._get_unshard_with_record( 584 FSDPParamGroup.unshard, events 585 ) 586 post_backward_with_record = self._get_post_backward_with_record( 587 FSDPParamGroup.post_backward, events 588 ) 589 with patch_unshard(unshard_with_record), patch_post_backward( 590 post_backward_with_record 591 ): 592 loss1, loss2 = model(inp) 593 expected_events = [ 594 # Root has no parameters, so it does not have an unshard 595 ("unshard", "0", TrainingState.FORWARD), 596 ("unshard", "1.lin1", TrainingState.FORWARD), 597 ("unshard", "1.lin2", TrainingState.FORWARD), 598 ] 599 self.assertEqual(events, expected_events) 600 events.clear() 601 602 model.set_is_last_backward(False) 603 loss2.sum().backward(retain_graph=True) 604 expected_events = [ 605 ("unshard", "1.lin2", TrainingState.PRE_BACKWARD), 606 # NOTE: This `1.lin1` unshard is a mistargeted prefetch. 607 ("unshard", "1.lin1", TrainingState.PRE_BACKWARD), 608 ("post_backward", "1.lin2", TrainingState.POST_BACKWARD), 609 ("unshard", "0", TrainingState.PRE_BACKWARD), 610 ("post_backward", "0", TrainingState.POST_BACKWARD), 611 ] 612 self.assertEqual(events, expected_events) 613 events.clear() 614 615 model.set_is_last_backward(True) 616 loss1.sum().backward() 617 expected_events = [ 618 # NOTE: `1.lin1` is already unsharded from the mistargeted 619 # prefetch in the first backward. 620 # Prefetch `0` 621 ("unshard", "0", TrainingState.PRE_BACKWARD), 622 ("post_backward", "1.lin1", TrainingState.POST_BACKWARD), 623 ("post_backward", "0", TrainingState.POST_BACKWARD), 624 ] 625 self.assertEqual(events, expected_events) 626 events.clear() 627 628 @skip_if_lt_x_gpu(2) 629 def test_set_modules_to_forward_prefetch(self): 630 n_layers = 4 631 reshard_after_forward = True 632 checkpoint_impl = "utils" 633 model, _, inp = self._init_transformer( 634 n_layers, reshard_after_forward, checkpoint_impl 635 ) 636 637 def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: 638 # Use model-specific knowledge to configure forward prefetching: 639 # each transformer block (layer) prefetches for the next few 640 for i, layer in enumerate(model.layers): 641 if i >= len(model.layers) - num_to_prefetch: 642 break 643 layers_to_prefetch = [ 644 model.layers[i + j] for j in range(1, num_to_prefetch + 1) 645 ] 646 layer.set_modules_to_forward_prefetch(layers_to_prefetch) 647 648 events: List[EventType] = [] 649 unshard_with_record = self._get_unshard_with_record( 650 FSDPParamGroup.unshard, events 651 ) 652 reshard_with_record = self._get_reshard_with_record( 653 FSDPParamGroup.reshard, events 654 ) 655 post_backward_with_record = self._get_post_backward_with_record( 656 FSDPParamGroup.post_backward, events 657 ) 658 expected_backward_events = [ 659 # Default backward prefetching 660 ("unshard", "layers.3", TrainingState.PRE_BACKWARD), 661 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 662 ("reshard", "layers.3", TrainingState.POST_BACKWARD), 663 ("post_backward", "layers.3", TrainingState.POST_BACKWARD), 664 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 665 ("reshard", "layers.2", TrainingState.POST_BACKWARD), 666 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 667 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 668 ("reshard", "layers.1", TrainingState.POST_BACKWARD), 669 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 670 ("reshard", "layers.0", TrainingState.POST_BACKWARD), 671 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 672 ("reshard", "", TrainingState.POST_BACKWARD), 673 ("post_backward", "", TrainingState.POST_BACKWARD), 674 ] 675 with patch_unshard(unshard_with_record), patch_reshard( 676 reshard_with_record 677 ), patch_post_backward(post_backward_with_record): 678 set_forward_prefetch(model, num_to_prefetch=1) 679 loss = model(inp) 680 expected_forward_events = [ 681 ("unshard", "", TrainingState.FORWARD), 682 # `layers.i` prefetches `layers.i+1` 683 ("unshard", "layers.0", TrainingState.FORWARD), 684 ("unshard", "layers.1", TrainingState.FORWARD), 685 ("reshard", "layers.0", TrainingState.FORWARD), 686 ("unshard", "layers.2", TrainingState.FORWARD), 687 ("reshard", "layers.1", TrainingState.FORWARD), 688 ("unshard", "layers.3", TrainingState.FORWARD), 689 ("reshard", "layers.2", TrainingState.FORWARD), 690 ("reshard", "layers.3", TrainingState.FORWARD), 691 ] 692 self.assertEqual(events, expected_forward_events) 693 events.clear() 694 loss.sum().backward() 695 self.assertEqual(events, expected_backward_events) 696 events.clear() 697 698 set_forward_prefetch(model, num_to_prefetch=2) 699 loss = model(inp) 700 expected_forward_events = [ 701 ("unshard", "", TrainingState.FORWARD), 702 # `layers.i` prefetches `layers.i+1` and `layers.i+2` 703 ("unshard", "layers.0", TrainingState.FORWARD), 704 ("unshard", "layers.1", TrainingState.FORWARD), 705 ("unshard", "layers.2", TrainingState.FORWARD), 706 ("reshard", "layers.0", TrainingState.FORWARD), 707 ("unshard", "layers.3", TrainingState.FORWARD), 708 ("reshard", "layers.1", TrainingState.FORWARD), 709 ("reshard", "layers.2", TrainingState.FORWARD), 710 ("reshard", "layers.3", TrainingState.FORWARD), 711 ] 712 self.assertEqual(events, expected_forward_events) 713 events.clear() 714 loss.sum().backward() 715 self.assertEqual(events, expected_backward_events) 716 events.clear() 717 718 @skip_if_lt_x_gpu(2) 719 def test_set_modules_to_backward_prefetch(self): 720 n_layers = 4 721 reshard_after_forward = True 722 checkpoint_impl = "utils" 723 model, _, inp = self._init_transformer( 724 n_layers, reshard_after_forward, checkpoint_impl 725 ) 726 727 def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: 728 # Use model-specific knowledge to configure backward prefetching: 729 # each transformer block (layer) prefetches for the previous few 730 for i, layer in enumerate(model.layers): 731 if i < num_to_prefetch: 732 continue 733 layers_to_prefetch = [ 734 model.layers[i - j] for j in range(1, num_to_prefetch + 1) 735 ] 736 layer.set_modules_to_backward_prefetch(layers_to_prefetch) 737 738 events: List[EventType] = [] 739 unshard_with_record = self._get_unshard_with_record( 740 FSDPParamGroup.unshard, events 741 ) 742 reshard_with_record = self._get_reshard_with_record( 743 FSDPParamGroup.reshard, events 744 ) 745 post_backward_with_record = self._get_post_backward_with_record( 746 FSDPParamGroup.post_backward, events 747 ) 748 expected_forward_events = [ 749 # Default forward prefetching 750 ("unshard", "", TrainingState.FORWARD), # root 751 ("unshard", "layers.0", TrainingState.FORWARD), 752 ("reshard", "layers.0", TrainingState.FORWARD), 753 ("unshard", "layers.1", TrainingState.FORWARD), 754 ("reshard", "layers.1", TrainingState.FORWARD), 755 ("unshard", "layers.2", TrainingState.FORWARD), 756 ("reshard", "layers.2", TrainingState.FORWARD), 757 ("unshard", "layers.3", TrainingState.FORWARD), 758 ("reshard", "layers.3", TrainingState.FORWARD), 759 ] 760 with patch_unshard(unshard_with_record), patch_reshard( 761 reshard_with_record 762 ), patch_post_backward(post_backward_with_record): 763 set_backward_prefetch(model, num_to_prefetch=1) 764 loss = model(inp) 765 self.assertEqual(events, expected_forward_events) 766 events.clear() 767 loss.sum().backward() 768 expected_backward_events = [ 769 # Root prefetches `layers.3` per default 770 ("unshard", "layers.3", TrainingState.PRE_BACKWARD), 771 # `layers.i` prefetches for `layers.i-1` (same as default) 772 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 773 ("reshard", "layers.3", TrainingState.POST_BACKWARD), 774 ("post_backward", "layers.3", TrainingState.POST_BACKWARD), 775 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 776 ("reshard", "layers.2", TrainingState.POST_BACKWARD), 777 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 778 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 779 ("reshard", "layers.1", TrainingState.POST_BACKWARD), 780 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 781 ("reshard", "layers.0", TrainingState.POST_BACKWARD), 782 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 783 ("reshard", "", TrainingState.POST_BACKWARD), 784 ("post_backward", "", TrainingState.POST_BACKWARD), 785 ] 786 self.assertEqual(events, expected_backward_events) 787 events.clear() 788 789 set_backward_prefetch(model, num_to_prefetch=2) 790 loss = model(inp) 791 self.assertEqual(events, expected_forward_events) 792 events.clear() 793 loss.sum().backward() 794 expected_backward_events = [ 795 # Root prefetches `layers.3` per default 796 ("unshard", "layers.3", TrainingState.PRE_BACKWARD), 797 # `layers.i` prefetches for `layers.i-1` and `layers.i-2` 798 ("unshard", "layers.2", TrainingState.PRE_BACKWARD), 799 ("unshard", "layers.1", TrainingState.PRE_BACKWARD), 800 ("reshard", "layers.3", TrainingState.POST_BACKWARD), 801 ("post_backward", "layers.3", TrainingState.POST_BACKWARD), 802 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 803 ("reshard", "layers.2", TrainingState.POST_BACKWARD), 804 ("post_backward", "layers.2", TrainingState.POST_BACKWARD), 805 ("reshard", "layers.1", TrainingState.POST_BACKWARD), 806 ("post_backward", "layers.1", TrainingState.POST_BACKWARD), 807 ("reshard", "layers.0", TrainingState.POST_BACKWARD), 808 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 809 ("reshard", "", TrainingState.POST_BACKWARD), 810 ("post_backward", "", TrainingState.POST_BACKWARD), 811 ] 812 self.assertEqual(events, expected_backward_events) 813 events.clear() 814 815 @skip_if_lt_x_gpu(2) 816 def test_fully_shard_multi_module_backward_prefetch(self): 817 n_layers = 5 818 model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=True) 819 model = Transformer(model_args) 820 for i in range(n_layers): 821 if i == 0: 822 fully_shard(model.layers[i]) 823 elif i % 2 == 1: 824 fully_shard([model.layers[i], model.layers[i + 1]]) 825 fully_shard([model.tok_embeddings, model.pos_embeddings]) 826 fully_shard([model.norm, model.output], reshard_after_forward=False) 827 fully_shard(model) 828 optim = torch.optim.AdamW(model.parameters(), lr=1e-2) 829 830 events: List[EventType] = [] 831 unshard_with_record = self._get_unshard_with_record( 832 FSDPParamGroup.unshard, events 833 ) 834 post_backward_with_record = self._get_post_backward_with_record( 835 FSDPParamGroup.post_backward, events 836 ) 837 inp = torch.randint( 838 0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda" 839 ) 840 with patch_unshard(unshard_with_record), patch_post_backward( 841 post_backward_with_record 842 ): 843 for iter_idx in range(3): 844 loss = model(inp) 845 expected_events = [ 846 ( 847 "unshard", 848 "tok_embeddings, pos_embeddings", 849 TrainingState.FORWARD, 850 ), 851 ("unshard", "layers.0", TrainingState.FORWARD), 852 ("unshard", "layers.1, layers.2", TrainingState.FORWARD), 853 ("unshard", "layers.3, layers.4", TrainingState.FORWARD), 854 ("unshard", "norm, output", TrainingState.FORWARD), 855 ] 856 self.assertEqual(events, expected_events) 857 events.clear() 858 loss.sum().backward() 859 expected_events = [ 860 # (norm, output) does not reshard after forward, so there is 861 # no unshard to begin backward 862 ("unshard", "layers.3, layers.4", TrainingState.PRE_BACKWARD), 863 ("post_backward", "norm, output", TrainingState.POST_BACKWARD), 864 ("unshard", "layers.1, layers.2", TrainingState.PRE_BACKWARD), 865 ( 866 "post_backward", 867 "layers.3, layers.4", 868 TrainingState.POST_BACKWARD, 869 ), 870 ("unshard", "layers.0", TrainingState.PRE_BACKWARD), 871 ( 872 "post_backward", 873 "layers.1, layers.2", 874 TrainingState.POST_BACKWARD, 875 ), 876 ( 877 "unshard", 878 "tok_embeddings, pos_embeddings", 879 TrainingState.PRE_BACKWARD, 880 ), 881 ("post_backward", "layers.0", TrainingState.POST_BACKWARD), 882 ( 883 "post_backward", 884 "tok_embeddings, pos_embeddings", 885 TrainingState.POST_BACKWARD, 886 ), 887 ] 888 events.clear() 889 optim.step() 890 optim.zero_grad() 891 892 @skip_if_lt_x_gpu(2) 893 def test_fully_shard_multi_module_unused_module(self): 894 class ModuleWithUnusedLinear(nn.Module): 895 def __init__(self) -> None: 896 super().__init__() 897 self.unused_lin = nn.Linear(1, 1) 898 self.lin = nn.Linear(16, 16) 899 900 def forward(self, x: torch.Tensor) -> torch.Tensor: 901 return nn.functional.relu(self.lin(x)) 902 903 model = nn.Sequential( 904 ModuleWithUnusedLinear(), ModuleWithUnusedLinear(), nn.Linear(16, 16) 905 ) 906 fully_shard([model[0].unused_lin, model[0].lin], reshard_after_forward=True) 907 fully_shard([model[1].unused_lin, model[1].lin], reshard_after_forward=True) 908 fully_shard(model) 909 optim = torch.optim.AdamW(model.parameters(), lr=1e-2) 910 911 events: List[EventType] = [] 912 unshard_with_record = self._get_unshard_with_record( 913 FSDPParamGroup.unshard, events 914 ) 915 post_backward_with_record = self._get_post_backward_with_record( 916 FSDPParamGroup.post_backward, events 917 ) 918 inp = torch.randn((2, 16), device="cuda") 919 with patch_unshard(unshard_with_record), patch_post_backward( 920 post_backward_with_record 921 ): 922 for iter_idx in range(3): 923 loss = model(inp) 924 expected_events = [ 925 ("unshard", "", TrainingState.FORWARD), 926 ("unshard", "0.unused_lin, 0.lin", TrainingState.FORWARD), 927 ("unshard", "1.unused_lin, 1.lin", TrainingState.FORWARD), 928 ] 929 self.assertEqual(events, expected_events) 930 events.clear() 931 loss.sum().backward() 932 expected_events = [ 933 # Since both `model[0]` and `model[1]` have unused modules 934 # that never ran forward, they do not reshard after forward 935 # despite setting it to `True`. Check that there are no 936 # unshards in backward. 937 ( 938 "post_backward", 939 "1.unused_lin, 1.lin", 940 TrainingState.POST_BACKWARD, 941 ), 942 ( 943 "post_backward", 944 "0.unused_lin, 0.lin", 945 TrainingState.POST_BACKWARD, 946 ), 947 ("post_backward", "", TrainingState.POST_BACKWARD), 948 ] 949 events.clear() 950 optim.step() 951 optim.zero_grad() 952 953 def _init_transformer( 954 self, 955 n_layers: int, 956 reshard_after_forward: Union[bool, int], 957 checkpoint_impl: Optional[str], 958 ): 959 model_args = ModelArgs( 960 n_layers=n_layers, checkpoint_activations=(checkpoint_impl == "utils") 961 ) 962 model = Transformer(model_args) 963 for module in model.modules(): 964 if isinstance(module, TransformerBlock): 965 if checkpoint_impl == "composable": 966 checkpoint(module) 967 fully_shard(module, reshard_after_forward=reshard_after_forward) 968 fully_shard(model, reshard_after_forward=reshard_after_forward) 969 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 970 inp = torch.randint( 971 0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda" 972 ) 973 return model, optim, inp 974 975 def _get_unshard_with_record( 976 self, orig_unshard: Callable, events: List[EventType] 977 ) -> Callable: 978 def unshard_with_record(self, *args, **kwargs): 979 nonlocal events 980 if ( 981 self._all_gather_result is None 982 and self._sharded_state != ShardedState.UNSHARDED 983 ): # skip no-ops 984 events.append(("unshard", self._module_fqn, self._training_state)) 985 return orig_unshard(self, *args, **kwargs) 986 987 return unshard_with_record 988 989 def _get_reshard_with_record( 990 self, orig_reshard: Callable, events: List[EventType] 991 ) -> Callable: 992 def reshard_with_record(self, *args, **kwargs): 993 nonlocal events 994 if ( 995 self._training_state == TrainingState.FORWARD 996 and not self._reshard_after_forward 997 ): # skip no-ops 998 return 999 events.append(("reshard", self._module_fqn, self._training_state)) 1000 return orig_reshard(self, *args, **kwargs) 1001 1002 return reshard_with_record 1003 1004 def _get_post_backward_with_record( 1005 self, orig_post_backward: Callable, events: List[EventType] 1006 ) -> Callable: 1007 def post_backward_with_record(self, *args, **kwargs): 1008 nonlocal events 1009 ret = orig_post_backward(self, *args, **kwargs) 1010 # Use training state after running post-backward to check that the 1011 # state is transitioned to `POST_BACKWARD` as expected 1012 events.append(("post_backward", self._module_fqn, self._training_state)) 1013 return ret 1014 1015 return post_backward_with_record 1016 1017 1018class TestFullyShardUnshardMultiProcess(FSDPTest): 1019 @property 1020 def world_size(self) -> int: 1021 return min(torch.cuda.device_count(), 2) 1022 1023 @skip_if_lt_x_gpu(2) 1024 def test_unshard_async(self): 1025 class ReduceModule(nn.Module): 1026 def __init__(self, dim: int, mesh: DeviceMesh): 1027 super().__init__() 1028 self.mesh = mesh 1029 self.weight = nn.Parameter(torch.randn(dim, dim)) 1030 1031 def forward(self, x: torch.Tensor): 1032 y = F.relu(x @ self.weight) 1033 # NOTE: This all-reduce is not differentiable and is included 1034 # to exercise the overlap. 1035 work = dist.all_reduce(y, group=self.mesh.get_group(), async_op=True) 1036 return y, work 1037 1038 class MLPs(nn.Module): 1039 def __init__(self, dim: int): 1040 super().__init__() 1041 self.mlp1 = MLP(dim) 1042 self.mlp2 = MLP(dim) 1043 self.mlp3 = MLP(dim) 1044 1045 def forward(self, ys: List[torch.Tensor], works: List[dist.Work]): 1046 (y1, y2, y3), (work1, work2, work3) = ys, works 1047 work1.wait() 1048 z1 = self.mlp1(y1) 1049 work2.wait() 1050 z2 = self.mlp2(y2) 1051 work3.wait() 1052 z3 = self.mlp3(y3) 1053 return z1 + z2 + z3 1054 1055 class ReduceModel(nn.Module): 1056 def __init__(self, dim: int, mesh: DeviceMesh): 1057 super().__init__() 1058 self.reduce_module1 = ReduceModule(dim, mesh) 1059 self.reduce_module2 = ReduceModule(dim, mesh) 1060 self.reduce_module3 = ReduceModule(dim, mesh) 1061 self.mlps = MLPs(dim) 1062 1063 def forward(self, x: torch.Tensor): 1064 y1, work1 = self.reduce_module1(x) 1065 if isinstance(self.mlps.mlp1, FSDPModule): 1066 self.mlps.mlp1.unshard(async_op=True) 1067 y2, work2 = self.reduce_module2(x) 1068 if isinstance(self.mlps.mlp2, FSDPModule): 1069 self.mlps.mlp2.unshard(async_op=True) 1070 y3, work3 = self.reduce_module3(x) 1071 if isinstance(self.mlps.mlp3, FSDPModule): 1072 self.mlps.mlp3.unshard(async_op=True) 1073 return self.mlps([y1, y2, y3], [work1, work2, work3]) 1074 1075 mesh = init_device_mesh("cuda", (self.world_size,)) 1076 batch_size, dim = 2, 8 1077 torch.manual_seed(42) 1078 ref_model = replicate(ReduceModel(dim, mesh).cuda()) 1079 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 1080 torch.manual_seed(42) 1081 model = ReduceModel(dim, mesh) 1082 fully_shard(model.mlps.mlp1, reshard_after_forward=False) 1083 fully_shard(model.mlps.mlp2, reshard_after_forward=False) 1084 fully_shard(model.mlps.mlp3, reshard_after_forward=False) 1085 fully_shard(model.mlps) 1086 replicate(model.cuda()) 1087 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 1088 torch.manual_seed(42 + self.rank + 1) 1089 inp = torch.randn((batch_size, dim), device="cuda") 1090 for _ in range(10): 1091 losses: List[torch.Tensor] = [] 1092 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 1093 losses.append(_model(inp).sum()) 1094 losses[-1].backward() 1095 with implicit_replication(): 1096 _optim.step() 1097 _optim.zero_grad() 1098 self.assertEqual(losses[0], losses[1]) 1099 1100 1101class TestFullyShardUnshardMultiThread(FSDPTestMultiThread): 1102 @property 1103 def world_size(self) -> int: 1104 return 2 1105 1106 @unittest.skipIf(not TEST_CUDA, "no cuda") 1107 def test_unshard_no_param_group(self): 1108 # Check that we can call `unshard()` on a module with no parameter 1109 # group / no managed parameters without erroring 1110 model = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4)) 1111 for lin in model: 1112 fully_shard(lin) 1113 fully_shard(model) 1114 handle = model.unshard(async_op=True) 1115 handle.wait() 1116 1117 @unittest.skipIf(not TEST_CUDA, "no cuda") 1118 def test_unshard_without_lazy_init(self): 1119 torch.manual_seed(42) 1120 model = MLP(4) 1121 for param in model.parameters(): 1122 dist.broadcast(param, src=0) 1123 ref_model = copy.deepcopy(model) 1124 fully_shard(model) 1125 model.unshard() # no lazy init yet 1126 for ref_param, param in zip(ref_model.parameters(), model.parameters()): 1127 self.assertEqual(ref_param, param) 1128 1129 1130if __name__ == "__main__": 1131 run_tests() 1132