1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import csv 5import itertools 6import logging 7import re 8from abc import ABC, abstractmethod 9from collections import defaultdict 10from enum import Enum 11from typing import ( 12 Any, 13 Callable, 14 Dict, 15 List, 16 NamedTuple, 17 Optional, 18 Set, 19 Tuple, 20 TYPE_CHECKING, 21 Union, 22) 23 24import torch 25import torch.distributed as dist 26from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle 27from torch.profiler import record_function 28 29from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec 30from .stage import _PipelineStageBase 31 32 33if TYPE_CHECKING: 34 from torch.distributed import Work 35 36__all__ = [ 37 "get_schedule_class", 38 "PipelineScheduleSingle", 39 "PipelineScheduleMulti", 40 "Schedule1F1B", 41 "ScheduleFlexibleInterleaved1F1B", 42 "ScheduleGPipe", 43 "ScheduleInterleaved1F1B", 44 "ScheduleLoopedBFS", 45 "ScheduleInterleavedZeroBubble", 46] 47 48logger = logging.getLogger(__name__) 49 50 51class _ComputationType(Enum): 52 # TODO(whc) rename to _ActType? 53 FORWARD = 1 54 BACKWARD = 2 55 WEIGHT = 3 56 UNSHARD = 4 57 RESHARD = 5 58 SEND_F = 6 59 RECV_F = 7 60 SEND_B = 8 61 RECV_B = 9 62 63 def __str__(self): 64 str_map = { 65 _ComputationType.FORWARD: "F", 66 _ComputationType.BACKWARD: "B", 67 _ComputationType.WEIGHT: "W", 68 _ComputationType.UNSHARD: "UNSHARD", 69 _ComputationType.RESHARD: "RESHARD", 70 _ComputationType.SEND_F: "SEND_F", 71 _ComputationType.RECV_F: "RECV_F", 72 _ComputationType.SEND_B: "SEND_B", 73 _ComputationType.RECV_B: "RECV_B", 74 } 75 return str_map[self] 76 77 @staticmethod 78 def from_str(action): 79 if action == "F": 80 return _ComputationType.FORWARD 81 elif action == "B": 82 return _ComputationType.BACKWARD 83 elif action == "W": 84 return _ComputationType.WEIGHT 85 elif action == "UNSHARD": 86 return _ComputationType.UNSHARD 87 elif action == "RESHARD": 88 return _ComputationType.RESHARD 89 elif action == "SEND_F": 90 return _ComputationType.SEND_F 91 elif action == "RECV_F": 92 return _ComputationType.RECV_F 93 elif action == "SEND_B": 94 return _ComputationType.SEND_B 95 elif action == "RECV_B": 96 return _ComputationType.RECV_B 97 else: 98 raise RuntimeError(f"Invalid computation type {action}") 99 100 101FORWARD = _ComputationType.FORWARD 102BACKWARD = _ComputationType.BACKWARD 103WEIGHT = _ComputationType.WEIGHT 104UNSHARD = _ComputationType.UNSHARD 105RESHARD = _ComputationType.RESHARD 106SEND_F = _ComputationType.SEND_F 107RECV_F = _ComputationType.RECV_F 108SEND_B = _ComputationType.SEND_B 109RECV_B = _ComputationType.RECV_B 110 111# Convenience shorthand for compute actions only since they are used in 'simple schedule format' 112F = FORWARD 113B = BACKWARD 114W = WEIGHT 115 116# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) 117_action_regex = re.compile( 118 r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)" 119) 120 121 122class _Action(NamedTuple): 123 stage_index: int 124 computation_type: _ComputationType 125 microbatch_index: Optional[int] = None 126 127 def __repr__(self): 128 repr = str(self.stage_index) 129 repr += str(self.computation_type) 130 if self.microbatch_index is not None: 131 repr += str(self.microbatch_index) 132 return repr 133 134 @staticmethod 135 def from_str(str): 136 """ 137 Reverse of __repr__ 138 139 String should be formatted as [stage][action type][(microbatch)] 140 e.g. `2F0`, `1UNSHARD`, `3SEND_F1` 141 """ 142 if match := _action_regex.match(str): 143 stage_index, computation_type, microbatch_index = match.groups() 144 return _Action( 145 int(stage_index), 146 _ComputationType.from_str(computation_type), 147 int(microbatch_index) if len(microbatch_index) else None, 148 ) 149 elif str == "" or str.isspace(): 150 return None 151 raise RuntimeError( 152 f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" 153 ) 154 155 156def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str: 157 """ 158 Formats the pipeline order in a timestep (row) x rank (column) grid of actions 159 and returns the formatted string 160 """ 161 # Calculate the maximum number of steps across all ranks 162 num_steps = max(len(actions) for actions in pipeline_order.values()) 163 step_labels = [ 164 "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) 165 ] 166 # Sorting the dictionary by keys and retrieving values in that order 167 rank_actions = [ 168 pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) 169 ] 170 # Transpose the list of lists (rows to columns) 171 transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) 172 # Generate column labels for ranks 173 num_ranks = len(pipeline_order) 174 rank_labels = ["Rank " + str(i) for i in range(num_ranks)] 175 # Calculate the maximum length of each column, considering labels 176 max_lengths = [ 177 max(len(str(item)) if item is not None else 0 for item in col) 178 for col in zip(step_labels, *transposed_actions) 179 ] 180 # Format the header row with rank labels 181 header_row = " " * (len(step_labels[0]) + 2) + " ".join( 182 f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) 183 ) 184 # Format each row with its corresponding label 185 formatted_rows = [ 186 f"{label}: " 187 + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) 188 for label, row in zip(step_labels, transposed_actions) 189 ] 190 # Join the rows into a single string 191 formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" 192 return formatted_table 193 194 195def _validate_pipeline_order( 196 pipeline_order: Dict[int, List[Optional[_Action]]], 197 num_microbatches: int, 198 num_stages: int, 199 enable_zero_bubble: bool = False, 200): 201 """ 202 pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...] 203 Validating that the pipeline order follows the rules: 204 1. Forward action for a microbatch must be before the Backward action for that microbatch 205 2. Recv for a microbatch must be before the send for that microbatch 206 3. Microbatch index is handled in sequential order for each stage 207 4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it 208 5. Same microbatch cannot be handled in the same time step across ranks 209 """ 210 # microbatch_index: (current computation type, current stage) 211 microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {} 212 max_timestep = max(len(rank_list) for rank_list in pipeline_order.values()) 213 for timestep in range(max_timestep): 214 error_msg: List[str] = [] 215 current_timestep_actions = [] 216 for rank in range(len(pipeline_order)): 217 action = ( 218 pipeline_order[rank][timestep] 219 if timestep < len(pipeline_order[rank]) 220 else None 221 ) 222 223 if action is not None: 224 computation_type = action.computation_type 225 if computation_type != _ComputationType.WEIGHT: 226 current_timestep_actions.append(action) 227 228 # TODO: enable this 229 # if len(current_timestep_actions) == 0: 230 # error_msg.append( 231 # "All actions were None, there is an unnecessary gap in the schedule" 232 # ) 233 234 # Ensure that no microbatch is operated on twice in current_timestep_actions 235 unique_microbatch_indices = { 236 action.microbatch_index for action in current_timestep_actions 237 } 238 if len(unique_microbatch_indices) != len(current_timestep_actions): 239 error_msg.append( 240 "Duplicate microbatch index found in current_timestep_actions" 241 ) 242 243 for action in current_timestep_actions: 244 stage_index = action.stage_index 245 computation_type = action.computation_type 246 mb_index = action.microbatch_index 247 assert ( 248 mb_index is not None 249 ), "All currently supported action types require valid microbatch_index" 250 if mb_index >= num_microbatches: 251 error_msg.append(f"Microbatch index {mb_index} out of range") 252 253 # first microbatch 254 if mb_index not in microbatch_process_info: 255 if computation_type != _ComputationType.FORWARD or stage_index != 0: 256 error_msg.append(f"Incorrect start for microbatch {mb_index}") 257 microbatch_process_info[mb_index] = (computation_type, stage_index) 258 else: 259 # if the microbatch is included, check that the current stage is right after prev 260 prev_computation, prev_stage = microbatch_process_info[mb_index] 261 262 if prev_computation == _ComputationType.FORWARD: 263 if prev_stage == num_stages - 1: 264 expected_stage = num_stages - 1 265 expected_computation = _ComputationType.BACKWARD 266 else: 267 expected_stage = prev_stage + 1 268 expected_computation = _ComputationType.FORWARD 269 elif prev_computation == _ComputationType.BACKWARD: 270 if prev_stage == 0: 271 error_msg.append( 272 f"[{mb_index=}] already finished backward computation" 273 ) 274 break 275 else: 276 expected_stage = prev_stage - 1 277 expected_computation = _ComputationType.BACKWARD 278 else: 279 raise ValueError( 280 f"Computation type {prev_computation} not supported" 281 ) 282 283 if expected_computation is not None: 284 if expected_computation != computation_type: 285 error_msg.append( 286 f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}" 287 ) 288 289 if expected_stage != stage_index: 290 error_msg.append( 291 f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}" 292 ) 293 294 microbatch_process_info[mb_index] = ( 295 expected_computation, 296 expected_stage, 297 ) 298 299 if not enable_zero_bubble: 300 if len(error_msg) != 0: 301 raise RuntimeError( 302 f"Error at timestep {timestep}: " + ",".join(error_msg) 303 ) 304 return 305 306 for rank in range(len(pipeline_order)): 307 backward_steps: Set[Tuple[int, int]] = set() 308 weight_steps: Set[Tuple[int, int]] = set() 309 310 for action in pipeline_order[rank]: 311 if action is None: 312 continue 313 314 stage_index = action.stage_index 315 computation_type = action.computation_type 316 mb_index = action.microbatch_index 317 if computation_type == _ComputationType.BACKWARD: 318 if mb_index is not None: 319 backward_steps.add((mb_index, stage_index)) 320 elif computation_type == _ComputationType.WEIGHT: 321 if (mb_index, stage_index) not in backward_steps: 322 error_msg.append( 323 f"{mb_index=}, {stage_index=} Weight happened before bwd" 324 ) 325 if (mb_index, stage_index) in weight_steps: 326 error_msg.append( 327 f"{mb_index=}, {stage_index=} Duplicated weight step" 328 ) 329 if mb_index is not None: 330 weight_steps.add((mb_index, stage_index)) 331 332 if len(backward_steps) != len(weight_steps): 333 error_msg.append("Length weight steps != Length bwd steps") 334 335 if len(error_msg) != 0: 336 raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg)) 337 338 339class _PipelineSchedule(ABC): 340 def __init__( 341 self, 342 n_microbatches: int, 343 loss_fn: Optional[Callable[..., torch.Tensor]] = None, 344 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 345 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 346 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 347 ): 348 # From arguments 349 self._n_microbatches = n_microbatches 350 self._loss_fn = loss_fn 351 # Chunking specification for positional inputs. (default: `None`) 352 self._args_chunk_spec = args_chunk_spec 353 # Chunking specification for keyword inputs. (default: `None`) 354 self._kwargs_chunk_spec = kwargs_chunk_spec 355 self._output_merge_spec = output_merge_spec 356 """ 357 # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. 358 # They are used to convert batch to microbatches in `step(x)`. See 359 # `TensorChunkSpec` for helper methods for creating them. 360 """ 361 362 # Derived 363 self._has_backward = self._loss_fn is not None 364 365 # Holds the losses for each microbatch. 366 self._internal_losses: List[torch.Tensor] = [] 367 logger.info("Using %s", self.__class__.__name__) 368 369 def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): 370 if stage.is_last and self._has_backward: 371 loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] 372 self._internal_losses.append(loss) 373 374 def _maybe_get_loss(self, stage, mb_index): 375 valid_index = 0 <= mb_index < len(self._internal_losses) 376 if stage.is_last and self._has_backward and valid_index: 377 return self._internal_losses[mb_index] 378 elif len(self._internal_losses) != 0 and not valid_index: 379 raise RuntimeError( 380 f"Loss for microbatch {mb_index} is not available. " 381 f"Available losses for microbatches: {self._internal_losses}" 382 ) 383 else: 384 return None 385 386 def _update_losses(self, stages, losses): 387 """ 388 Update the losses to those in the internal state 389 """ 390 # if stages not a list turn into a list 391 if not isinstance(stages, list): 392 stages = [stages] 393 contains_last_stage = any(stage.is_last for stage in stages) 394 395 # Return losses if there is a container passed in 396 if contains_last_stage and losses is not None: 397 if len(self._internal_losses) != self._n_microbatches: 398 raise RuntimeError( 399 f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" 400 ) 401 402 # Clean external container first 403 losses.clear() 404 # Copy internal losses to external container 405 losses.extend(self._internal_losses) 406 407 self._internal_losses.clear() 408 409 @abstractmethod 410 def _step_microbatches( 411 self, 412 arg_mbs: Optional[List] = None, 413 kwarg_mbs: Optional[List] = None, 414 target_mbs: Optional[List] = None, 415 losses: Optional[List] = None, 416 ): 417 """ 418 Run one iteration of the pipeline schedule with list of microbatches. 419 Will go through all the microbatches according to the schedule 420 implementation. 421 422 Args: 423 microbatches: list of microbatch args. 424 """ 425 raise NotImplementedError 426 427 @abstractmethod 428 def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): 429 """ 430 Run one iteration of the pipeline schedule with *whole-batch* input. 431 Will chunk the input into microbatches automatically, and go through the 432 microbatches according to the schedule implementation. 433 434 args: positional arguments to the model (as in non-pipeline case). 435 kwargs: keyword arguments to the model (as in non-pipeline case). 436 target: target for the loss function. 437 losses: a list to store the losses for each microbatch. 438 """ 439 raise NotImplementedError 440 441 def _check_inputs( 442 self, 443 arg_mbs: Optional[List] = None, 444 kwarg_mbs: Optional[List] = None, 445 target_mbs: Optional[List] = None, 446 losses: Optional[List] = None, 447 ): 448 """ 449 Pre-process/check inputs 450 """ 451 452 def check_type_and_len(mbs, name: str): 453 if not isinstance(mbs, list): 454 raise TypeError(f"{name} must be a list but got a {type(mbs)}") 455 if len(mbs) != self._n_microbatches: 456 raise ValueError( 457 f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" 458 ) 459 460 if arg_mbs is not None: 461 check_type_and_len(arg_mbs, "arg_mbs") 462 else: 463 arg_mbs = [()] * self._n_microbatches 464 465 if kwarg_mbs is not None: 466 check_type_and_len(kwarg_mbs, "kwarg_mbs") 467 else: 468 kwarg_mbs = [{}] * self._n_microbatches 469 470 if target_mbs is not None: 471 check_type_and_len(target_mbs, "target_mbs") 472 473 if losses is not None: 474 if not isinstance(losses, list): 475 raise TypeError(f"losses must be a list but got a {type(losses)}") 476 477 return arg_mbs, kwarg_mbs 478 479 def _compute_loss(self, output, target): 480 return self._loss_fn(output, target) # type: ignore[misc] 481 482 def _split_inputs( 483 self, 484 args: Tuple[Any, ...], 485 kwargs: Optional[Dict[str, Any]] = None, 486 ): 487 """ 488 Splits a full-batch input into chunks (i.e. microbatches) and returns 489 the chunks 490 """ 491 if args or kwargs: 492 args_split, kwargs_split = split_args_kwargs_into_chunks( 493 args, 494 kwargs, 495 self._n_microbatches, 496 self._args_chunk_spec, 497 self._kwargs_chunk_spec, 498 ) 499 return args_split, kwargs_split 500 else: 501 # Empty inputs (e.g. when called on middle stages) 502 # Return a list of empty tuples/dicts with matching length as chunks 503 return [()] * self._n_microbatches, [{}] * self._n_microbatches 504 505 def _merge_outputs(self, output_chunks: List[Any]) -> Any: 506 """ 507 Merge output chunks back to a batch state. 508 If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). 509 """ 510 return merge_chunks( 511 output_chunks, 512 self._output_merge_spec, 513 ) 514 515 516def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): 517 """ 518 Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. 519 """ 520 if len(p2p_ops) == 0: 521 return None 522 desc_str = f"{desc}, " if desc else "" 523 logger.debug("batch_p2p %s%s", desc_str, p2p_ops) 524 return dist.batch_isend_irecv(p2p_ops).pop() 525 526 527def _sorted_batch_p2p( 528 p2p_ops: List[dist.P2POp], desc: Optional[str] = None 529) -> Dict[int, dist.Work]: 530 """ 531 Sorts the list of P2P ops by the peer rank, and then calls 532 batch_isend_irecv. Return a dictionary of works by peer rank. This function 533 helps us avoid hangs in case of skip connections. 534 """ 535 # Arrange p2p_ops by peer rank: 536 # int is the peer rank; 537 # List is the list of ops towards the peer 538 ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) 539 work_by_peer: Dict[int, dist.Work] = {} 540 if len(p2p_ops) == 0: 541 return work_by_peer 542 543 # Classify the ops by peer rank 544 for op in p2p_ops: 545 ops_by_peer[op.peer].append(op) 546 547 # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) 548 for peer, ops in sorted(ops_by_peer.items()): 549 work_by_peer[peer] = _batch_p2p(ops, desc=desc) 550 551 return work_by_peer 552 553 554class PipelineScheduleSingle(_PipelineSchedule): 555 """ 556 Base class for single-stage schedules. 557 Implements the `step` method. 558 Derived classes should implement `_step_microbatches`. 559 """ 560 561 def __init__( 562 self, 563 stage: _PipelineStageBase, 564 n_microbatches: int, 565 loss_fn: Optional[Callable] = None, 566 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 567 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 568 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 569 ): 570 # Init parent 571 super().__init__( 572 n_microbatches=n_microbatches, 573 loss_fn=loss_fn, 574 args_chunk_spec=args_chunk_spec, 575 kwargs_chunk_spec=kwargs_chunk_spec, 576 output_merge_spec=output_merge_spec, 577 ) 578 # Self attributes 579 self._stage = stage 580 self._num_stages = stage.num_stages 581 # Set the same has_backward flag for stage object 582 self._stage.has_backward = self._has_backward 583 584 # TODO: later replace this with lazy shape inference during forward 585 # Prepare forward send/recv infrastructure for stage 586 stage._prepare_forward_infra(n_microbatches) 587 if self._has_backward: 588 stage._prepare_backward_infra(n_microbatches) 589 590 def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): 591 """ 592 Run one iteration of the pipeline schedule with *whole-batch* input. 593 Will chunk the input into microbatches automatically, and go through the 594 microbatches according to the schedule implementation. 595 596 args: positional arguments to the model (as in non-pipeline case). 597 kwargs: keyword arguments to the model (as in non-pipeline case). 598 target: target for the loss function. 599 losses: a list to store the losses for each microbatch. 600 """ 601 602 # Clean per iteration 603 self._stage.clear_runtime_states() 604 605 # Split inputs into microbatches 606 args_split, kwargs_split = self._split_inputs(args, kwargs) 607 608 # Split target into microbatches 609 if target is not None: 610 targets_split = list(torch.tensor_split(target, self._n_microbatches)) 611 else: 612 targets_split = None 613 614 # Run microbatches 615 self._step_microbatches(args_split, kwargs_split, targets_split, losses) 616 617 # Return merged results per original format 618 if self._stage.is_last: 619 return self._merge_outputs(self._stage.output_chunks) 620 else: 621 return None 622 623 624class _ScheduleForwardOnly(PipelineScheduleSingle): 625 """ 626 The forward-only schedule. 627 Will go through all the microbatches and perform only the forward pass 628 """ 629 630 def _step_microbatches( 631 self, 632 arg_mbs: Optional[List] = None, 633 kwarg_mbs: Optional[List] = None, 634 target_mbs: Optional[List] = None, 635 losses: Optional[List] = None, 636 ): 637 """ 638 Run one iteration of the pipeline schedule 639 """ 640 if target_mbs is not None or losses is not None: 641 raise RuntimeError( 642 "Forward-only schedule does not support loss computation" 643 ) 644 645 arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) 646 647 # Delay send waits 648 fwd_sends_to_wait: List[dist.Work] = [] 649 650 # Run microbatches 651 for i in range(self._n_microbatches): 652 with record_function(f"Forward {i}"): 653 ops = self._stage.get_fwd_recv_ops(i) 654 works = _sorted_batch_p2p(ops, desc="fwd_recv") 655 for work in works.values(): 656 work.wait() 657 658 self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] 659 660 ops = self._stage.get_fwd_send_ops(i) 661 works = _sorted_batch_p2p(ops, desc="fwd_send") 662 fwd_sends_to_wait.extend(works.values()) 663 664 logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) 665 666 # Wait for all forward sends to finish 667 # This should not have performance impact because by the time the first 668 # backward arrives all the forward sends should have been finished. 669 for work in fwd_sends_to_wait: 670 work.wait() 671 672 673class ScheduleGPipe(PipelineScheduleSingle): 674 """ 675 The GPipe schedule. 676 Will go through all the microbatches in a fill-drain manner. 677 """ 678 679 def _step_microbatches( 680 self, 681 arg_mbs: Optional[List] = None, 682 kwarg_mbs: Optional[List] = None, 683 target_mbs: Optional[List] = None, 684 losses: Optional[List] = None, 685 ): 686 """ 687 Run one iteration of the pipeline schedule with list of microbatches. 688 Will go through all the microbatches according to the GPipe schedule. 689 690 Args: 691 microbatches: list of microbatch args. 692 """ 693 arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) 694 695 # Delay send waits 696 fwd_sends_to_wait: List[dist.Work] = [] 697 698 # Run microbatches 699 for i in range(self._n_microbatches): 700 with record_function(f"Forward {i}"): 701 ops = self._stage.get_fwd_recv_ops(i) 702 works = _sorted_batch_p2p(ops, desc="fwd_recv") 703 for work in works.values(): 704 work.wait() 705 706 output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] 707 708 ops = self._stage.get_fwd_send_ops(i) 709 works = _sorted_batch_p2p(ops, desc="fwd_send") 710 fwd_sends_to_wait.extend(works.values()) 711 712 logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) 713 714 self._maybe_compute_loss(self._stage, output, target_mbs, i) 715 716 # Wait for all forward sends to finish 717 # This should not have performance impact because by the time the first 718 # backward arrives all the forward sends should have been finished. 719 for work in fwd_sends_to_wait: 720 work.wait() 721 722 # No loss function, no need to run backward 723 if not self._has_backward: 724 return 725 726 # Run backward 727 # Delay send waits 728 bwd_sends_to_wait: List[dist.Work] = [] 729 for i in range(self._n_microbatches): 730 with record_function(f"Backward {i}"): 731 ops = self._stage.get_bwd_recv_ops(i) 732 works = _sorted_batch_p2p(ops, desc="bwd_recv") 733 for work in works.values(): 734 work.wait() 735 736 loss = self._maybe_get_loss(self._stage, i) 737 self._stage.backward_one_chunk(i, loss=loss) 738 739 ops = self._stage.get_bwd_send_ops(i) 740 works = _sorted_batch_p2p(ops, desc="bwd_send") 741 bwd_sends_to_wait.extend(works.values()) 742 743 logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) 744 745 # Return losses if there is a container passed in 746 self._update_losses(self._stage, losses) 747 748 # Wait for all backward sends to finish 749 for work in bwd_sends_to_wait: 750 work.wait() 751 752 753class Schedule1F1B(PipelineScheduleSingle): 754 """ 755 The 1F1B schedule. 756 Will perform one forward and one backward on the microbatches in steady state. 757 """ 758 759 def _step_microbatches( 760 self, 761 arg_mbs: Optional[List] = None, 762 kwarg_mbs: Optional[List] = None, 763 target_mbs: Optional[List] = None, 764 losses: Optional[List] = None, 765 ): 766 """ 767 Run one iteration of the pipeline schedule with list of microbatches. 768 Will go through all the microbatches according to the 1F1B schedule. 769 770 Args: 771 microbatches: list of microbatch args. 772 """ 773 arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) 774 775 # Last stage has 1 warmup, second-to-last 2 warmups, ... 776 # first stage `num_stages` warmups 777 warmup_chunks = min( 778 self._n_microbatches, 779 self._num_stages - self._stage.stage_index, 780 ) 781 782 # Chunk counters 783 fwd_mb_index = 0 784 bwd_mb_index = 0 785 weight_stage_mb_index = 0 786 787 # Warmup phase 788 send_work = None 789 fwd_sends = [] 790 for _ in range(warmup_chunks): 791 # Receive activations 792 fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) 793 if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): 794 recv_work.wait() 795 796 # Compute 797 output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] 798 799 # Clear previous chunk's forward sends (hopefully they have well 800 # finished, otherwise, we are heavily communication bound, in which 801 # case it doesn't create a lot of benefit to compute next chunk 802 # eagerly either) 803 if send_work: 804 send_work.wait() 805 806 # Send activations 807 fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) 808 if fwd_mb_index != warmup_chunks - 1: 809 # Safe to fire 810 send_work = _batch_p2p(fwd_sends, desc="fwd_send") 811 # otherwise: 812 # The last foward send is left for fuse with first 1B in 1B1F below 813 814 # Compute loss 815 self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) 816 fwd_mb_index += 1 817 818 # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. 819 820 # 1B1F phase 821 while True: # Don't worry, we have a break inside 822 # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops 823 bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) 824 825 # Now, we need to fire the fwd_sends and bwd_recvs together 826 if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): 827 fuse_work.wait() 828 829 # Backward one chunk 830 loss = self._maybe_get_loss(self._stage, bwd_mb_index) 831 self._stage.backward_one_chunk(bwd_mb_index, loss=loss) 832 833 # Get the bwd send ops, but don't fire, to be fused with the 1F below 834 bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) 835 bwd_mb_index += 1 836 837 if fwd_mb_index == self._n_microbatches: 838 # We are done with 1B1F, so break with some left-over bwd_sends 839 break 840 841 # We prepare 1F of the `1B1F` 842 fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) 843 844 # Fuse it with bwd_sends above 845 if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): 846 fuse_work.wait() 847 848 # Now do the fwd 849 output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] 850 851 # Compute loss 852 self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) 853 854 # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) 855 fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) 856 fwd_mb_index += 1 857 858 # Remember we still have some bwd_sends left over after the break? Now it is time to fire it 859 send_work = _batch_p2p(bwd_sends, desc="bwd_send") 860 861 # Cooldown 862 while bwd_mb_index < self._n_microbatches: 863 # prepare bwd recv ops 864 bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) 865 if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): 866 recv_work.wait() 867 868 # Backward one chunk 869 loss = self._maybe_get_loss(self._stage, bwd_mb_index) 870 self._stage.backward_one_chunk(bwd_mb_index, loss=loss) 871 872 # Clear previous chunk's backward sends (hopefully they have well finished) 873 if send_work: 874 send_work.wait() 875 876 # Get the bwd send ops, fire it 877 bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) 878 send_work = _batch_p2p(bwd_sends, desc="bwd_send") 879 bwd_mb_index += 1 880 881 # Wait for the last backward send to finish 882 if send_work: 883 send_work.wait() 884 885 # Return losses if there is a container passed in 886 self._update_losses(self._stage, losses) 887 888 889def _add_unshard_reshard( 890 compute_actions: List[Optional[_Action]], 891 max_active_stages: int = 3, 892) -> List[_Action]: 893 """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. 894 895 UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. 896 RESHARD does the opposite, releasing memory (but doing no commmunication) 897 898 We abandon the "timestep lock" during lowering 899 900 max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice 901 3 stages is probably the thing we want? 902 (to account for having one f and one b active, and something else prefetching?) 903 """ 904 905 def next_stage_indices( 906 count: int, next_actions: List[Optional[_Action]] 907 ) -> List[int]: 908 """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" 909 seen: Set[int] = set() 910 ret: List[int] = [] 911 912 for a in next_actions: 913 if a is not None and a.stage_index not in seen: 914 seen.add(a.stage_index) 915 ret.append(a.stage_index) 916 if len(ret) == count: 917 break 918 return ret 919 920 active_stages: Set[int] = set() 921 fsdp_aware_actions: List[_Action] = [] 922 923 def _unshard(stage_index: int): 924 active_stages.add(stage_index) 925 fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) 926 927 def _reshard(stage_index: int): 928 active_stages.remove(stage_index) 929 fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) 930 931 for i, action in enumerate(compute_actions): 932 if action is None: 933 continue 934 935 # We prefetch the next N stages we'll see, dropping existing stages to make room 936 next_n = next_stage_indices(max_active_stages, compute_actions[i:]) 937 # Fetch needs to be ordered correctly, so don't use a set 938 fetch = list(filter(lambda s: s not in active_stages, next_n)) 939 # Unclear what the best policy is for eviction, but we can maintain order so we do 940 evict = list(filter(lambda s: s not in next_n, active_stages)) 941 942 # logger.debug( 943 # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", 944 # i, 945 # active_stages, 946 # fetch, 947 # evict, 948 # ) 949 950 for stage in evict: 951 _reshard(stage) 952 for stage in fetch: 953 _unshard(stage) 954 fsdp_aware_actions.append(action) 955 956 return fsdp_aware_actions 957 958 959def _add_send_recv( 960 compute_actions: Dict[int, List[_Action]], 961 stage_to_rank: Callable[[int], int], 962 num_stages: int, 963) -> Dict[int, List[_Action]]: 964 comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions} 965 966 def _has_comms(action: _Action) -> bool: 967 if action.computation_type == F: 968 return action.stage_index != num_stages - 1 969 elif action.computation_type == B: 970 return action.stage_index != 0 971 return False 972 973 def _get_comms(action: _Action) -> Tuple[_Action, _Action]: 974 assert _has_comms(action), f"{action} is not a valid comm action" 975 stage_idx = action.stage_index 976 ctype = action.computation_type 977 mb_idx = action.microbatch_index 978 send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) 979 recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 980 recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) 981 return send, recv 982 983 def _ready_to_schedule( 984 action: Optional[_Action], prev_actions: List[_Action] 985 ) -> bool: 986 """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. 987 This helps ensure a sane (non-hanging) ordering of sends and recvs. 988 But it also means we might not be able to schedule our next compute action yet. 989 """ 990 if action is None: 991 return True 992 elif action.computation_type == F and not action.stage_index == 0: 993 expected_recv = _Action( 994 action.stage_index, 995 RECV_F if action.computation_type == F else RECV_B, 996 action.microbatch_index, 997 ) 998 return expected_recv in prev_actions 999 elif action.computation_type == B and not action.stage_index == num_stages - 1: 1000 expected_recv = _Action( 1001 action.stage_index, 1002 RECV_F if action.computation_type == F else RECV_B, 1003 action.microbatch_index, 1004 ) 1005 return expected_recv in prev_actions 1006 else: 1007 return True 1008 1009 while compute_actions: 1010 progress = False 1011 # go in order of ranks even if dict keys aren't ordered 1012 for rank in range(len(compute_actions)): 1013 assert len(compute_actions[rank]) > 0 1014 action = compute_actions[rank][0] 1015 1016 if not _ready_to_schedule(action, comm_actions[rank]): 1017 continue 1018 1019 if action is not None: 1020 comm_actions[rank].append(action) 1021 if _has_comms(action): 1022 send, recv = _get_comms(action) 1023 # TODO we can avoid send/recv if the 2 stages are on the same rank. 1024 # should we avoid that in the runtime or here? 1025 comm_actions[rank].append(send) 1026 comm_actions[stage_to_rank(recv.stage_index)].append(recv) 1027 1028 compute_actions[rank].pop(0) 1029 if len(compute_actions[rank]) == 0: 1030 del compute_actions[rank] 1031 progress = True 1032 assert progress, "Malformed compute schedule, can't schedule sends/recvs" 1033 return comm_actions 1034 1035 1036class PipelineScheduleMulti(_PipelineSchedule): 1037 """ 1038 Base class for multi-stage schedules. 1039 Implements the `step` method. 1040 """ 1041 1042 def __init__( 1043 self, 1044 stages: List[_PipelineStageBase], 1045 n_microbatches: int, 1046 loss_fn: Optional[Callable] = None, 1047 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 1048 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 1049 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 1050 stage_index_to_group_rank: Optional[Dict[int, int]] = None, 1051 use_full_backward: bool = True, 1052 ): 1053 if len(stages) <= 1: 1054 raise ValueError( 1055 f"Multi-stage schedule expects at least two stages but got {len(stages)}" 1056 ) 1057 # Init parent 1058 super().__init__( 1059 n_microbatches=n_microbatches, 1060 loss_fn=loss_fn, 1061 args_chunk_spec=args_chunk_spec, 1062 kwargs_chunk_spec=kwargs_chunk_spec, 1063 output_merge_spec=output_merge_spec, 1064 ) 1065 # Self attributes 1066 self._stages = stages 1067 self._num_stages = stages[0].num_stages 1068 self.pp_group_size = stages[0].group_size 1069 self.rank = stages[0].group_rank 1070 # Set the pipeline stage states 1071 if stage_index_to_group_rank is not None: 1072 for stage in self._stages: 1073 stage.stage_index_to_group_rank = stage_index_to_group_rank 1074 self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank 1075 1076 # Set the same has_backward flag for stage object 1077 for stage in self._stages: 1078 stage.has_backward = self._has_backward 1079 1080 self._should_compute_loss = ( 1081 lambda stage: stage.is_last and self._loss_fn is not None 1082 ) 1083 1084 # This will be set during init of derived schedules 1085 self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} 1086 self.use_full_backward = use_full_backward 1087 1088 # TODO: later replace this with lazy shape inference during forward 1089 # Prepare forward send/recv infrastructure for stage 1090 for stage in self._stages: 1091 stage._prepare_forward_infra(n_microbatches) 1092 if self._has_backward: 1093 stage._prepare_backward_infra(n_microbatches) 1094 1095 def _dump_csv(self, filename): 1096 """Dump a CSV representation of the schedule into a file with the provided filename.""" 1097 with open(filename, "w", newline="") as csvfile: 1098 writer = csv.writer(csvfile) 1099 for rank in self.pipeline_order: 1100 writer.writerow(self.pipeline_order[rank]) 1101 1102 def _validate_schedule(self): 1103 # TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554 1104 def _validate_rank_actions( 1105 actions: Dict[int, List[_Action | None]], 1106 num_stages: int, 1107 num_microbatches: int, 1108 ): 1109 # We will count all the actions per stage and ensure they happen in a valid order 1110 # (e.g. F before B before W for a given microbatch) 1111 stage_actions: Dict[int, Dict[_ComputationType, Set]] = { 1112 stage_id: { 1113 F: set(), 1114 B: set(), 1115 W: set(), 1116 } 1117 for stage_id in range(num_stages) 1118 } 1119 for rank in actions: 1120 for action in actions[rank]: 1121 if action is None: 1122 continue 1123 assert isinstance( 1124 action, _Action 1125 ), f"Got an invalid action: {action}, expected instance of _Action" 1126 s_id = action.stage_index 1127 ctype = action.computation_type 1128 mb_id = action.microbatch_index 1129 if ctype == F: 1130 stage_actions[s_id][F].add(mb_id) 1131 elif ctype == B: 1132 assert ( 1133 mb_id in stage_actions[s_id][F] 1134 ), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward" 1135 stage_actions[s_id][B].add(mb_id) 1136 elif ctype == W: 1137 assert ( 1138 not self.use_full_backward 1139 ), "Schedule contains 'W' actions, but is configured to use full backward" 1140 assert ( 1141 mb_id in stage_actions[s_id][B] 1142 ), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward" 1143 stage_actions[s_id][W].add(mb_id) 1144 1145 for s_id in stage_actions: 1146 for ctype in (F, B, W): 1147 stage_mb = len(stage_actions[s_id][ctype]) 1148 assert ( 1149 stage_mb == num_microbatches 1150 ), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}" 1151 1152 assert ( 1153 len(self.pipeline_order) == self.pp_group_size 1154 ), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}" 1155 for rank in range(self.pp_group_size): 1156 assert ( 1157 rank in self.pipeline_order 1158 ), f"Schedule is missing actions for rank {rank}" 1159 _validate_rank_actions( 1160 self.pipeline_order, 1161 self._num_stages, 1162 self._n_microbatches, 1163 ) 1164 1165 def _load_csv(self, filename, format="compute_only"): 1166 """Load a CSV representation of the schedule from a file with the provided filename. 1167 This API will most likely get renamed/refactored so is marked as internal for now. 1168 1169 format must be "compute_only" for PipelineScheduleMulti 1170 """ 1171 assert format == "compute_only" 1172 with open(filename, newline="") as csvfile: 1173 reader = csv.reader(csvfile) 1174 for rank, row in enumerate(reader): 1175 self.pipeline_order[rank] = [_Action.from_str(s) for s in row] 1176 self._validate_schedule() 1177 1178 def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): 1179 """ 1180 Run one iteration of the pipeline schedule with *whole-batch* input. 1181 Will chunk the input into microbatches automatically, and go through the 1182 microbatches according to the schedule implementation. 1183 1184 args: positional arguments to the model (as in non-pipeline case). 1185 kwargs: keyword arguments to the model (as in non-pipeline case). 1186 target: target for the loss function. 1187 losses: a list to store the losses for each microbatch. 1188 """ 1189 1190 # Clean per iteration 1191 for stage in self._stages: 1192 stage.clear_runtime_states() 1193 1194 # Split inputs into microbatches 1195 args_split, kwargs_split = self._split_inputs(args, kwargs) 1196 1197 # Split target into microbatches 1198 if target is not None: 1199 targets_split = list(torch.tensor_split(target, self._n_microbatches)) 1200 else: 1201 targets_split = None 1202 1203 # Run microbatches 1204 self._step_microbatches(args_split, kwargs_split, targets_split, losses) 1205 1206 # Return merged results per original format 1207 for stage in self._stages: 1208 if stage.is_last: 1209 return self._merge_outputs(stage.output_chunks) 1210 # Does not contain the last stage 1211 return None 1212 1213 def _step_microbatches( 1214 self, 1215 arg_mbs: Optional[List] = None, 1216 kwarg_mbs: Optional[List] = None, 1217 target_mbs: Optional[List] = None, 1218 losses: Optional[List] = None, 1219 ): 1220 """ 1221 Operate on the microbatches for looped schedules (multiple stages on each rank). 1222 1223 TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does 1224 not support models with skip connections. 1225 """ 1226 arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) 1227 1228 # Based on the plan in Step 1 created in __init__: 1229 # 2. Perform communication based on the pipeline_order 1230 stage_index_to_stage: Dict[int, _PipelineStageBase] = { 1231 stage.stage_index: stage for stage in self._stages 1232 } 1233 1234 # determine prev_rank and next_rank based on which ranks are next to 1235 # the stages in the pipeline_order 1236 all_prev_ranks: Set[int] = set() 1237 all_next_ranks: Set[int] = set() 1238 for stage_index in stage_index_to_stage.keys(): 1239 # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) 1240 if stage_index > 0: 1241 all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) 1242 if stage_index < self._num_stages - 1: 1243 all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) 1244 1245 for time_step, action in enumerate(self.pipeline_order[self.rank]): 1246 try: 1247 ops: List[dist.P2POp] = [] 1248 if action is not None: 1249 computation_type = action.computation_type 1250 mb_index = action.microbatch_index 1251 stage_index = action.stage_index 1252 assert ( 1253 mb_index is not None 1254 ), "All currently supported action types require valid microbatch_index" 1255 if computation_type == _ComputationType.FORWARD: 1256 # perform forward computation 1257 stage = stage_index_to_stage[stage_index] 1258 output = stage.forward_one_chunk( 1259 mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] 1260 ) 1261 self._maybe_compute_loss(stage, output, target_mbs, mb_index) 1262 ops.extend(stage.get_fwd_send_ops(mb_index)) 1263 elif computation_type == _ComputationType.BACKWARD: 1264 # perform backward computation 1265 stage = stage_index_to_stage[stage_index] 1266 loss = self._maybe_get_loss(stage, mb_index) 1267 stage.backward_one_chunk( 1268 mb_index, loss=loss, full_backward=self.use_full_backward 1269 ) 1270 ops.extend(stage.get_bwd_send_ops(mb_index)) 1271 elif computation_type == _ComputationType.WEIGHT: 1272 # perform weight update 1273 if self.use_full_backward: 1274 raise ValueError( 1275 f"We detected a weight update in the pipeline schedule, but \ 1276 {self.use_full_backward=}" 1277 ) 1278 stage = stage_index_to_stage[stage_index] 1279 stage.backward_weight_one_chunk(mb_index) 1280 else: 1281 raise ValueError(f"Unknown computation type {computation_type}") 1282 1283 # Look at the neighboring ranks for this current timestep and determine whether 1284 # this current rank needs to do any recv communication 1285 for prev_rank in all_prev_ranks: 1286 prev_rank_ops = self.pipeline_order[prev_rank] 1287 prev_rank_action = None 1288 if time_step < len(prev_rank_ops): 1289 prev_rank_action = prev_rank_ops[time_step] 1290 if prev_rank_action is not None: 1291 computation_type = prev_rank_action.computation_type 1292 mb_index = prev_rank_action.microbatch_index 1293 stage_index = prev_rank_action.stage_index 1294 assert ( 1295 mb_index is not None 1296 ), "All currently supported action types require valid microbatch_index" 1297 # Only handle sends for the forward from a previous rank 1298 if computation_type == _ComputationType.FORWARD: 1299 # If not the last stage, then receive fwd activations 1300 if stage_index + 1 in stage_index_to_stage: 1301 # TODO: We are assuming that stage will always receive from stage-1 1302 # however that is not necessarily true of get_fwd_recv_ops 1303 stage = stage_index_to_stage[stage_index + 1] 1304 ops.extend(stage.get_fwd_recv_ops(mb_index)) 1305 elif ( 1306 computation_type == _ComputationType.BACKWARD 1307 or computation_type == _ComputationType.WEIGHT 1308 ): 1309 # Previous rank doing backward or weight update has no influence for the current rank forward recv 1310 pass 1311 else: 1312 raise ValueError( 1313 f"Unknown computation type {computation_type}" 1314 ) 1315 for next_rank in all_next_ranks: 1316 next_rank_ops = self.pipeline_order[next_rank] 1317 next_rank_action = None 1318 if time_step < len(next_rank_ops): 1319 next_rank_action = next_rank_ops[time_step] 1320 if next_rank_action is not None: 1321 computation_type = next_rank_action.computation_type 1322 mb_index = next_rank_action.microbatch_index 1323 stage_index = next_rank_action.stage_index 1324 assert ( 1325 mb_index is not None 1326 ), "All currently supported action types require valid microbatch_index" 1327 # Only handle receives for the backwards from a next rank 1328 if ( 1329 computation_type == _ComputationType.FORWARD 1330 or computation_type == _ComputationType.WEIGHT 1331 ): 1332 # Next rank doing forward or weight update has no influence for the current rank backward recv 1333 pass 1334 elif computation_type == _ComputationType.BACKWARD: 1335 # If not the first stage, then receive bwd gradients 1336 if stage_index - 1 in stage_index_to_stage: 1337 # TODO: We are assuming that stage will always receive from stage+1 1338 # however that is not necessarily true of get_bwd_recv_ops 1339 stage = stage_index_to_stage[stage_index - 1] 1340 ops.extend(stage.get_bwd_recv_ops(mb_index)) 1341 else: 1342 raise ValueError( 1343 f"Unknown computation type {computation_type}" 1344 ) 1345 1346 # do the communication 1347 if ops: 1348 _batch_p2p(ops).wait() 1349 except Exception as e: 1350 logger.error( 1351 "[Rank %s] pipeline schedule %s caught the following exception \ 1352 at time_step %s when running action %s", 1353 self.rank, 1354 self.__class__.__name__, 1355 time_step, 1356 action, 1357 ) 1358 logger.error("%s", _format_pipeline_order(self.pipeline_order)) 1359 raise e 1360 # Return losses if there is a container passed in 1361 self._update_losses(self._stages, losses) 1362 1363 1364class _PipelineScheduleRuntime(PipelineScheduleMulti): 1365 """ 1366 Provides a simple runtime that requires a 'schedule IR' including specified communication operations. 1367 1368 Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be 1369 subclassed and the subclass can be responsible for creating a schedule IR. 1370 """ 1371 1372 def _load_actions( 1373 self, 1374 actions: Dict[int, List[Optional[_Action]]], 1375 format: str = "compute_only", 1376 ): 1377 """ 1378 Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including 1379 communication actions. Stores the schedule in self, and must be called before running step_mo() 1380 """ 1381 assert ( 1382 self.stage_index_to_group_rank is not None 1383 ), "stage_index_to_group_rank is required for PipelineScheduleRuntime" 1384 self.pipeline_order_with_comms: Dict[int, List[_Action]] = {} 1385 if format == "compute_comms": 1386 for rank in actions: 1387 self.pipeline_order_with_comms[rank] = [] 1388 for action in actions[rank]: 1389 assert action is not None 1390 self.pipeline_order_with_comms[rank].append(action) 1391 # TODO what level of validation should we offer for compute+comms schedule? 1392 elif format == "compute_only": 1393 # Perform schedule lowering 1394 for rank in actions: 1395 self.pipeline_order_with_comms[rank] = _add_unshard_reshard( 1396 actions[rank] 1397 ) 1398 1399 self.pipeline_order_with_comms = _add_send_recv( 1400 self.pipeline_order_with_comms, 1401 stage_to_rank=lambda s: self.stage_index_to_group_rank[s], 1402 num_stages=self._num_stages, 1403 ) 1404 else: 1405 raise NotImplementedError(f"{format=} is not implemented") 1406 1407 def _load_csv(self, filename: str, format: str = "compute_only"): 1408 """Loads a csv in simple format and then lowers it to include comunication actions 1409 1410 format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes 1411 will automatically be run to generate a compute_comms schedule. 1412 """ 1413 if format == "compute_only": 1414 # this will populate self.pipeline_order 1415 super()._load_csv(filename) 1416 # this will populate self.pipeline_order_with_comms 1417 self._load_actions(self.pipeline_order) 1418 elif format == "compute_comms": 1419 actions = {} 1420 with open(filename, newline="") as csvfile: 1421 reader = csv.reader(csvfile) 1422 for rank, row in enumerate(reader): 1423 actions[rank] = [_Action.from_str(s) for s in row] 1424 self._load_actions(actions, format=format) 1425 else: 1426 raise NotImplementedError(f"{format=} is not implemented") 1427 1428 def _dump_csv(self, filename: str): 1429 """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" 1430 # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible 1431 # that it does not exist if it was created from a compute_comms schedule. 1432 assert ( 1433 self.pipeline_order_with_comms is not None 1434 ), "Must initialize compute_comms schedule before dump_csv" 1435 with open(filename, "w", newline="") as csvfile: 1436 writer = csv.writer(csvfile) 1437 for rank in self.pipeline_order_with_comms: 1438 writer.writerow(self.pipeline_order_with_comms[rank]) 1439 1440 def _step_microbatches( 1441 self, 1442 arg_mbs: Optional[List] = None, 1443 kwarg_mbs: Optional[List] = None, 1444 target_mbs: Optional[List] = None, 1445 losses: Optional[List] = None, 1446 ): 1447 """ 1448 Operate on the microbatches for looped schedules (multiple stages on each rank). 1449 1450 TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does 1451 not support models with skip connections. 1452 """ 1453 arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) 1454 1455 # Based on the plan in Step 1 created in __init__: 1456 # 2. Perform communication based on the pipeline_order 1457 stage_index_to_stage: Dict[int, _PipelineStageBase] = { 1458 stage.stage_index: stage for stage in self._stages 1459 } 1460 1461 assert ( 1462 self.pipeline_order_with_comms is not None 1463 ), "Must call _load_actions() before calling _step_microbatches()" 1464 1465 # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use 1466 bwd_recv_ops: Dict[Tuple[int, int], Work] = {} 1467 fwd_recv_ops: Dict[Tuple[int, int], Work] = {} 1468 1469 # send ops should be waited on before step() exists, mainly for hygeine 1470 send_ops: List[Work] = [] 1471 1472 # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages 1473 unshard_ops: Dict[int, UnshardHandle] = {} 1474 unsharded_stages = set() 1475 1476 def _assert_unsharded(stage_idx: int): 1477 """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" 1478 if stage_idx in unshard_ops: 1479 unshard_ops[stage_idx].wait() 1480 del unshard_ops[stage_idx] 1481 unsharded_stages.add(stage_idx) 1482 assert ( 1483 stage_idx in unsharded_stages 1484 ), f"Attempted to compute on sharded {stage_idx=}" 1485 1486 for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): 1487 try: 1488 comp_type = action.computation_type 1489 mb_index: int = ( 1490 action.microbatch_index 1491 if action.microbatch_index is not None 1492 else -1 1493 ) 1494 assert mb_index >= 0 or comp_type in ( 1495 UNSHARD, 1496 RESHARD, 1497 ), f"{action=} missing mb_index" 1498 stage_idx = action.stage_index 1499 stage = stage_index_to_stage[stage_idx] 1500 stage_uses_fsdp = isinstance(stage.submod, FSDPModule) 1501 1502 logger.debug( 1503 "_PipelineScheduleRuntime running time_step %d, action %s", 1504 time_step, 1505 action, 1506 ) 1507 1508 # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, 1509 # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be 1510 # safe to use instead. 1511 # However, I was wondering if I should avoid calling batched operators at all in the case that there is 1512 # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. 1513 if comp_type == SEND_F: 1514 send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) 1515 elif comp_type == SEND_B: 1516 send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) 1517 elif comp_type == RECV_F: 1518 assert ( 1519 stage_idx, 1520 mb_index, 1521 ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward" 1522 fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( 1523 stage.get_fwd_recv_ops(mb_index) 1524 ) 1525 elif comp_type == RECV_B: 1526 assert ( 1527 stage_idx, 1528 mb_index, 1529 ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward" 1530 bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( 1531 stage.get_bwd_recv_ops(mb_index) 1532 ) 1533 elif comp_type == UNSHARD: 1534 if stage_uses_fsdp: 1535 assert ( 1536 stage_idx not in unsharded_stages 1537 and stage_idx not in unshard_ops 1538 ), f"Unsharding the same {stage_idx=} twice" 1539 unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) 1540 elif comp_type == RESHARD: 1541 if stage_uses_fsdp: 1542 assert ( 1543 stage_idx in unsharded_stages 1544 ), f"Resharding {stage_idx=} without unsharding" 1545 assert ( 1546 stage_idx not in unshard_ops 1547 ), f"Resharding {stage_idx=} before finishing unshard" 1548 stage.submod.reshard() 1549 elif comp_type == FORWARD: 1550 if stage_uses_fsdp: 1551 _assert_unsharded(stage_idx) 1552 1553 if not stage.is_first: 1554 assert ( 1555 stage_idx, 1556 mb_index, 1557 ) in fwd_recv_ops, f"Computing {action=} before receiving input" 1558 fwd_recv_ops.pop((stage_idx, mb_index)).wait() 1559 output = stage.forward_one_chunk( 1560 mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] 1561 ) 1562 self._maybe_compute_loss(stage, output, target_mbs, mb_index) 1563 elif comp_type == BACKWARD: 1564 if stage_uses_fsdp: 1565 _assert_unsharded(stage_idx) 1566 1567 if not stage.is_last: 1568 assert ( 1569 stage_idx, 1570 mb_index, 1571 ) in bwd_recv_ops, ( 1572 f"Attempted to run compute {action=} before receiving input" 1573 ) 1574 bwd_recv_ops.pop((stage_idx, mb_index)).wait() 1575 loss = self._maybe_get_loss(stage, mb_index) 1576 stage.backward_one_chunk( 1577 mb_index, loss=loss, full_backward=self.use_full_backward 1578 ) 1579 elif comp_type == WEIGHT: 1580 if stage_uses_fsdp: 1581 _assert_unsharded(stage_idx) 1582 1583 if self.use_full_backward: 1584 raise ValueError( 1585 f"We detected a weight update in the pipeline schedule, but \ 1586 {self.use_full_backward=}" 1587 ) 1588 stage.backward_weight_one_chunk(mb_index) 1589 else: 1590 raise ValueError(f"{action=} is unknown or unsupported") 1591 except Exception as e: 1592 logger.error( 1593 "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", 1594 time_step, 1595 action, 1596 ) 1597 # TODO(whc) what is the best practice for printing a multiline log? 1598 # logger will split it into multiple log lines, but this makes it hard to read (too wide) 1599 print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type] 1600 raise e 1601 1602 # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them 1603 while len(send_ops): 1604 send_ops.pop().wait() 1605 1606 assert len(unshard_ops) == 0, "Unused unshard operations" 1607 1608 # Return losses if there is a container passed in 1609 self._update_losses(self._stages, losses) 1610 1611 1612class ScheduleLoopedBFS(PipelineScheduleMulti): 1613 """ 1614 Breadth-First Pipeline Parallelism. 1615 See https://arxiv.org/abs/2211.05953 for details. 1616 Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. 1617 What is different is that when microbatches are ready for multiple local 1618 stages, Loops BFS will prioritizes the earlier stage, running all available 1619 microbatches at once. 1620 """ 1621 1622 def __init__( 1623 self, 1624 stages: List[_PipelineStageBase], 1625 n_microbatches: int, 1626 loss_fn: Optional[Callable] = None, 1627 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 1628 ): 1629 super().__init__( 1630 stages=stages, 1631 n_microbatches=n_microbatches, 1632 loss_fn=loss_fn, 1633 output_merge_spec=output_merge_spec, 1634 ) 1635 1636 # 1. Create the pipeline_order (all ranks do this calculation) 1637 # This will be used to keep track of the current state of the entire pipeline 1638 # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] 1639 self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} 1640 # ======================================================================== 1641 for rank in range(self.pp_group_size): 1642 rank_ops = self._calculate_single_rank_operations(rank) 1643 self.pipeline_order[rank] = rank_ops 1644 1645 def _calculate_single_rank_operations(self, rank): 1646 n_local_stages = len(self._stages) 1647 stage_indices = range( 1648 rank, self.pp_group_size * n_local_stages, self.pp_group_size 1649 ) 1650 1651 # Store the list of operations used for that rank 1652 rank_ops: List[Optional[_Action]] = [] 1653 # Pre-padding, rank starts with no-ops based on the warmup. 1654 for _ in range(rank): 1655 rank_ops.append(None) 1656 1657 for stage_index in stage_indices: 1658 for mb_index in range(self._n_microbatches): 1659 rank_ops.append( 1660 _Action(stage_index, _ComputationType.FORWARD, mb_index) 1661 ) 1662 1663 # wait for the first backward to trickle up 1664 # which is 2 for every hop away 1665 post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) 1666 rank_ops.extend([None] * post_warmup_ops) 1667 1668 for stage_index in reversed(stage_indices): 1669 for mb_index in reversed(range(self._n_microbatches)): 1670 rank_ops.append( 1671 _Action(stage_index, _ComputationType.BACKWARD, mb_index) 1672 ) 1673 return rank_ops 1674 1675 1676def _get_1f1b_rank_ops( 1677 n_local_stages, 1678 pp_group_size, 1679 warmup_ops, 1680 fwd_bwd_ops, 1681 cooldown_ops, 1682 rank, 1683 forward_stage_index, 1684 backward_stage_index, 1685 num_1f1b_microbatches=0, 1686 enable_zero_bubble=False, 1687): 1688 # All stages start with handling microbatch 0 1689 fwd_stage_mb_index: Dict[int, int] = defaultdict(int) 1690 bwd_stage_mb_index: Dict[int, int] = defaultdict(int) 1691 weight_stage_mb_index: Dict[int, int] = defaultdict(int) 1692 1693 # Store the list of operations used for that rank 1694 rank_ops: List[Optional[_Action]] = [] 1695 # Pre-padding, rank starts with no-ops based on the warmup. 1696 for _ in range(rank): 1697 rank_ops.append(None) 1698 # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup 1699 # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. 1700 # Formula: 1701 # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward 1702 # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) 1703 # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] 1704 # warmup_ops = calculated above 1705 post_warmup_ops = ( 1706 n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) 1707 ) - (warmup_ops + rank) 1708 1709 if enable_zero_bubble: 1710 post_warmup_ops = pp_group_size - rank - 1 1711 1712 total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops 1713 1714 backward_op_ids = [] 1715 weight_op_count = 0 1716 1717 for op in range(total_ops): 1718 # Warmup phase 1719 if op < warmup_ops: 1720 fwd_stage_index = forward_stage_index(op) 1721 # This will assign the current microbatch index and update it as well 1722 fwd_stage_mb_index[fwd_stage_index] = ( 1723 mb_index := fwd_stage_mb_index[fwd_stage_index] 1724 ) + 1 1725 rank_ops.append( 1726 _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) 1727 ) 1728 if op == warmup_ops - 1: 1729 # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up 1730 rank_ops.extend([None] * post_warmup_ops) 1731 # 1F1B Phase (forward and backward) 1732 elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: 1733 fwd_stage_index = forward_stage_index(op) 1734 fwd_stage_mb_index[fwd_stage_index] = ( 1735 fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] 1736 ) + 1 1737 rank_ops.append( 1738 _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) 1739 ) 1740 bwd_stage_index = backward_stage_index(op) 1741 bwd_stage_mb_index[bwd_stage_index] = ( 1742 bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] 1743 ) + 1 1744 rank_ops.append( 1745 _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index) 1746 ) 1747 backward_op_ids.append(op) 1748 1749 if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: 1750 weight_stage_index = backward_stage_index( 1751 backward_op_ids[weight_op_count] 1752 ) 1753 weight_stage_mb_index[weight_stage_index] = ( 1754 weight_mb_index := weight_stage_mb_index[weight_stage_index] 1755 ) + 1 1756 rank_ops.append( 1757 _Action( 1758 weight_stage_index, _ComputationType.WEIGHT, weight_mb_index 1759 ) 1760 ) 1761 weight_op_count += 1 1762 # Cooldown phase 1763 else: 1764 # During cooldown phase, we need steps to align with 1f1b happening in other ranks 1765 # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None 1766 if not enable_zero_bubble: 1767 rank_ops.append(None) 1768 1769 bwd_stage_index = backward_stage_index(op) 1770 bwd_stage_mb_index[bwd_stage_index] = ( 1771 bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] 1772 ) + 1 1773 rank_ops.append( 1774 _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index) 1775 ) 1776 backward_op_ids.append(op) 1777 1778 if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: 1779 weight_stage_index = backward_stage_index( 1780 backward_op_ids[weight_op_count] 1781 ) 1782 weight_stage_mb_index[weight_stage_index] = ( 1783 weight_mb_index := weight_stage_mb_index[weight_stage_index] 1784 ) + 1 1785 rank_ops.append( 1786 _Action( 1787 weight_stage_index, _ComputationType.WEIGHT, weight_mb_index 1788 ) 1789 ) 1790 weight_op_count += 1 1791 1792 while enable_zero_bubble and weight_op_count < len(backward_op_ids): 1793 weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) 1794 weight_stage_mb_index[weight_stage_index] = ( 1795 weight_mb_index := weight_stage_mb_index[weight_stage_index] 1796 ) + 1 1797 rank_ops.append( 1798 _Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index) 1799 ) 1800 weight_op_count += 1 1801 1802 return rank_ops 1803 1804 1805class ScheduleInterleaved1F1B(PipelineScheduleMulti): 1806 """ 1807 The Interleaved 1F1B schedule. 1808 See https://arxiv.org/pdf/2104.04473 for details. 1809 Will perform one forward and one backward on the microbatches in steady 1810 state and supports multiple stages per rank. When microbatches are ready for 1811 multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch 1812 (also called "depth first"). 1813 """ 1814 1815 def __init__( 1816 self, 1817 stages: List[_PipelineStageBase], 1818 n_microbatches: int, 1819 loss_fn: Optional[Callable] = None, 1820 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 1821 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 1822 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 1823 ): 1824 self.pp_group_size = stages[0].group_size 1825 # TODO: is this limitation a must? 1826 if n_microbatches % self.pp_group_size != 0: 1827 raise ValueError( 1828 f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ 1829 to be a multiple of the number of pipeline ranks ({self.pp_group_size})." 1830 ) 1831 1832 super().__init__( 1833 stages=stages, 1834 n_microbatches=n_microbatches, 1835 loss_fn=loss_fn, 1836 args_chunk_spec=args_chunk_spec, 1837 kwargs_chunk_spec=kwargs_chunk_spec, 1838 output_merge_spec=output_merge_spec, 1839 ) 1840 1841 self.n_local_stages = len(stages) 1842 self.rank = stages[0].group_rank 1843 self.group = stages[0].group 1844 1845 # 1. Create the pipeline_order (all ranks do this calculation) 1846 # This will be used to keep track of the current state of the entire pipeline 1847 # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] 1848 self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} 1849 1850 for rank in range(self.pp_group_size): 1851 rank_ops = self._calculate_single_rank_operations(rank) 1852 self.pipeline_order[rank] = rank_ops 1853 1854 def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: 1855 def get_rank_warmup_ops(rank): 1856 # Warms up operations for last stage 1857 warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size 1858 # Increment warmup operations by 2 for each hop away from the last stage 1859 warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) 1860 # We cannot have more warmup operations than there are number of microbatches, so cap it there 1861 return min(warmup_ops, self._n_microbatches * self.n_local_stages) 1862 1863 warmup_ops = get_rank_warmup_ops(rank) 1864 microbatch_ops = self.n_local_stages * self._n_microbatches 1865 # fwd_bwd_ops should encompass the remaining forwards 1866 fwd_bwd_ops = microbatch_ops - warmup_ops 1867 # cooldown_ops should encompass the remaining backwards 1868 cooldown_ops = microbatch_ops - fwd_bwd_ops 1869 # total ops encompass both forward and backward ops 1870 total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops 1871 # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 1872 1873 logger.debug( 1874 "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", 1875 rank, 1876 warmup_ops, 1877 fwd_bwd_ops, 1878 cooldown_ops, 1879 total_ops, 1880 ) 1881 1882 # Calculates the stage index based on step and pp_group_size 1883 def forward_stage_index(step): 1884 # Get the local index from 0 to n_local_stages-1 1885 local_index = (step // self.pp_group_size) % self.n_local_stages 1886 return (local_index * self.pp_group_size) + rank 1887 1888 def backward_stage_index(step): 1889 local_index = ( 1890 self.n_local_stages 1891 - 1 1892 - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages 1893 ) 1894 return (local_index * self.pp_group_size) + rank 1895 1896 return _get_1f1b_rank_ops( 1897 self.n_local_stages, 1898 self.pp_group_size, 1899 warmup_ops, 1900 fwd_bwd_ops, 1901 cooldown_ops, 1902 rank, 1903 forward_stage_index, 1904 backward_stage_index, 1905 ) 1906 1907 1908class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): 1909 """ 1910 The Flexible Interleaved 1F1B schedule. 1911 1912 This schedule is mostly similar to the interleaved 1F1B schedule. 1913 It differs by being relaxing the requirement of num_microbatch % pp_size == 0. 1914 Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and 1915 it works as long as n_microbatches % num_rounds is 0. As a few examples, support 1916 1917 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. 1918 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. 1919 1920 When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5 1921 """ 1922 1923 def __init__( 1924 self, 1925 stages: List[_PipelineStageBase], 1926 n_microbatches: int, 1927 loss_fn: Optional[Callable] = None, 1928 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 1929 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 1930 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 1931 enable_zero_bubble: bool = False, 1932 ): 1933 self.pp_group_size = stages[0].group_size 1934 super().__init__( 1935 stages=stages, 1936 n_microbatches=n_microbatches, 1937 loss_fn=loss_fn, 1938 args_chunk_spec=args_chunk_spec, 1939 kwargs_chunk_spec=kwargs_chunk_spec, 1940 output_merge_spec=output_merge_spec, 1941 use_full_backward=not enable_zero_bubble, 1942 ) 1943 self.n_local_stages = len(stages) 1944 self.rank = stages[0].group_rank 1945 self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) 1946 self.microbatches_per_round = n_microbatches // self.number_of_rounds 1947 self.enable_zero_bubble = enable_zero_bubble 1948 if n_microbatches % self.number_of_rounds != 0: 1949 raise ValueError( 1950 "Flexible Interleaved 1F1B requires the number of microbatches to be a " 1951 f"multiple of the number of rounds ({self.number_of_rounds}), " 1952 f"but got {n_microbatches}." 1953 ) 1954 # 1. Create the pipeline_order (all ranks do this calculation) 1955 # This will be used to keep track of the current state of the entire pipeline 1956 # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] 1957 self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} 1958 for rank in range(self.pp_group_size): 1959 rank_ops = self._calculate_single_rank_operations(rank) 1960 self.pipeline_order[rank] = rank_ops 1961 1962 # This function add bubbles to the generated schedule based on dependencies of actions 1963 # Note that the ZB1P schedule will not require bubbles to be manually added and it is 1964 # only useful when n_microbatches <= microbatches_per_round 1965 self.pipeline_order = self._add_bubbles_to_actions( 1966 self.n_local_stages * self.pp_group_size, 1967 ) 1968 1969 def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: 1970 def get_rank_warmup_ops(rank): 1971 # Warms up operations for last stage 1972 warmups_ops_last_stage = ( 1973 self.n_local_stages - 1 1974 ) * self.microbatches_per_round 1975 # Increment warmup operations by 2 for each hop away from the last stage 1976 multiply_factor = 1 if self.enable_zero_bubble else 2 1977 warmup_ops = warmups_ops_last_stage + multiply_factor * ( 1978 (self.pp_group_size - 1) - rank 1979 ) 1980 1981 # We cannot have more warmup operations than there are number of microbatches, so cap it there 1982 return min(warmup_ops, self._n_microbatches * self.n_local_stages) 1983 1984 warmup_ops = get_rank_warmup_ops(rank) 1985 microbatch_ops = self.n_local_stages * self._n_microbatches 1986 # fwd_bwd_ops should encompass the remaining forwards 1987 fwd_bwd_ops = microbatch_ops - warmup_ops 1988 # cooldown_ops should encompass the remaining backwards 1989 cooldown_ops = microbatch_ops - fwd_bwd_ops 1990 # total ops encompass both forward and backward ops 1991 total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops 1992 # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 1993 logger.debug( 1994 "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", 1995 rank, 1996 warmup_ops, 1997 fwd_bwd_ops, 1998 cooldown_ops, 1999 total_ops, 2000 ) 2001 2002 # Calculates the stage index based on step and pp_group_size 2003 2004 def forward_stage_index(step): 2005 # Get the local index from 0 to n_local_stages-1 2006 local_index = (step // self.microbatches_per_round) % self.n_local_stages 2007 return (local_index * self.pp_group_size) + rank 2008 2009 def backward_stage_index(step): 2010 local_index = ( 2011 self.n_local_stages 2012 - 1 2013 - ((step - warmup_ops) // self.microbatches_per_round) 2014 % self.n_local_stages 2015 ) 2016 return (local_index * self.pp_group_size) + rank 2017 2018 if self.enable_zero_bubble: 2019 num_1f1b_microbatches = rank 2020 2021 return _get_1f1b_rank_ops( 2022 self.n_local_stages, 2023 self.pp_group_size, 2024 warmup_ops, 2025 fwd_bwd_ops, 2026 cooldown_ops, 2027 rank, 2028 forward_stage_index, 2029 backward_stage_index, 2030 num_1f1b_microbatches, 2031 enable_zero_bubble=True, 2032 ) 2033 2034 return _get_1f1b_rank_ops( 2035 self.n_local_stages, 2036 self.pp_group_size, 2037 warmup_ops, 2038 fwd_bwd_ops, 2039 cooldown_ops, 2040 rank, 2041 forward_stage_index, 2042 backward_stage_index, 2043 ) 2044 2045 def _add_bubbles_to_actions(self, num_stages_global): 2046 actions = self.pipeline_order 2047 if not self.enable_zero_bubble: 2048 return actions 2049 2050 def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): 2051 if op == _ComputationType.FORWARD: 2052 if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: 2053 return True 2054 elif op == _ComputationType.BACKWARD: 2055 if stage == num_stages_global - 1: 2056 return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops 2057 return (stage + 1, op, microbatch) not in seen_ops 2058 return False 2059 2060 seen_ops: Set[Tuple[int, _ComputationType, int]] = set() 2061 result: Dict[int, List[Optional[_Action]]] = {} 2062 next_pointer: Dict[int, int] = {} 2063 bubbles_added: Dict[int, int] = {} 2064 total_bubbles_added = 0 2065 2066 for rank in range(self.pp_group_size): 2067 result[rank] = [] 2068 next_pointer[rank] = 0 2069 bubbles_added[rank] = 0 2070 2071 while True: 2072 should_stop = True 2073 2074 temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set() 2075 2076 for rank in range(self.pp_group_size): 2077 timestamp = next_pointer[rank] 2078 if timestamp >= len(actions[rank]): 2079 continue 2080 2081 should_stop = False 2082 2083 if actions[rank][timestamp] is not None: 2084 temp_action = actions[rank][timestamp] 2085 assert temp_action is not None 2086 stage_index, op, microbatch = temp_action 2087 if not need_bubble( 2088 stage_index, op, microbatch, num_stages_global, seen_ops 2089 ): 2090 result[rank].append(actions[rank][timestamp]) 2091 if microbatch is not None: 2092 temp_seen_ops.add((stage_index, op, microbatch)) 2093 next_pointer[rank] += 1 2094 else: 2095 result[rank].append(None) 2096 bubbles_added[rank] += 1 2097 else: 2098 next_pointer[rank] += 1 2099 result[rank].append(None) 2100 2101 seen_ops.update(temp_seen_ops) 2102 if should_stop: 2103 break 2104 2105 if total_bubbles_added > 0: 2106 logger.warning( 2107 "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", 2108 total_bubbles_added, 2109 bubbles_added, 2110 ) 2111 return result 2112 2113 2114class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B): 2115 """ 2116 The Interleaved Zero Bubble schedule. 2117 See https://arxiv.org/pdf/2401.10241 for details. 2118 Will perform one forward and one backward on inputs for the microbatches in steady 2119 state and supports multiple stages per rank. Uses the backward for weights to fill in 2120 the pipeline bubble. 2121 """ 2122 2123 def __init__( 2124 self, 2125 stages: List[_PipelineStageBase], 2126 n_microbatches: int, 2127 loss_fn: Optional[Callable] = None, 2128 args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, 2129 kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, 2130 output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 2131 ): 2132 super().__init__( 2133 stages=stages, 2134 n_microbatches=n_microbatches, 2135 loss_fn=loss_fn, 2136 args_chunk_spec=args_chunk_spec, 2137 kwargs_chunk_spec=kwargs_chunk_spec, 2138 output_merge_spec=output_merge_spec, 2139 enable_zero_bubble=True, 2140 ) 2141 2142 2143def get_schedule_class(schedule_name: str): 2144 """ 2145 Maps a schedule name to its corresponding class object. 2146 2147 Args: 2148 schedule_name (str): The name of the schedule. 2149 """ 2150 schedule_map = { 2151 "1F1B": Schedule1F1B, 2152 "Interleaved1F1B": ScheduleInterleaved1F1B, 2153 "GPipe": ScheduleGPipe, 2154 "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B, 2155 "LoopedBFS": ScheduleLoopedBFS, 2156 "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, 2157 "PipelineScheduleSingle": PipelineScheduleSingle, 2158 "PipelineScheduleMulti": PipelineScheduleMulti, 2159 } 2160 if schedule_name not in schedule_map: 2161 raise ValueError(f"Unknown schedule name: {schedule_name}") 2162 return schedule_map[schedule_name] 2163