1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import copy 4import logging 5import os 6import sys 7import tempfile 8 9from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw 10from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW 11 12import torch 13import torch.distributed as dist 14from torch.distributed.pipelining import ( 15 _ScheduleForwardOnly, 16 pipeline, 17 PipelineStage, 18 Schedule1F1B, 19 ScheduleFlexibleInterleaved1F1B, 20 ScheduleGPipe, 21 ScheduleInterleaved1F1B, 22 ScheduleInterleavedZeroBubble, 23 ScheduleLoopedBFS, 24) 25from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime 26from torch.testing._internal.common_cuda import TEST_MULTIGPU 27from torch.testing._internal.common_distributed import ( 28 MultiProcContinousTest, 29 requires_nccl, 30) 31from torch.testing._internal.common_utils import ( 32 instantiate_parametrized_tests, 33 parametrize, 34 skip_but_pass_in_sandcastle_if, 35) 36 37 38logger = logging.getLogger(__name__) 39 40d_hid = 512 41batch_size = 256 42 43torch.manual_seed(0) 44 45 46class ScheduleTest(MultiProcContinousTest): 47 @classmethod 48 def backend_str(cls) -> str: 49 # Testing with NCCL backend 50 return "nccl" 51 52 @classmethod 53 def setUpClass(cls): 54 """ 55 Class-scope test fixture. Run once for entire test class, before any test starts. 56 Set up the device. 57 """ 58 super().setUpClass() 59 dev_id = cls.rank % torch.cuda.device_count() 60 cls.device = torch.device(f"cuda:{dev_id}") 61 62 @requires_nccl() 63 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 64 @parametrize("ScheduleClass", [_ScheduleForwardOnly]) 65 def test_forward_only(self, ScheduleClass): 66 mod = MultiMLP(d_hid, n_layers=self.world_size) 67 mod.to(self.device) 68 69 mod_ref = copy.deepcopy(mod) 70 71 x = torch.randn(batch_size, d_hid, device=self.device) 72 x_clone = x.clone() 73 74 num_microbatches = 4 75 x_mb = x.chunk(num_microbatches)[0] 76 77 # Create a pipeline 78 split_spec = mod.split_spec if hasattr(mod, "split_spec") else None 79 pipe = pipeline( 80 mod, 81 mb_args=(x_mb,), 82 split_spec=split_spec, 83 ) 84 85 stage = pipe.build_stage( 86 self.rank, 87 self.device, 88 ) 89 90 # Attach to a schedule 91 schedule = ScheduleClass(stage, num_microbatches) 92 93 # Run 94 num_iters = 20 95 for _ in range(num_iters): 96 if self.rank == 0: 97 schedule.step(x) 98 dist.recv(x, src=self.world_size - 1) 99 elif self.rank == self.world_size - 1: 100 out = schedule.step() 101 dist.send(out, dst=0) 102 else: 103 schedule.step() 104 105 # Validate pipelined output is the same as reference model 106 if self.rank == self.world_size - 1: 107 for _ in range(num_iters): 108 x_clone = mod_ref(x_clone) 109 110 torch.testing.assert_close(x_clone, out) 111 112 @requires_nccl() 113 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 114 @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) 115 def test_multi_iter(self, ScheduleClass): 116 mod = MultiMLP(d_hid, n_layers=self.world_size) 117 mod.to(self.device) 118 119 x = torch.randn(batch_size, d_hid, device=self.device) 120 target = torch.randn(batch_size, d_hid, device=self.device) 121 loss_fn = torch.nn.MSELoss(reduction="sum") 122 123 chunks = 4 124 x_mb = x.chunk(chunks)[0] 125 126 # Create a pipeline 127 split_spec = mod.split_spec if hasattr(mod, "split_spec") else None 128 pipe = pipeline( 129 mod, 130 mb_args=(x_mb,), 131 split_spec=split_spec, 132 ) 133 134 stage = pipe.build_stage( 135 self.rank, 136 self.device, 137 ) 138 139 # Attach to a schedule 140 schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) 141 142 # Run 143 for _ in range(20): 144 if self.rank == 0: 145 schedule.step(x) 146 elif self.rank == self.world_size - 1: 147 losses = [] 148 out = schedule.step(target=target, losses=losses) 149 else: 150 schedule.step() 151 152 @requires_nccl() 153 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 154 @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) 155 def test_kwargs_with_tracer(self, ScheduleClass): 156 mod = ModelWithKwargs(d_hid) 157 mod.to(self.device) 158 159 x = torch.randn(batch_size, d_hid, device=self.device) 160 y = torch.randn(batch_size, d_hid, device=self.device) 161 target = torch.randn(batch_size, d_hid, device=self.device) 162 loss_fn = torch.nn.MSELoss(reduction="sum") 163 164 chunks = 4 165 x_mb = x.chunk(chunks)[0] 166 y_mb = y.chunk(chunks)[0] 167 168 pipe = pipeline( 169 mod, 170 mb_args=(x_mb,), 171 mb_kwargs={"y": y_mb}, 172 ) 173 174 stage = pipe.build_stage( 175 self.rank, 176 self.device, 177 ) 178 179 # Attach to a schedule 180 schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) 181 182 # Run 183 if self.rank == 0: 184 schedule.step(x, y=y) 185 elif self.rank == self.world_size - 1: 186 losses = [] 187 out = schedule.step(target=target, losses=losses) 188 else: 189 schedule.step() 190 191 dist.barrier() 192 193 # Last rank checks result 194 if self.rank == self.world_size - 1: 195 ref_out = mod(x, y=y) 196 ref_loss = loss_fn(ref_out, target) 197 pipe_loss = sum(losses) 198 torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) 199 torch.testing.assert_close(pipe_loss, ref_loss) 200 201 @requires_nccl() 202 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 203 @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) 204 @parametrize("ModelClass", [MultiMLP]) 205 def test_grad_with_tracer(self, ScheduleClass, ModelClass): 206 mod = ModelClass(d_hid) 207 mod.to(self.device) 208 209 ref_mod = copy.deepcopy(mod) 210 x = torch.randn(batch_size, d_hid, device=self.device) 211 with torch.no_grad(): 212 y = ref_mod(x) 213 # Add a small perturbation 214 target = y + torch.randn(batch_size, d_hid, device=self.device) 215 216 loss_fn = torch.nn.MSELoss(reduction="sum") 217 218 # Run reference 219 for _ in range(2): 220 ref_mod.zero_grad() 221 ref_out = ref_mod(x) 222 ref_loss = loss_fn(ref_out, target) 223 ref_loss.backward() 224 225 # Create a pipeline 226 chunks = 4 227 x_mb = x.chunk(chunks)[0] 228 split_spec = mod.split_spec if hasattr(mod, "split_spec") else None 229 pipe = pipeline( 230 mod, 231 mb_args=(x_mb,), 232 split_spec=split_spec, 233 ) 234 235 stage = pipe.build_stage( 236 self.rank, 237 self.device, 238 ) 239 240 # Attach to a schedule 241 schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) 242 243 # Run 244 stage_module = pipe.get_stage_module(self.rank) 245 for _ in range(2): 246 # Zero gradients 247 stage_module.zero_grad() 248 if self.rank == 0: 249 schedule.step(x) 250 elif self.rank == self.world_size - 1: 251 losses = [] 252 out = schedule.step(target=target, losses=losses) 253 else: 254 schedule.step() 255 256 dist.barrier() 257 258 # Last rank checks result 259 if self.rank == self.world_size - 1: 260 # Check output 261 torch.testing.assert_close(out, ref_out) 262 # Check loss 263 # Since the reduction used in the loss function above is "sum", we use 264 # "sum" here to reduce microbatch losses into a single value too. 265 pipe_loss = sum(losses) 266 torch.testing.assert_close(pipe_loss, ref_loss) 267 268 # Every rank checks gradients 269 for name, p in stage_module.named_parameters(): 270 ref_p = ref_mod.get_parameter(name) 271 try: 272 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 273 except AssertionError: 274 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 275 raise 276 277 @requires_nccl() 278 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 279 @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) 280 def test_grad_with_manual(self, ScheduleClass): 281 full_mod = MultiMLP(d_hid, n_layers=self.world_size) 282 full_mod.to(self.device) 283 284 ref_mod = copy.deepcopy(full_mod) 285 x = torch.randn(batch_size, d_hid, device=self.device) 286 with torch.no_grad(): 287 y = ref_mod(x) 288 # Add a small perturbation 289 target = y + torch.randn(batch_size, d_hid, device=self.device) 290 291 loss_fn = torch.nn.MSELoss(reduction="sum") 292 293 # Run reference 294 for _ in range(2): 295 ref_mod.zero_grad() 296 ref_out = ref_mod(x) 297 ref_loss = loss_fn(ref_out, target) 298 ref_loss.backward() 299 300 # Get a submodule, e.g. `layers.0` or `layers.1` 301 submod_name = f"layers.{self.rank}" 302 stage_module = full_mod.get_submodule(submod_name) 303 chunks = 4 304 # Create a pipeline stage to wrap that submodule 305 stage = PipelineStage( 306 stage_module, 307 self.rank, 308 self.world_size, 309 self.device, 310 input_args=x.chunk(chunks)[0], 311 ) 312 313 # Attach to a schedule 314 schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) 315 316 # Run 317 for _ in range(2): 318 # Zero gradients 319 stage_module.zero_grad() 320 if self.rank == 0: 321 schedule.step(x) 322 elif self.rank == self.world_size - 1: 323 losses = [] 324 out = schedule.step(target=target, losses=losses) 325 else: 326 schedule.step() 327 328 dist.barrier() 329 330 # Last rank checks result 331 if self.rank == self.world_size - 1: 332 # Check output 333 torch.testing.assert_close(out, ref_out) 334 # Check loss 335 # Since the reduction used in the loss function above is "sum", we use 336 # "sum" here to reduce microbatch losses into a single value too. 337 pipe_loss = sum(losses) 338 torch.testing.assert_close(pipe_loss, ref_loss) 339 340 # Every rank checks gradients 341 ref_submod = ref_mod.get_submodule(submod_name) 342 for name, p in stage_module.named_parameters(): 343 ref_p = ref_submod.get_parameter(name) 344 try: 345 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 346 except AssertionError: 347 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 348 raise 349 350 @requires_nccl() 351 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 352 @parametrize( 353 "ScheduleClass", 354 [ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble], 355 ) 356 @parametrize("use_new_runtime", [False, True]) 357 def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): 358 stages_per_rank = 2 359 n_stages = stages_per_rank * self.world_size 360 full_mod = MultiMLP(d_hid, n_layers=n_stages) 361 full_mod.to(self.device) 362 363 ref_mod = copy.deepcopy(full_mod) 364 x = torch.randn(batch_size, d_hid, device=self.device) 365 with torch.no_grad(): 366 y = ref_mod(x) 367 # Add a small perturbation 368 target = y + torch.randn(batch_size, d_hid, device=self.device) 369 370 loss_fn = torch.nn.MSELoss(reduction="sum") 371 372 # Run reference 373 for _ in range(2): 374 ref_mod.zero_grad() 375 ref_out = ref_mod(x) 376 ref_loss = loss_fn(ref_out, target) 377 ref_loss.backward() 378 379 # Get a submodule, e.g. `layers.0` or `layers.1` 380 stage_indices = [ 381 self.rank + i * self.world_size for i in range(stages_per_rank) 382 ] 383 print(f"Rank {self.rank} stages: {stage_indices}") 384 submod_names = [f"layers.{i}" for i in stage_indices] 385 stage_modules = [ 386 full_mod.get_submodule(submod_name) for submod_name in submod_names 387 ] 388 # Create a pipeline stage to wrap that submodule 389 num_microbatches = ( 390 ScheduleClass.num_microbatches 391 if hasattr(ScheduleClass, "num_microbatches") 392 else 8 393 ) 394 input_args = x.chunk(num_microbatches)[0] 395 stages = [ 396 PipelineStage( 397 stage_module, 398 stage_idx, 399 n_stages, 400 self.device, 401 input_args=input_args, 402 ) 403 for stage_module, stage_idx in zip(stage_modules, stage_indices) 404 ] 405 406 # Attach to a schedule 407 schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn) 408 if use_new_runtime: 409 old_schedule = schedule 410 tmp_schedule = _PipelineScheduleRuntime( 411 stages, 412 num_microbatches, 413 loss_fn=loss_fn, 414 stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, 415 use_full_backward=old_schedule.use_full_backward, 416 ) 417 tmp_schedule._load_actions(old_schedule.pipeline_order) 418 # test that csv round-trip works for compute_comms schedule 419 schedule = _PipelineScheduleRuntime( 420 stages, 421 num_microbatches, 422 loss_fn=loss_fn, 423 stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, 424 use_full_backward=old_schedule.use_full_backward, 425 ) 426 with tempfile.NamedTemporaryFile() as f: 427 tmp_schedule._dump_csv(f.name) 428 f.seek(0) 429 schedule._load_csv(f.name, format="compute_comms") 430 one_more_schedule = _PipelineScheduleRuntime( 431 stages, 432 num_microbatches, 433 loss_fn=loss_fn, 434 stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, 435 use_full_backward=old_schedule.use_full_backward, 436 ) 437 one_more_schedule._load_actions( 438 schedule.pipeline_order_with_comms, format="compute_comms" 439 ) 440 self.assertEqual( 441 len(schedule.pipeline_order_with_comms), 442 len( 443 one_more_schedule.pipeline_order_with_comms, 444 ), 445 ) 446 for rank in schedule.pipeline_order_with_comms: 447 self.assertEqual( 448 len(schedule.pipeline_order_with_comms[rank]), 449 len( 450 one_more_schedule.pipeline_order_with_comms[rank], 451 ), 452 ) 453 for a, b in zip( 454 schedule.pipeline_order_with_comms[rank], 455 one_more_schedule.pipeline_order_with_comms[rank], 456 ): 457 self.assertEqual(a, b) 458 459 # Run 460 for _ in range(2): 461 # Zero gradients 462 for stage_module in stage_modules: 463 stage_module.zero_grad() 464 if self.rank == 0: 465 schedule.step(x) 466 elif self.rank == self.world_size - 1: 467 losses = [] 468 out = schedule.step(target=target, losses=losses) 469 else: 470 schedule.step() 471 472 dist.barrier() 473 474 # Last rank checks result 475 if self.rank == self.world_size - 1: 476 # Check output 477 torch.testing.assert_close(out, ref_out) 478 # Check loss 479 # Since the reduction used in the loss function above is "sum", we use 480 # "sum" here to reduce microbatch losses into a single value too. 481 pipe_loss = sum(losses) 482 torch.testing.assert_close(pipe_loss, ref_loss) 483 484 # Every rank checks gradients 485 for stage_module, submod_name in zip(stage_modules, submod_names): 486 # Get corresponding submodule from reference model 487 ref_submod = ref_mod.get_submodule(submod_name) 488 # Check gradients per parameter 489 for name, p in stage_module.named_parameters(): 490 ref_p = ref_submod.get_parameter(name) 491 try: 492 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 493 except AssertionError: 494 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 495 raise 496 497 @requires_nccl() 498 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 499 @parametrize("ScheduleClass", [ScheduleWithW, ScheduleFlexibleInterleaved1F1B]) 500 def test_schedule_with_native_zero_bubble(self, ScheduleClass): 501 print(ScheduleClass) 502 if ScheduleClass is ScheduleFlexibleInterleaved1F1B: 503 n_stages = 4 504 num_microbatches = 8 505 rank_stages = { 506 0: [0, 2], 507 1: [1, 3], 508 } 509 else: 510 n_stages = ScheduleClass.n_stages 511 num_microbatches = ScheduleClass.num_microbatches 512 rank_stages = ScheduleClass.rank_stages 513 514 num_steps = 4 515 full_mod = MultiMLP(d_hid, n_layers=n_stages) 516 full_mod.to(self.device) 517 518 ref_mod = copy.deepcopy(full_mod) 519 x = torch.randn(batch_size, d_hid, device=self.device) 520 # x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True) 521 with torch.no_grad(): 522 y = ref_mod(x) 523 # Add a small perturbation 524 target = y + torch.randn(batch_size, d_hid, device=self.device) 525 526 loss_fn = torch.nn.MSELoss(reduction="sum") 527 528 # Create a pipeline stage to wrap that submodule 529 input_args = x.chunk(num_microbatches)[0] 530 stage_indices = rank_stages[self.rank] 531 print(f"Rank {self.rank} stages: {stage_indices}") 532 submod_names = [f"layers.{i}" for i in stage_indices] 533 stage_modules = [ 534 full_mod.get_submodule(submod_name) for submod_name in submod_names 535 ] 536 stages = [ 537 PipelineStage( 538 stage_module, 539 stage_idx, 540 n_stages, 541 self.device, 542 input_args=input_args, 543 ) 544 for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) 545 ] 546 547 schedule = ScheduleClass( 548 stages, num_microbatches, loss_fn=loss_fn, enable_zero_bubble=True 549 ) 550 551 # Run reference 552 ref_x = x.clone().detach().requires_grad_(x.requires_grad) 553 torch.testing.assert_close(x, ref_x) 554 for _ in range(num_steps): 555 ref_out = ref_mod(ref_x) 556 ref_loss = loss_fn(ref_out, target) 557 ref_loss.backward() 558 559 # Run pipelined stages 560 for _ in range(num_steps): 561 if self.rank == 0: 562 schedule.step(x) 563 elif self.rank == self.world_size - 1: 564 losses = [] 565 out = schedule.step(target=target, losses=losses) 566 else: 567 schedule.step() 568 569 # Every rank checks parameters compared with the reference model 570 for stage_module, submod_name in zip(stage_modules, submod_names): 571 # Get corresponding submodule from reference model 572 ref_submod = ref_mod.get_submodule(submod_name) 573 # Check gradients per parameter 574 for name, p in stage_module.named_parameters(): 575 ref_p = ref_submod.get_parameter(name) 576 try: 577 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 578 except AssertionError: 579 print( 580 f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}" 581 ) 582 raise 583 584 @requires_nccl() 585 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 586 @parametrize("ScheduleClass", [ScheduleVShaped, ScheduleUnbalanced]) 587 def test_non_symmetric_stage_ids(self, ScheduleClass): 588 n_stages = ScheduleClass.n_stages 589 full_mod = MultiMLP(d_hid, n_layers=n_stages) 590 full_mod.to(self.device) 591 592 ref_mod = copy.deepcopy(full_mod) 593 x = torch.randn(batch_size, d_hid, device=self.device) 594 with torch.no_grad(): 595 y = ref_mod(x) 596 # Add a small perturbation 597 target = y + torch.randn(batch_size, d_hid, device=self.device) 598 599 loss_fn = torch.nn.MSELoss(reduction="sum") 600 601 # Run reference 602 for _ in range(2): 603 ref_mod.zero_grad() 604 ref_out = ref_mod(x) 605 ref_loss = loss_fn(ref_out, target) 606 ref_loss.backward() 607 608 # Create a pipeline stage to wrap that submodule 609 chunks = 1 610 input_args = x.chunk(chunks)[0] 611 rank_stages = ScheduleClass.rank_stages 612 stage_indices = rank_stages[self.rank] 613 print(f"Rank {self.rank} stages: {stage_indices}") 614 submod_names = [f"layers.{i}" for i in stage_indices] 615 stage_modules = [ 616 full_mod.get_submodule(submod_name) for submod_name in submod_names 617 ] 618 stages = [ 619 PipelineStage( 620 stage_module, 621 stage_idx, 622 n_stages, 623 self.device, 624 input_args=input_args, 625 ) 626 for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) 627 ] 628 629 # Attach to a schedule 630 stage_index_to_group_rank = { 631 value: key for key, values in rank_stages.items() for value in values 632 } 633 schedule = ScheduleClass( 634 stages, chunks, stage_index_to_group_rank, loss_fn=loss_fn 635 ) 636 637 # Run 638 # TODO how to better specify .step() when first and last stage are on rank 0... 639 for _ in range(2): 640 # Zero gradients 641 for stage_module in stage_modules: 642 stage_module.zero_grad() 643 if self.rank == 0: 644 losses = [] 645 out = schedule.step(x, target=target, losses=losses) 646 else: 647 schedule.step() 648 649 dist.barrier() 650 651 # Last rank checks result 652 if self.rank == 0: 653 # Check output 654 torch.testing.assert_close(out, ref_out) 655 # Check loss 656 # Since the reduction used in the loss function above is "sum", we use 657 # "sum" here to reduce microbatch losses into a single value too. 658 pipe_loss = sum(losses) 659 torch.testing.assert_close(pipe_loss, ref_loss) 660 661 # Every rank checks gradients 662 for stage_module, submod_name in zip(stage_modules, submod_names): 663 # Get corresponding submodule from reference model 664 ref_submod = ref_mod.get_submodule(submod_name) 665 # Check gradients per parameter 666 for name, p in stage_module.named_parameters(): 667 ref_p = ref_submod.get_parameter(name) 668 try: 669 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 670 except AssertionError: 671 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 672 raise 673 674 @requires_nccl() 675 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 676 @parametrize("ScheduleClass", [ScheduleFlexibleInterleaved1F1B]) 677 def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): 678 stages_per_rank = 2 679 n_stages = stages_per_rank * self.world_size 680 full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages) 681 full_mod.to(self.device) 682 683 ref_mod = copy.deepcopy(full_mod) 684 x = torch.randn(batch_size, d_hid, device=self.device) 685 with torch.no_grad(): 686 y = ref_mod(x) 687 # Add a small perturbation 688 target = y + torch.randn(batch_size, d_hid, device=self.device) 689 690 ref_loss_fn = torch.nn.MSELoss(reduction="sum") 691 full_loss_fn = torch.nn.MSELoss(reduction="sum") 692 693 full_mod.toggle() 694 695 # Get a submodule, e.g. `layers.0` or `layers.1` 696 stage_indices = [ 697 self.rank + i * self.world_size for i in range(stages_per_rank) 698 ] 699 submod_names = [f"layers.{i}" for i in stage_indices] 700 stage_modules = [ 701 full_mod.get_submodule(submod_name) for submod_name in submod_names 702 ] 703 704 # Run reference 705 for _ in range(2): 706 ref_stage_modules = [ 707 ref_mod.get_submodule(submod_name) for submod_name in submod_names 708 ] 709 for stage_module in ref_stage_modules: 710 stage_module.zero_grad() 711 712 ref_mod.zero_grad() 713 ref_out = ref_mod(x) 714 ref_loss = ref_loss_fn(ref_out, target) 715 ref_loss.backward() 716 717 class CustomState: 718 def __init__(self, stage_module, stage_idx, rank): 719 self.i = 0 720 self.stage_module = stage_module 721 self.stage_idx = stage_idx 722 self.rank = rank 723 724 def dw_builder(self): 725 def dw_runner(): 726 # This inner function would be called by PipelineStage during `backward_weight_one_chunk` 727 self.i += 1 728 print( 729 f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}" 730 ) 731 self.stage_module.compute_dW() 732 733 return dw_runner 734 735 cs = {} 736 for stage_module, stage_idx in zip(stage_modules, stage_indices): 737 cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank) 738 739 # Create a pipeline stage to wrap that submodule 740 chunks = 2 741 input_args = x.chunk(chunks)[0] 742 stages = [ 743 PipelineStage( 744 stage_module, 745 stage_idx, 746 n_stages, 747 self.device, 748 input_args=input_args, 749 dw_builder=cs[stage_idx].dw_builder, 750 ) 751 for stage_module, stage_idx in zip(stage_modules, stage_indices) 752 ] 753 754 # Attach to a schedule 755 schedule = ScheduleClass( 756 stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True 757 ) 758 759 for _ in range(2): 760 # Zero gradients 761 for stage_module in stage_modules: 762 stage_module.zero_grad() 763 if self.rank == 0: 764 schedule.step(x) 765 elif self.rank == self.world_size - 1: 766 losses = [] 767 out = schedule.step(target=target, losses=losses) 768 else: 769 schedule.step() 770 771 dist.barrier() 772 # Last rank checks result 773 if self.rank == self.world_size - 1: 774 # Check output 775 torch.testing.assert_close(out, ref_out) 776 777 # Check loss 778 # Since the reduction used in the loss function above is "sum", we use 779 # "sum" here to reduce microbatch losses into a single value too. 780 pipe_loss = sum(losses) 781 torch.testing.assert_close(pipe_loss, ref_loss) 782 783 # Every rank checks gradients 784 for stage_module, submod_name in zip(stage_modules, submod_names): 785 # Get corresponding submodule from reference model 786 ref_submod = ref_mod.get_submodule(submod_name) 787 # Check gradients per parameter 788 for name, p in stage_module.named_parameters(): 789 ref_p = ref_submod.get_parameter(name) 790 torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) 791 792 793instantiate_parametrized_tests(ScheduleTest) 794 795 796if __name__ == "__main__": 797 # Check if GPU and NCCL are available 798 if not ( 799 dist.is_available() 800 and dist.is_nccl_available() 801 and torch.cuda.device_count() > 1 802 ): 803 print( 804 "c10d NCCL not available or not enough GPUs, skipping tests", 805 file=sys.stderr, 806 ) 807 sys.exit(0) 808 809 rank = int(os.getenv("RANK", -1)) 810 world_size = int(os.getenv("WORLD_SIZE", 2)) 811 812 if rank != -1: 813 # Launched with torchrun or other multi-proc launchers. Directly run the test. 814 ScheduleTest.run_rank(rank, world_size) 815 else: 816 # Launched as a single process. Spawn subprocess to run the tests. 817 # Also need a rendezvous file for `init_process_group` purpose. 818 rdvz_file = tempfile.NamedTemporaryFile(delete=False).name 819 torch.multiprocessing.spawn( 820 ScheduleTest.run_rank, 821 nprocs=world_size, 822 args=(world_size, rdvz_file), 823 ) 824