xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_schedule_multiproc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import copy
4import logging
5import os
6import sys
7import tempfile
8
9from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
10from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
11
12import torch
13import torch.distributed as dist
14from torch.distributed.pipelining import (
15    _ScheduleForwardOnly,
16    pipeline,
17    PipelineStage,
18    Schedule1F1B,
19    ScheduleFlexibleInterleaved1F1B,
20    ScheduleGPipe,
21    ScheduleInterleaved1F1B,
22    ScheduleInterleavedZeroBubble,
23    ScheduleLoopedBFS,
24)
25from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
26from torch.testing._internal.common_cuda import TEST_MULTIGPU
27from torch.testing._internal.common_distributed import (
28    MultiProcContinousTest,
29    requires_nccl,
30)
31from torch.testing._internal.common_utils import (
32    instantiate_parametrized_tests,
33    parametrize,
34    skip_but_pass_in_sandcastle_if,
35)
36
37
38logger = logging.getLogger(__name__)
39
40d_hid = 512
41batch_size = 256
42
43torch.manual_seed(0)
44
45
46class ScheduleTest(MultiProcContinousTest):
47    @classmethod
48    def backend_str(cls) -> str:
49        # Testing with NCCL backend
50        return "nccl"
51
52    @classmethod
53    def setUpClass(cls):
54        """
55        Class-scope test fixture. Run once for entire test class, before any test starts.
56        Set up the device.
57        """
58        super().setUpClass()
59        dev_id = cls.rank % torch.cuda.device_count()
60        cls.device = torch.device(f"cuda:{dev_id}")
61
62    @requires_nccl()
63    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
64    @parametrize("ScheduleClass", [_ScheduleForwardOnly])
65    def test_forward_only(self, ScheduleClass):
66        mod = MultiMLP(d_hid, n_layers=self.world_size)
67        mod.to(self.device)
68
69        mod_ref = copy.deepcopy(mod)
70
71        x = torch.randn(batch_size, d_hid, device=self.device)
72        x_clone = x.clone()
73
74        num_microbatches = 4
75        x_mb = x.chunk(num_microbatches)[0]
76
77        # Create a pipeline
78        split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
79        pipe = pipeline(
80            mod,
81            mb_args=(x_mb,),
82            split_spec=split_spec,
83        )
84
85        stage = pipe.build_stage(
86            self.rank,
87            self.device,
88        )
89
90        # Attach to a schedule
91        schedule = ScheduleClass(stage, num_microbatches)
92
93        # Run
94        num_iters = 20
95        for _ in range(num_iters):
96            if self.rank == 0:
97                schedule.step(x)
98                dist.recv(x, src=self.world_size - 1)
99            elif self.rank == self.world_size - 1:
100                out = schedule.step()
101                dist.send(out, dst=0)
102            else:
103                schedule.step()
104
105        # Validate pipelined output is the same as reference model
106        if self.rank == self.world_size - 1:
107            for _ in range(num_iters):
108                x_clone = mod_ref(x_clone)
109
110            torch.testing.assert_close(x_clone, out)
111
112    @requires_nccl()
113    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
114    @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
115    def test_multi_iter(self, ScheduleClass):
116        mod = MultiMLP(d_hid, n_layers=self.world_size)
117        mod.to(self.device)
118
119        x = torch.randn(batch_size, d_hid, device=self.device)
120        target = torch.randn(batch_size, d_hid, device=self.device)
121        loss_fn = torch.nn.MSELoss(reduction="sum")
122
123        chunks = 4
124        x_mb = x.chunk(chunks)[0]
125
126        # Create a pipeline
127        split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
128        pipe = pipeline(
129            mod,
130            mb_args=(x_mb,),
131            split_spec=split_spec,
132        )
133
134        stage = pipe.build_stage(
135            self.rank,
136            self.device,
137        )
138
139        # Attach to a schedule
140        schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
141
142        # Run
143        for _ in range(20):
144            if self.rank == 0:
145                schedule.step(x)
146            elif self.rank == self.world_size - 1:
147                losses = []
148                out = schedule.step(target=target, losses=losses)
149            else:
150                schedule.step()
151
152    @requires_nccl()
153    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
154    @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
155    def test_kwargs_with_tracer(self, ScheduleClass):
156        mod = ModelWithKwargs(d_hid)
157        mod.to(self.device)
158
159        x = torch.randn(batch_size, d_hid, device=self.device)
160        y = torch.randn(batch_size, d_hid, device=self.device)
161        target = torch.randn(batch_size, d_hid, device=self.device)
162        loss_fn = torch.nn.MSELoss(reduction="sum")
163
164        chunks = 4
165        x_mb = x.chunk(chunks)[0]
166        y_mb = y.chunk(chunks)[0]
167
168        pipe = pipeline(
169            mod,
170            mb_args=(x_mb,),
171            mb_kwargs={"y": y_mb},
172        )
173
174        stage = pipe.build_stage(
175            self.rank,
176            self.device,
177        )
178
179        # Attach to a schedule
180        schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
181
182        # Run
183        if self.rank == 0:
184            schedule.step(x, y=y)
185        elif self.rank == self.world_size - 1:
186            losses = []
187            out = schedule.step(target=target, losses=losses)
188        else:
189            schedule.step()
190
191        dist.barrier()
192
193        # Last rank checks result
194        if self.rank == self.world_size - 1:
195            ref_out = mod(x, y=y)
196            ref_loss = loss_fn(ref_out, target)
197            pipe_loss = sum(losses)
198            torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3)
199            torch.testing.assert_close(pipe_loss, ref_loss)
200
201    @requires_nccl()
202    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
203    @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
204    @parametrize("ModelClass", [MultiMLP])
205    def test_grad_with_tracer(self, ScheduleClass, ModelClass):
206        mod = ModelClass(d_hid)
207        mod.to(self.device)
208
209        ref_mod = copy.deepcopy(mod)
210        x = torch.randn(batch_size, d_hid, device=self.device)
211        with torch.no_grad():
212            y = ref_mod(x)
213            # Add a small perturbation
214            target = y + torch.randn(batch_size, d_hid, device=self.device)
215
216        loss_fn = torch.nn.MSELoss(reduction="sum")
217
218        # Run reference
219        for _ in range(2):
220            ref_mod.zero_grad()
221            ref_out = ref_mod(x)
222            ref_loss = loss_fn(ref_out, target)
223            ref_loss.backward()
224
225        # Create a pipeline
226        chunks = 4
227        x_mb = x.chunk(chunks)[0]
228        split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
229        pipe = pipeline(
230            mod,
231            mb_args=(x_mb,),
232            split_spec=split_spec,
233        )
234
235        stage = pipe.build_stage(
236            self.rank,
237            self.device,
238        )
239
240        # Attach to a schedule
241        schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
242
243        # Run
244        stage_module = pipe.get_stage_module(self.rank)
245        for _ in range(2):
246            # Zero gradients
247            stage_module.zero_grad()
248            if self.rank == 0:
249                schedule.step(x)
250            elif self.rank == self.world_size - 1:
251                losses = []
252                out = schedule.step(target=target, losses=losses)
253            else:
254                schedule.step()
255
256        dist.barrier()
257
258        # Last rank checks result
259        if self.rank == self.world_size - 1:
260            # Check output
261            torch.testing.assert_close(out, ref_out)
262            # Check loss
263            # Since the reduction used in the loss function above is "sum", we use
264            # "sum" here to reduce microbatch losses into a single value too.
265            pipe_loss = sum(losses)
266            torch.testing.assert_close(pipe_loss, ref_loss)
267
268        # Every rank checks gradients
269        for name, p in stage_module.named_parameters():
270            ref_p = ref_mod.get_parameter(name)
271            try:
272                torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
273            except AssertionError:
274                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
275                raise
276
277    @requires_nccl()
278    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
279    @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
280    def test_grad_with_manual(self, ScheduleClass):
281        full_mod = MultiMLP(d_hid, n_layers=self.world_size)
282        full_mod.to(self.device)
283
284        ref_mod = copy.deepcopy(full_mod)
285        x = torch.randn(batch_size, d_hid, device=self.device)
286        with torch.no_grad():
287            y = ref_mod(x)
288            # Add a small perturbation
289            target = y + torch.randn(batch_size, d_hid, device=self.device)
290
291        loss_fn = torch.nn.MSELoss(reduction="sum")
292
293        # Run reference
294        for _ in range(2):
295            ref_mod.zero_grad()
296            ref_out = ref_mod(x)
297            ref_loss = loss_fn(ref_out, target)
298            ref_loss.backward()
299
300        # Get a submodule, e.g. `layers.0` or `layers.1`
301        submod_name = f"layers.{self.rank}"
302        stage_module = full_mod.get_submodule(submod_name)
303        chunks = 4
304        # Create a pipeline stage to wrap that submodule
305        stage = PipelineStage(
306            stage_module,
307            self.rank,
308            self.world_size,
309            self.device,
310            input_args=x.chunk(chunks)[0],
311        )
312
313        # Attach to a schedule
314        schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
315
316        # Run
317        for _ in range(2):
318            # Zero gradients
319            stage_module.zero_grad()
320            if self.rank == 0:
321                schedule.step(x)
322            elif self.rank == self.world_size - 1:
323                losses = []
324                out = schedule.step(target=target, losses=losses)
325            else:
326                schedule.step()
327
328        dist.barrier()
329
330        # Last rank checks result
331        if self.rank == self.world_size - 1:
332            # Check output
333            torch.testing.assert_close(out, ref_out)
334            # Check loss
335            # Since the reduction used in the loss function above is "sum", we use
336            # "sum" here to reduce microbatch losses into a single value too.
337            pipe_loss = sum(losses)
338            torch.testing.assert_close(pipe_loss, ref_loss)
339
340        # Every rank checks gradients
341        ref_submod = ref_mod.get_submodule(submod_name)
342        for name, p in stage_module.named_parameters():
343            ref_p = ref_submod.get_parameter(name)
344            try:
345                torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
346            except AssertionError:
347                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
348                raise
349
350    @requires_nccl()
351    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
352    @parametrize(
353        "ScheduleClass",
354        [ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
355    )
356    @parametrize("use_new_runtime", [False, True])
357    def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
358        stages_per_rank = 2
359        n_stages = stages_per_rank * self.world_size
360        full_mod = MultiMLP(d_hid, n_layers=n_stages)
361        full_mod.to(self.device)
362
363        ref_mod = copy.deepcopy(full_mod)
364        x = torch.randn(batch_size, d_hid, device=self.device)
365        with torch.no_grad():
366            y = ref_mod(x)
367            # Add a small perturbation
368            target = y + torch.randn(batch_size, d_hid, device=self.device)
369
370        loss_fn = torch.nn.MSELoss(reduction="sum")
371
372        # Run reference
373        for _ in range(2):
374            ref_mod.zero_grad()
375            ref_out = ref_mod(x)
376            ref_loss = loss_fn(ref_out, target)
377            ref_loss.backward()
378
379        # Get a submodule, e.g. `layers.0` or `layers.1`
380        stage_indices = [
381            self.rank + i * self.world_size for i in range(stages_per_rank)
382        ]
383        print(f"Rank {self.rank} stages: {stage_indices}")
384        submod_names = [f"layers.{i}" for i in stage_indices]
385        stage_modules = [
386            full_mod.get_submodule(submod_name) for submod_name in submod_names
387        ]
388        # Create a pipeline stage to wrap that submodule
389        num_microbatches = (
390            ScheduleClass.num_microbatches
391            if hasattr(ScheduleClass, "num_microbatches")
392            else 8
393        )
394        input_args = x.chunk(num_microbatches)[0]
395        stages = [
396            PipelineStage(
397                stage_module,
398                stage_idx,
399                n_stages,
400                self.device,
401                input_args=input_args,
402            )
403            for stage_module, stage_idx in zip(stage_modules, stage_indices)
404        ]
405
406        # Attach to a schedule
407        schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn)
408        if use_new_runtime:
409            old_schedule = schedule
410            tmp_schedule = _PipelineScheduleRuntime(
411                stages,
412                num_microbatches,
413                loss_fn=loss_fn,
414                stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
415                use_full_backward=old_schedule.use_full_backward,
416            )
417            tmp_schedule._load_actions(old_schedule.pipeline_order)
418            # test that csv round-trip works for compute_comms schedule
419            schedule = _PipelineScheduleRuntime(
420                stages,
421                num_microbatches,
422                loss_fn=loss_fn,
423                stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
424                use_full_backward=old_schedule.use_full_backward,
425            )
426            with tempfile.NamedTemporaryFile() as f:
427                tmp_schedule._dump_csv(f.name)
428                f.seek(0)
429                schedule._load_csv(f.name, format="compute_comms")
430            one_more_schedule = _PipelineScheduleRuntime(
431                stages,
432                num_microbatches,
433                loss_fn=loss_fn,
434                stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
435                use_full_backward=old_schedule.use_full_backward,
436            )
437            one_more_schedule._load_actions(
438                schedule.pipeline_order_with_comms, format="compute_comms"
439            )
440            self.assertEqual(
441                len(schedule.pipeline_order_with_comms),
442                len(
443                    one_more_schedule.pipeline_order_with_comms,
444                ),
445            )
446            for rank in schedule.pipeline_order_with_comms:
447                self.assertEqual(
448                    len(schedule.pipeline_order_with_comms[rank]),
449                    len(
450                        one_more_schedule.pipeline_order_with_comms[rank],
451                    ),
452                )
453                for a, b in zip(
454                    schedule.pipeline_order_with_comms[rank],
455                    one_more_schedule.pipeline_order_with_comms[rank],
456                ):
457                    self.assertEqual(a, b)
458
459        # Run
460        for _ in range(2):
461            # Zero gradients
462            for stage_module in stage_modules:
463                stage_module.zero_grad()
464            if self.rank == 0:
465                schedule.step(x)
466            elif self.rank == self.world_size - 1:
467                losses = []
468                out = schedule.step(target=target, losses=losses)
469            else:
470                schedule.step()
471
472        dist.barrier()
473
474        # Last rank checks result
475        if self.rank == self.world_size - 1:
476            # Check output
477            torch.testing.assert_close(out, ref_out)
478            # Check loss
479            # Since the reduction used in the loss function above is "sum", we use
480            # "sum" here to reduce microbatch losses into a single value too.
481            pipe_loss = sum(losses)
482            torch.testing.assert_close(pipe_loss, ref_loss)
483
484        # Every rank checks gradients
485        for stage_module, submod_name in zip(stage_modules, submod_names):
486            # Get corresponding submodule from reference model
487            ref_submod = ref_mod.get_submodule(submod_name)
488            # Check gradients per parameter
489            for name, p in stage_module.named_parameters():
490                ref_p = ref_submod.get_parameter(name)
491                try:
492                    torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
493                except AssertionError:
494                    print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
495                    raise
496
497    @requires_nccl()
498    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
499    @parametrize("ScheduleClass", [ScheduleWithW, ScheduleFlexibleInterleaved1F1B])
500    def test_schedule_with_native_zero_bubble(self, ScheduleClass):
501        print(ScheduleClass)
502        if ScheduleClass is ScheduleFlexibleInterleaved1F1B:
503            n_stages = 4
504            num_microbatches = 8
505            rank_stages = {
506                0: [0, 2],
507                1: [1, 3],
508            }
509        else:
510            n_stages = ScheduleClass.n_stages
511            num_microbatches = ScheduleClass.num_microbatches
512            rank_stages = ScheduleClass.rank_stages
513
514        num_steps = 4
515        full_mod = MultiMLP(d_hid, n_layers=n_stages)
516        full_mod.to(self.device)
517
518        ref_mod = copy.deepcopy(full_mod)
519        x = torch.randn(batch_size, d_hid, device=self.device)
520        # x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True)
521        with torch.no_grad():
522            y = ref_mod(x)
523            # Add a small perturbation
524            target = y + torch.randn(batch_size, d_hid, device=self.device)
525
526        loss_fn = torch.nn.MSELoss(reduction="sum")
527
528        # Create a pipeline stage to wrap that submodule
529        input_args = x.chunk(num_microbatches)[0]
530        stage_indices = rank_stages[self.rank]
531        print(f"Rank {self.rank} stages: {stage_indices}")
532        submod_names = [f"layers.{i}" for i in stage_indices]
533        stage_modules = [
534            full_mod.get_submodule(submod_name) for submod_name in submod_names
535        ]
536        stages = [
537            PipelineStage(
538                stage_module,
539                stage_idx,
540                n_stages,
541                self.device,
542                input_args=input_args,
543            )
544            for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank])
545        ]
546
547        schedule = ScheduleClass(
548            stages, num_microbatches, loss_fn=loss_fn, enable_zero_bubble=True
549        )
550
551        # Run reference
552        ref_x = x.clone().detach().requires_grad_(x.requires_grad)
553        torch.testing.assert_close(x, ref_x)
554        for _ in range(num_steps):
555            ref_out = ref_mod(ref_x)
556            ref_loss = loss_fn(ref_out, target)
557            ref_loss.backward()
558
559        # Run pipelined stages
560        for _ in range(num_steps):
561            if self.rank == 0:
562                schedule.step(x)
563            elif self.rank == self.world_size - 1:
564                losses = []
565                out = schedule.step(target=target, losses=losses)
566            else:
567                schedule.step()
568
569        # Every rank checks parameters compared with the reference model
570        for stage_module, submod_name in zip(stage_modules, submod_names):
571            # Get corresponding submodule from reference model
572            ref_submod = ref_mod.get_submodule(submod_name)
573            # Check gradients per parameter
574            for name, p in stage_module.named_parameters():
575                ref_p = ref_submod.get_parameter(name)
576                try:
577                    torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
578                except AssertionError:
579                    print(
580                        f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}"
581                    )
582                    raise
583
584    @requires_nccl()
585    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
586    @parametrize("ScheduleClass", [ScheduleVShaped, ScheduleUnbalanced])
587    def test_non_symmetric_stage_ids(self, ScheduleClass):
588        n_stages = ScheduleClass.n_stages
589        full_mod = MultiMLP(d_hid, n_layers=n_stages)
590        full_mod.to(self.device)
591
592        ref_mod = copy.deepcopy(full_mod)
593        x = torch.randn(batch_size, d_hid, device=self.device)
594        with torch.no_grad():
595            y = ref_mod(x)
596            # Add a small perturbation
597            target = y + torch.randn(batch_size, d_hid, device=self.device)
598
599        loss_fn = torch.nn.MSELoss(reduction="sum")
600
601        # Run reference
602        for _ in range(2):
603            ref_mod.zero_grad()
604            ref_out = ref_mod(x)
605            ref_loss = loss_fn(ref_out, target)
606            ref_loss.backward()
607
608        # Create a pipeline stage to wrap that submodule
609        chunks = 1
610        input_args = x.chunk(chunks)[0]
611        rank_stages = ScheduleClass.rank_stages
612        stage_indices = rank_stages[self.rank]
613        print(f"Rank {self.rank} stages: {stage_indices}")
614        submod_names = [f"layers.{i}" for i in stage_indices]
615        stage_modules = [
616            full_mod.get_submodule(submod_name) for submod_name in submod_names
617        ]
618        stages = [
619            PipelineStage(
620                stage_module,
621                stage_idx,
622                n_stages,
623                self.device,
624                input_args=input_args,
625            )
626            for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank])
627        ]
628
629        # Attach to a schedule
630        stage_index_to_group_rank = {
631            value: key for key, values in rank_stages.items() for value in values
632        }
633        schedule = ScheduleClass(
634            stages, chunks, stage_index_to_group_rank, loss_fn=loss_fn
635        )
636
637        # Run
638        # TODO how to better specify .step() when first and last stage are on rank 0...
639        for _ in range(2):
640            # Zero gradients
641            for stage_module in stage_modules:
642                stage_module.zero_grad()
643            if self.rank == 0:
644                losses = []
645                out = schedule.step(x, target=target, losses=losses)
646            else:
647                schedule.step()
648
649        dist.barrier()
650
651        # Last rank checks result
652        if self.rank == 0:
653            # Check output
654            torch.testing.assert_close(out, ref_out)
655            # Check loss
656            # Since the reduction used in the loss function above is "sum", we use
657            # "sum" here to reduce microbatch losses into a single value too.
658            pipe_loss = sum(losses)
659            torch.testing.assert_close(pipe_loss, ref_loss)
660
661        # Every rank checks gradients
662        for stage_module, submod_name in zip(stage_modules, submod_names):
663            # Get corresponding submodule from reference model
664            ref_submod = ref_mod.get_submodule(submod_name)
665            # Check gradients per parameter
666            for name, p in stage_module.named_parameters():
667                ref_p = ref_submod.get_parameter(name)
668                try:
669                    torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
670                except AssertionError:
671                    print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
672                    raise
673
674    @requires_nccl()
675    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
676    @parametrize("ScheduleClass", [ScheduleFlexibleInterleaved1F1B])
677    def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
678        stages_per_rank = 2
679        n_stages = stages_per_rank * self.world_size
680        full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages)
681        full_mod.to(self.device)
682
683        ref_mod = copy.deepcopy(full_mod)
684        x = torch.randn(batch_size, d_hid, device=self.device)
685        with torch.no_grad():
686            y = ref_mod(x)
687            # Add a small perturbation
688            target = y + torch.randn(batch_size, d_hid, device=self.device)
689
690        ref_loss_fn = torch.nn.MSELoss(reduction="sum")
691        full_loss_fn = torch.nn.MSELoss(reduction="sum")
692
693        full_mod.toggle()
694
695        # Get a submodule, e.g. `layers.0` or `layers.1`
696        stage_indices = [
697            self.rank + i * self.world_size for i in range(stages_per_rank)
698        ]
699        submod_names = [f"layers.{i}" for i in stage_indices]
700        stage_modules = [
701            full_mod.get_submodule(submod_name) for submod_name in submod_names
702        ]
703
704        # Run reference
705        for _ in range(2):
706            ref_stage_modules = [
707                ref_mod.get_submodule(submod_name) for submod_name in submod_names
708            ]
709            for stage_module in ref_stage_modules:
710                stage_module.zero_grad()
711
712            ref_mod.zero_grad()
713            ref_out = ref_mod(x)
714            ref_loss = ref_loss_fn(ref_out, target)
715            ref_loss.backward()
716
717        class CustomState:
718            def __init__(self, stage_module, stage_idx, rank):
719                self.i = 0
720                self.stage_module = stage_module
721                self.stage_idx = stage_idx
722                self.rank = rank
723
724            def dw_builder(self):
725                def dw_runner():
726                    # This inner function would be called by PipelineStage during `backward_weight_one_chunk`
727                    self.i += 1
728                    print(
729                        f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}"
730                    )
731                    self.stage_module.compute_dW()
732
733                return dw_runner
734
735        cs = {}
736        for stage_module, stage_idx in zip(stage_modules, stage_indices):
737            cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank)
738
739        # Create a pipeline stage to wrap that submodule
740        chunks = 2
741        input_args = x.chunk(chunks)[0]
742        stages = [
743            PipelineStage(
744                stage_module,
745                stage_idx,
746                n_stages,
747                self.device,
748                input_args=input_args,
749                dw_builder=cs[stage_idx].dw_builder,
750            )
751            for stage_module, stage_idx in zip(stage_modules, stage_indices)
752        ]
753
754        # Attach to a schedule
755        schedule = ScheduleClass(
756            stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True
757        )
758
759        for _ in range(2):
760            # Zero gradients
761            for stage_module in stage_modules:
762                stage_module.zero_grad()
763            if self.rank == 0:
764                schedule.step(x)
765            elif self.rank == self.world_size - 1:
766                losses = []
767                out = schedule.step(target=target, losses=losses)
768            else:
769                schedule.step()
770
771        dist.barrier()
772        # Last rank checks result
773        if self.rank == self.world_size - 1:
774            # Check output
775            torch.testing.assert_close(out, ref_out)
776
777            # Check loss
778            # Since the reduction used in the loss function above is "sum", we use
779            # "sum" here to reduce microbatch losses into a single value too.
780            pipe_loss = sum(losses)
781            torch.testing.assert_close(pipe_loss, ref_loss)
782
783        # Every rank checks gradients
784        for stage_module, submod_name in zip(stage_modules, submod_names):
785            # Get corresponding submodule from reference model
786            ref_submod = ref_mod.get_submodule(submod_name)
787            # Check gradients per parameter
788            for name, p in stage_module.named_parameters():
789                ref_p = ref_submod.get_parameter(name)
790                torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
791
792
793instantiate_parametrized_tests(ScheduleTest)
794
795
796if __name__ == "__main__":
797    # Check if GPU and NCCL are available
798    if not (
799        dist.is_available()
800        and dist.is_nccl_available()
801        and torch.cuda.device_count() > 1
802    ):
803        print(
804            "c10d NCCL not available or not enough GPUs, skipping tests",
805            file=sys.stderr,
806        )
807        sys.exit(0)
808
809    rank = int(os.getenv("RANK", -1))
810    world_size = int(os.getenv("WORLD_SIZE", 2))
811
812    if rank != -1:
813        # Launched with torchrun or other multi-proc launchers. Directly run the test.
814        ScheduleTest.run_rank(rank, world_size)
815    else:
816        # Launched as a single process. Spawn subprocess to run the tests.
817        # Also need a rendezvous file for `init_process_group` purpose.
818        rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
819        torch.multiprocessing.spawn(
820            ScheduleTest.run_rank,
821            nprocs=world_size,
822            args=(world_size, rdvz_file),
823        )
824