1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import logging 4from typing import List 5 6import torch 7from torch.distributed.pipelining import ( 8 ScheduleFlexibleInterleaved1F1B, 9 ScheduleInterleaved1F1B, 10 ScheduleLoopedBFS, 11) 12from torch.distributed.pipelining.schedules import ( 13 _Action, 14 _add_send_recv, 15 _add_unshard_reshard, 16 _format_pipeline_order, 17 _PipelineSchedule, 18 _validate_pipeline_order, 19 B, 20 F, 21 get_schedule_class, 22 RECV_F, 23 RESHARD, 24 SEND_B, 25 UNSHARD, 26 W, 27) 28from torch.distributed.pipelining.stage import _PipelineStageBase 29from torch.testing._internal.common_utils import ( 30 instantiate_parametrized_tests, 31 parametrize, 32 run_tests, 33 TestCase, 34) 35 36 37logger = logging.getLogger(__name__) 38torch.manual_seed(0) 39 40 41class MockPipelineStage(_PipelineStageBase): 42 def __init__(self, *args, **kwargs): 43 # Mock the necessary attributes 44 self.num_stages = kwargs.get("num_stages", 1) 45 self.group_size = kwargs.get("group_size", 1) 46 self.group_rank = kwargs.get("group_rank", 0) 47 self.group = kwargs.get("group", None) 48 self.stage_index_to_group_rank = kwargs.get("stage_index_to_group_rank", None) 49 50 def _create_grad_recv_info(self, *args, **kwargs): 51 return None 52 53 def _prepare_forward_infra(self, n_microbatches): 54 pass 55 56 def _prepare_backward_infra(self, n_microbatches): 57 pass 58 59 60class ScheduleTest(TestCase): 61 def test_get_schedule_class(self): 62 # List of all expected schedule names 63 schedule_names = [ 64 "1F1B", 65 "Interleaved1F1B", 66 "GPipe", 67 "FlexibleInterleaved1F1B", 68 "LoopedBFS", 69 "PipelineScheduleSingle", 70 "PipelineScheduleMulti", 71 ] 72 73 # Test each schedule name 74 for name in schedule_names: 75 with self.subTest(name=name): 76 schedule_class = get_schedule_class(name) 77 self.assertIsNotNone( 78 schedule_class, f"Class for {name} should not be None" 79 ) 80 self.assertTrue( 81 issubclass(schedule_class, _PipelineSchedule), 82 f"{name} should be a subclass of _PipelineSchedule", 83 ) 84 85 86class TestSchedulePlan(TestCase): 87 def setUp(self): 88 # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size 89 # These should succeed since num_microbatches % group_size == 0 90 self.test_cases = [ 91 # small number of stages 92 (2, 2, 2), 93 (2, 4, 4), 94 (2, 8, 2), 95 (2, 8, 4), 96 (2, 8, 8), 97 (4, 4, 4), 98 (4, 8, 4), 99 (4, 8, 8), 100 # large microbatches 101 (4, 16, 4), 102 (4, 32, 4), 103 (4, 64, 4), 104 # large groups 105 (4, 16, 16), 106 (4, 32, 32), 107 (4, 128, 64), 108 # odd num pipeline stages 109 (3, 2, 2), 110 (3, 8, 2), 111 (3, 12, 4), 112 # odd group_sizes 113 (4, 6, 3), 114 (4, 10, 5), 115 # n_mb non divisible by group_size 116 (2, 3, 4), 117 (2, 4, 4), 118 (2, 10, 4), 119 (2, 15, 4), 120 ] 121 122 @parametrize( 123 "ScheduleClass", 124 [ScheduleInterleaved1F1B, ScheduleLoopedBFS], 125 ) 126 def test_pipeline_order(self, ScheduleClass): 127 for num_local_stages, num_microbatches, group_size in self.test_cases: 128 with self.subTest( 129 num_local_stages=num_local_stages, 130 num_microbatches=num_microbatches, 131 group_size=group_size, 132 ): 133 if num_microbatches % group_size != 0: 134 continue 135 136 logger.info( 137 "num_local_stages=%d num_microbatches=%d group_size=%d", 138 num_local_stages, 139 num_microbatches, 140 group_size, 141 ) 142 num_stages = num_local_stages * group_size 143 stages = [ 144 MockPipelineStage(group_size=group_size, num_stages=num_stages) 145 for i in range(num_local_stages) 146 ] 147 148 schedule = ScheduleClass(stages, num_microbatches) 149 formatted_pipeline_order = _format_pipeline_order( 150 schedule.pipeline_order 151 ) 152 # print(formatted_pipeline_order) 153 _validate_pipeline_order( 154 schedule.pipeline_order, num_microbatches, num_stages 155 ) 156 157 @parametrize( 158 "ScheduleClass", 159 [ScheduleFlexibleInterleaved1F1B], 160 ) 161 def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass): 162 for num_local_stages, num_microbatches, group_size in self.test_cases: 163 with self.subTest( 164 num_local_stages=num_local_stages, 165 num_microbatches=num_microbatches, 166 group_size=group_size, 167 ): 168 warmups_ops_last_stage = (num_local_stages - 1) * ( 169 num_microbatches // max(1, num_microbatches // group_size) 170 ) 171 warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1) 172 warmup_ops = min(warmup_ops, num_microbatches * num_local_stages) 173 174 for i in range(2): 175 num_stages = num_local_stages * group_size 176 stages = [ 177 MockPipelineStage(group_size=group_size, num_stages=num_stages) 178 for i in range(num_local_stages) 179 ] 180 schedule = ScheduleClass( 181 stages, num_microbatches, enable_zero_bubble=(i == 0) 182 ) 183 formatted_pipeline_order = _format_pipeline_order( 184 schedule.pipeline_order 185 ) 186 # print(formatted_pipeline_order) 187 _validate_pipeline_order( 188 schedule.pipeline_order, 189 num_microbatches, 190 num_stages, 191 enable_zero_bubble=(i == 0), 192 ) 193 194 195instantiate_parametrized_tests(TestSchedulePlan) 196 197 198class TestScheduleLowering(TestCase): 199 """Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules""" 200 201 def _parse_actions(self, actions: List[str]) -> List[_Action]: 202 return [_Action.from_str(s) for s in actions] 203 204 @parametrize( 205 "action_str_and_ref", 206 [ 207 ("1F0", _Action(1, F, 0)), 208 ("2B1", _Action(2, B, 1)), 209 ("0W3", _Action(0, W, 3)), 210 ("1UNSHARD", _Action(1, UNSHARD, None)), 211 ("3RESHARD", _Action(3, RESHARD, None)), 212 ("2SEND_B2", _Action(2, SEND_B, 2)), 213 ("1RECV_F1", _Action(1, RECV_F, 1)), 214 ], 215 ) 216 def test_action_parse(self, action_str_and_ref): 217 """Test that actions can be parsed from strings and round-tripped back to the same strings.""" 218 act_str, ref = action_str_and_ref 219 act = _Action.from_str(act_str) 220 self.assertEqual(act, ref) 221 self.assertEqual(act_str, act.__repr__()) 222 223 @parametrize( 224 "test_info", 225 [ 226 { 227 "compute": ["0F0", "0F1", " ", "0B0", "0B1"], 228 "comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"], 229 }, 230 ], 231 ) 232 def test_unshard_reshard(self, test_info): 233 """Test the lowering pass that takes a 'compute only' schedule (with only F,B,W ops) and adds 234 FSDP unshard/reshard operations to the schedule. This is just part of the process of adding communication 235 ops and producing a complete schedule. 236 """ 237 compute_sch = self._parse_actions(test_info["compute"]) 238 expected_comms_sch = self._parse_actions(test_info["comms"]) 239 240 comms_sch = _add_unshard_reshard(compute_sch) 241 for expected, actual in zip(expected_comms_sch, comms_sch): 242 self.assertEqual( 243 expected, 244 actual, 245 ( 246 f"Mismatch: expected action {expected} but found {actual}." 247 f"\nWhole Schedule: {comms_sch}" 248 ), 249 ) 250 251 @parametrize( 252 "test_info", 253 [ 254 { 255 "compute": { 256 0: ["0F0", "0F1", " ", "0B0", " ", "0B1"], 257 1: [" ", "1F0", "1B0", "1F1", "1B1", " "], 258 }, 259 "comms": { 260 0: [ 261 "0F0", 262 "0SEND_F0", 263 "0F1", 264 "0SEND_F1", 265 "0RECV_B0", 266 "0B0", 267 "0RECV_B1", 268 "0B1", 269 ], 270 1: [ 271 "1RECV_F0", 272 "1RECV_F1", 273 "1F0", 274 "1B0", 275 "1SEND_B0", 276 "1F1", 277 "1B1", 278 "1SEND_B1", 279 ], 280 }, 281 "stage_to_rank": lambda stage_idx: stage_idx, 282 "num_stages": 2, 283 }, 284 ], 285 ) 286 def test_send_recv(self, test_info): 287 """Tests the lowering pass that adds send/recv ops to a compute-only schedule.""" 288 compute_sch = { 289 rank: self._parse_actions(test_info["compute"][rank]) 290 for rank in test_info["compute"] 291 } 292 expected_comms_sch = { 293 rank: self._parse_actions(test_info["comms"][rank]) 294 for rank in test_info["comms"] 295 } 296 297 comms_sch = _add_send_recv( 298 compute_sch, test_info["stage_to_rank"], test_info["num_stages"] 299 ) 300 for rank in expected_comms_sch: 301 for i, (expected, actual) in enumerate( 302 zip(expected_comms_sch[rank], comms_sch[rank]) 303 ): 304 self.assertEqual( 305 expected, 306 actual, 307 ( 308 f"Mismatch on rank {rank} at position {i}." 309 f"\nExpected: {expected_comms_sch[rank]}" 310 f"\nActual: {comms_sch[rank]}" 311 ), 312 ) 313 self.assertEqual(len(comms_sch[rank]), len(expected_comms_sch[rank])) 314 315 316instantiate_parametrized_tests(TestScheduleLowering) 317 318if __name__ == "__main__": 319 run_tests() 320