xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/schedule_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3# This file is a Schedule zoo for testing torch.distributed.pipelining.
4# It includes schedules designed purely for testing purposes
5from typing import Callable, Dict, List, Optional
6
7from torch.distributed.pipelining.schedules import (
8    _Action,
9    _ComputationType,
10    PipelineScheduleMulti,
11)
12from torch.distributed.pipelining.stage import _PipelineStageBase
13
14
15F = _ComputationType.FORWARD
16B = _ComputationType.BACKWARD
17W = _ComputationType.WEIGHT
18
19
20class ScheduleVShaped(PipelineScheduleMulti):
21    n_stages = 4
22    rank_stages = {
23        0: [0, 3],
24        1: [1, 2],
25    }
26
27    def __init__(
28        self,
29        stages: List[_PipelineStageBase],
30        n_microbatches: int,
31        stage_index_to_group_rank: Dict[int, int],
32        loss_fn: Optional[Callable] = None,
33    ):
34        super().__init__(
35            stages=stages,
36            n_microbatches=n_microbatches,
37            loss_fn=loss_fn,
38            stage_index_to_group_rank=stage_index_to_group_rank,
39        )
40
41        # Go through one microbatch
42        # Note(whc) - it might be easier to work with thes schedules by writing them as a list of
43        # ["0F0", ...] and then parsing them in the test infra to turn them into actions.
44        self.pipeline_order = {
45            0: [
46                _Action(0, F, 0),
47                None,
48                None,
49                _Action(3, F, 0),
50                _Action(3, B, 0),
51                None,
52                None,
53                _Action(0, B, 0),
54            ],
55            1: [
56                None,
57                _Action(1, F, 0),
58                _Action(2, F, 0),
59                None,
60                None,
61                _Action(2, B, 0),
62                _Action(1, B, 0),
63                None,
64            ],
65        }
66
67
68class ScheduleUnbalanced(PipelineScheduleMulti):
69    n_stages = 5
70    rank_stages = {
71        0: [0, 1, 4],
72        1: [2, 3],
73    }
74
75    def __init__(
76        self,
77        stages: List[_PipelineStageBase],
78        n_microbatches: int,
79        stage_index_to_group_rank: Dict[int, int],
80        loss_fn: Optional[Callable] = None,
81    ):
82        super().__init__(
83            stages=stages,
84            n_microbatches=n_microbatches,
85            loss_fn=loss_fn,
86            stage_index_to_group_rank=stage_index_to_group_rank,
87        )
88
89        self.pipeline_order = {
90            0: [
91                _Action(0, F, 0),
92                _Action(1, F, 0),
93                None,
94                None,
95                _Action(4, F, 0),
96                _Action(4, B, 0),
97                None,
98                None,
99                _Action(1, B, 0),
100                _Action(0, B, 0),
101            ],
102            1: [
103                None,
104                None,
105                _Action(2, F, 0),
106                _Action(3, F, 0),
107                None,
108                None,
109                _Action(3, B, 0),
110                _Action(2, B, 0),
111                None,
112                None,
113            ],
114        }
115
116
117class ScheduleWithW(PipelineScheduleMulti):
118    n_stages = 4
119    num_microbatches = 2
120    rank_stages = {
121        0: [0, 2],
122        1: [1, 3],
123    }
124
125    def __init__(
126        self,
127        stages: List[_PipelineStageBase],
128        n_microbatches: int,
129        loss_fn: Optional[Callable] = None,
130        enable_zero_bubble: bool = True,
131    ):
132        super().__init__(
133            stages=stages,
134            n_microbatches=n_microbatches,
135            loss_fn=loss_fn,
136        )
137
138        # Needs to be updated as part of all schedules using "W"
139        self.use_full_backward = False
140
141        # Go through two microbatches
142        self.pipeline_order = {
143            0: [
144                _Action(0, F, 0),
145                _Action(0, F, 1),
146                _Action(2, F, 0),
147                _Action(2, F, 1),
148                None,
149                _Action(2, B, 0),
150                _Action(2, W, 0),
151                _Action(0, B, 0),
152                _Action(2, B, 1),
153                _Action(0, W, 0),
154                _Action(0, B, 1),
155                _Action(2, W, 1),
156                _Action(0, W, 1),
157            ],
158            1: [
159                None,
160                _Action(1, F, 0),
161                _Action(1, F, 1),
162                _Action(3, F, 0),
163                _Action(3, B, 0),
164                _Action(3, F, 1),
165                _Action(1, B, 0),
166                _Action(3, B, 1),
167                _Action(3, W, 0),
168                _Action(1, B, 1),
169                _Action(1, W, 0),
170                _Action(3, W, 1),
171                _Action(1, W, 1),
172            ],
173        }
174