xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_schedule.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import logging
4from typing import List
5
6import torch
7from torch.distributed.pipelining import (
8    ScheduleFlexibleInterleaved1F1B,
9    ScheduleInterleaved1F1B,
10    ScheduleLoopedBFS,
11)
12from torch.distributed.pipelining.schedules import (
13    _Action,
14    _add_send_recv,
15    _add_unshard_reshard,
16    _format_pipeline_order,
17    _PipelineSchedule,
18    _validate_pipeline_order,
19    B,
20    F,
21    get_schedule_class,
22    RECV_F,
23    RESHARD,
24    SEND_B,
25    UNSHARD,
26    W,
27)
28from torch.distributed.pipelining.stage import _PipelineStageBase
29from torch.testing._internal.common_utils import (
30    instantiate_parametrized_tests,
31    parametrize,
32    run_tests,
33    TestCase,
34)
35
36
37logger = logging.getLogger(__name__)
38torch.manual_seed(0)
39
40
41class MockPipelineStage(_PipelineStageBase):
42    def __init__(self, *args, **kwargs):
43        # Mock the necessary attributes
44        self.num_stages = kwargs.get("num_stages", 1)
45        self.group_size = kwargs.get("group_size", 1)
46        self.group_rank = kwargs.get("group_rank", 0)
47        self.group = kwargs.get("group", None)
48        self.stage_index_to_group_rank = kwargs.get("stage_index_to_group_rank", None)
49
50    def _create_grad_recv_info(self, *args, **kwargs):
51        return None
52
53    def _prepare_forward_infra(self, n_microbatches):
54        pass
55
56    def _prepare_backward_infra(self, n_microbatches):
57        pass
58
59
60class ScheduleTest(TestCase):
61    def test_get_schedule_class(self):
62        # List of all expected schedule names
63        schedule_names = [
64            "1F1B",
65            "Interleaved1F1B",
66            "GPipe",
67            "FlexibleInterleaved1F1B",
68            "LoopedBFS",
69            "PipelineScheduleSingle",
70            "PipelineScheduleMulti",
71        ]
72
73        # Test each schedule name
74        for name in schedule_names:
75            with self.subTest(name=name):
76                schedule_class = get_schedule_class(name)
77                self.assertIsNotNone(
78                    schedule_class, f"Class for {name} should not be None"
79                )
80                self.assertTrue(
81                    issubclass(schedule_class, _PipelineSchedule),
82                    f"{name} should be a subclass of _PipelineSchedule",
83                )
84
85
86class TestSchedulePlan(TestCase):
87    def setUp(self):
88        # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size
89        # These should succeed since num_microbatches % group_size == 0
90        self.test_cases = [
91            # small number of stages
92            (2, 2, 2),
93            (2, 4, 4),
94            (2, 8, 2),
95            (2, 8, 4),
96            (2, 8, 8),
97            (4, 4, 4),
98            (4, 8, 4),
99            (4, 8, 8),
100            # large microbatches
101            (4, 16, 4),
102            (4, 32, 4),
103            (4, 64, 4),
104            # large groups
105            (4, 16, 16),
106            (4, 32, 32),
107            (4, 128, 64),
108            # odd num pipeline stages
109            (3, 2, 2),
110            (3, 8, 2),
111            (3, 12, 4),
112            # odd group_sizes
113            (4, 6, 3),
114            (4, 10, 5),
115            # n_mb non divisible by group_size
116            (2, 3, 4),
117            (2, 4, 4),
118            (2, 10, 4),
119            (2, 15, 4),
120        ]
121
122    @parametrize(
123        "ScheduleClass",
124        [ScheduleInterleaved1F1B, ScheduleLoopedBFS],
125    )
126    def test_pipeline_order(self, ScheduleClass):
127        for num_local_stages, num_microbatches, group_size in self.test_cases:
128            with self.subTest(
129                num_local_stages=num_local_stages,
130                num_microbatches=num_microbatches,
131                group_size=group_size,
132            ):
133                if num_microbatches % group_size != 0:
134                    continue
135
136                logger.info(
137                    "num_local_stages=%d num_microbatches=%d group_size=%d",
138                    num_local_stages,
139                    num_microbatches,
140                    group_size,
141                )
142                num_stages = num_local_stages * group_size
143                stages = [
144                    MockPipelineStage(group_size=group_size, num_stages=num_stages)
145                    for i in range(num_local_stages)
146                ]
147
148                schedule = ScheduleClass(stages, num_microbatches)
149                formatted_pipeline_order = _format_pipeline_order(
150                    schedule.pipeline_order
151                )
152                # print(formatted_pipeline_order)
153                _validate_pipeline_order(
154                    schedule.pipeline_order, num_microbatches, num_stages
155                )
156
157    @parametrize(
158        "ScheduleClass",
159        [ScheduleFlexibleInterleaved1F1B],
160    )
161    def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass):
162        for num_local_stages, num_microbatches, group_size in self.test_cases:
163            with self.subTest(
164                num_local_stages=num_local_stages,
165                num_microbatches=num_microbatches,
166                group_size=group_size,
167            ):
168                warmups_ops_last_stage = (num_local_stages - 1) * (
169                    num_microbatches // max(1, num_microbatches // group_size)
170                )
171                warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1)
172                warmup_ops = min(warmup_ops, num_microbatches * num_local_stages)
173
174                for i in range(2):
175                    num_stages = num_local_stages * group_size
176                    stages = [
177                        MockPipelineStage(group_size=group_size, num_stages=num_stages)
178                        for i in range(num_local_stages)
179                    ]
180                    schedule = ScheduleClass(
181                        stages, num_microbatches, enable_zero_bubble=(i == 0)
182                    )
183                    formatted_pipeline_order = _format_pipeline_order(
184                        schedule.pipeline_order
185                    )
186                    # print(formatted_pipeline_order)
187                    _validate_pipeline_order(
188                        schedule.pipeline_order,
189                        num_microbatches,
190                        num_stages,
191                        enable_zero_bubble=(i == 0),
192                    )
193
194
195instantiate_parametrized_tests(TestSchedulePlan)
196
197
198class TestScheduleLowering(TestCase):
199    """Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
200
201    def _parse_actions(self, actions: List[str]) -> List[_Action]:
202        return [_Action.from_str(s) for s in actions]
203
204    @parametrize(
205        "action_str_and_ref",
206        [
207            ("1F0", _Action(1, F, 0)),
208            ("2B1", _Action(2, B, 1)),
209            ("0W3", _Action(0, W, 3)),
210            ("1UNSHARD", _Action(1, UNSHARD, None)),
211            ("3RESHARD", _Action(3, RESHARD, None)),
212            ("2SEND_B2", _Action(2, SEND_B, 2)),
213            ("1RECV_F1", _Action(1, RECV_F, 1)),
214        ],
215    )
216    def test_action_parse(self, action_str_and_ref):
217        """Test that actions can be parsed from strings and round-tripped back to the same strings."""
218        act_str, ref = action_str_and_ref
219        act = _Action.from_str(act_str)
220        self.assertEqual(act, ref)
221        self.assertEqual(act_str, act.__repr__())
222
223    @parametrize(
224        "test_info",
225        [
226            {
227                "compute": ["0F0", "0F1", "   ", "0B0", "0B1"],
228                "comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
229            },
230        ],
231    )
232    def test_unshard_reshard(self, test_info):
233        """Test the lowering pass that takes a 'compute only' schedule (with only F,B,W ops) and adds
234        FSDP unshard/reshard operations to the schedule.  This is just part of the process of adding communication
235        ops and producing a complete schedule.
236        """
237        compute_sch = self._parse_actions(test_info["compute"])
238        expected_comms_sch = self._parse_actions(test_info["comms"])
239
240        comms_sch = _add_unshard_reshard(compute_sch)
241        for expected, actual in zip(expected_comms_sch, comms_sch):
242            self.assertEqual(
243                expected,
244                actual,
245                (
246                    f"Mismatch: expected action {expected} but found {actual}."
247                    f"\nWhole Schedule: {comms_sch}"
248                ),
249            )
250
251    @parametrize(
252        "test_info",
253        [
254            {
255                "compute": {
256                    0: ["0F0", "0F1", "   ", "0B0", "   ", "0B1"],
257                    1: ["   ", "1F0", "1B0", "1F1", "1B1", "   "],
258                },
259                "comms": {
260                    0: [
261                        "0F0",
262                        "0SEND_F0",
263                        "0F1",
264                        "0SEND_F1",
265                        "0RECV_B0",
266                        "0B0",
267                        "0RECV_B1",
268                        "0B1",
269                    ],
270                    1: [
271                        "1RECV_F0",
272                        "1RECV_F1",
273                        "1F0",
274                        "1B0",
275                        "1SEND_B0",
276                        "1F1",
277                        "1B1",
278                        "1SEND_B1",
279                    ],
280                },
281                "stage_to_rank": lambda stage_idx: stage_idx,
282                "num_stages": 2,
283            },
284        ],
285    )
286    def test_send_recv(self, test_info):
287        """Tests the lowering pass that adds send/recv ops to a compute-only schedule."""
288        compute_sch = {
289            rank: self._parse_actions(test_info["compute"][rank])
290            for rank in test_info["compute"]
291        }
292        expected_comms_sch = {
293            rank: self._parse_actions(test_info["comms"][rank])
294            for rank in test_info["comms"]
295        }
296
297        comms_sch = _add_send_recv(
298            compute_sch, test_info["stage_to_rank"], test_info["num_stages"]
299        )
300        for rank in expected_comms_sch:
301            for i, (expected, actual) in enumerate(
302                zip(expected_comms_sch[rank], comms_sch[rank])
303            ):
304                self.assertEqual(
305                    expected,
306                    actual,
307                    (
308                        f"Mismatch on rank {rank} at position {i}."
309                        f"\nExpected: {expected_comms_sch[rank]}"
310                        f"\nActual:   {comms_sch[rank]}"
311                    ),
312                )
313            self.assertEqual(len(comms_sch[rank]), len(expected_comms_sch[rank]))
314
315
316instantiate_parametrized_tests(TestScheduleLowering)
317
318if __name__ == "__main__":
319    run_tests()
320