1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3# This file is a Schedule zoo for testing torch.distributed.pipelining. 4# It includes schedules designed purely for testing purposes 5from typing import Callable, Dict, List, Optional 6 7from torch.distributed.pipelining.schedules import ( 8 _Action, 9 _ComputationType, 10 PipelineScheduleMulti, 11) 12from torch.distributed.pipelining.stage import _PipelineStageBase 13 14 15F = _ComputationType.FORWARD 16B = _ComputationType.BACKWARD 17W = _ComputationType.WEIGHT 18 19 20class ScheduleVShaped(PipelineScheduleMulti): 21 n_stages = 4 22 rank_stages = { 23 0: [0, 3], 24 1: [1, 2], 25 } 26 27 def __init__( 28 self, 29 stages: List[_PipelineStageBase], 30 n_microbatches: int, 31 stage_index_to_group_rank: Dict[int, int], 32 loss_fn: Optional[Callable] = None, 33 ): 34 super().__init__( 35 stages=stages, 36 n_microbatches=n_microbatches, 37 loss_fn=loss_fn, 38 stage_index_to_group_rank=stage_index_to_group_rank, 39 ) 40 41 # Go through one microbatch 42 # Note(whc) - it might be easier to work with thes schedules by writing them as a list of 43 # ["0F0", ...] and then parsing them in the test infra to turn them into actions. 44 self.pipeline_order = { 45 0: [ 46 _Action(0, F, 0), 47 None, 48 None, 49 _Action(3, F, 0), 50 _Action(3, B, 0), 51 None, 52 None, 53 _Action(0, B, 0), 54 ], 55 1: [ 56 None, 57 _Action(1, F, 0), 58 _Action(2, F, 0), 59 None, 60 None, 61 _Action(2, B, 0), 62 _Action(1, B, 0), 63 None, 64 ], 65 } 66 67 68class ScheduleUnbalanced(PipelineScheduleMulti): 69 n_stages = 5 70 rank_stages = { 71 0: [0, 1, 4], 72 1: [2, 3], 73 } 74 75 def __init__( 76 self, 77 stages: List[_PipelineStageBase], 78 n_microbatches: int, 79 stage_index_to_group_rank: Dict[int, int], 80 loss_fn: Optional[Callable] = None, 81 ): 82 super().__init__( 83 stages=stages, 84 n_microbatches=n_microbatches, 85 loss_fn=loss_fn, 86 stage_index_to_group_rank=stage_index_to_group_rank, 87 ) 88 89 self.pipeline_order = { 90 0: [ 91 _Action(0, F, 0), 92 _Action(1, F, 0), 93 None, 94 None, 95 _Action(4, F, 0), 96 _Action(4, B, 0), 97 None, 98 None, 99 _Action(1, B, 0), 100 _Action(0, B, 0), 101 ], 102 1: [ 103 None, 104 None, 105 _Action(2, F, 0), 106 _Action(3, F, 0), 107 None, 108 None, 109 _Action(3, B, 0), 110 _Action(2, B, 0), 111 None, 112 None, 113 ], 114 } 115 116 117class ScheduleWithW(PipelineScheduleMulti): 118 n_stages = 4 119 num_microbatches = 2 120 rank_stages = { 121 0: [0, 2], 122 1: [1, 3], 123 } 124 125 def __init__( 126 self, 127 stages: List[_PipelineStageBase], 128 n_microbatches: int, 129 loss_fn: Optional[Callable] = None, 130 enable_zero_bubble: bool = True, 131 ): 132 super().__init__( 133 stages=stages, 134 n_microbatches=n_microbatches, 135 loss_fn=loss_fn, 136 ) 137 138 # Needs to be updated as part of all schedules using "W" 139 self.use_full_backward = False 140 141 # Go through two microbatches 142 self.pipeline_order = { 143 0: [ 144 _Action(0, F, 0), 145 _Action(0, F, 1), 146 _Action(2, F, 0), 147 _Action(2, F, 1), 148 None, 149 _Action(2, B, 0), 150 _Action(2, W, 0), 151 _Action(0, B, 0), 152 _Action(2, B, 1), 153 _Action(0, W, 0), 154 _Action(0, B, 1), 155 _Action(2, W, 1), 156 _Action(0, W, 1), 157 ], 158 1: [ 159 None, 160 _Action(1, F, 0), 161 _Action(1, F, 1), 162 _Action(3, F, 0), 163 _Action(3, B, 0), 164 _Action(3, F, 1), 165 _Action(1, B, 0), 166 _Action(3, B, 1), 167 _Action(3, W, 0), 168 _Action(1, B, 1), 169 _Action(1, W, 0), 170 _Action(3, W, 1), 171 _Action(1, W, 1), 172 ], 173 } 174