xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/schedules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import csv
5import itertools
6import logging
7import re
8from abc import ABC, abstractmethod
9from collections import defaultdict
10from enum import Enum
11from typing import (
12    Any,
13    Callable,
14    Dict,
15    List,
16    NamedTuple,
17    Optional,
18    Set,
19    Tuple,
20    TYPE_CHECKING,
21    Union,
22)
23
24import torch
25import torch.distributed as dist
26from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
27from torch.profiler import record_function
28
29from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
30from .stage import _PipelineStageBase
31
32
33if TYPE_CHECKING:
34    from torch.distributed import Work
35
36__all__ = [
37    "get_schedule_class",
38    "PipelineScheduleSingle",
39    "PipelineScheduleMulti",
40    "Schedule1F1B",
41    "ScheduleFlexibleInterleaved1F1B",
42    "ScheduleGPipe",
43    "ScheduleInterleaved1F1B",
44    "ScheduleLoopedBFS",
45    "ScheduleInterleavedZeroBubble",
46]
47
48logger = logging.getLogger(__name__)
49
50
51class _ComputationType(Enum):
52    # TODO(whc) rename to _ActType?
53    FORWARD = 1
54    BACKWARD = 2
55    WEIGHT = 3
56    UNSHARD = 4
57    RESHARD = 5
58    SEND_F = 6
59    RECV_F = 7
60    SEND_B = 8
61    RECV_B = 9
62
63    def __str__(self):
64        str_map = {
65            _ComputationType.FORWARD: "F",
66            _ComputationType.BACKWARD: "B",
67            _ComputationType.WEIGHT: "W",
68            _ComputationType.UNSHARD: "UNSHARD",
69            _ComputationType.RESHARD: "RESHARD",
70            _ComputationType.SEND_F: "SEND_F",
71            _ComputationType.RECV_F: "RECV_F",
72            _ComputationType.SEND_B: "SEND_B",
73            _ComputationType.RECV_B: "RECV_B",
74        }
75        return str_map[self]
76
77    @staticmethod
78    def from_str(action):
79        if action == "F":
80            return _ComputationType.FORWARD
81        elif action == "B":
82            return _ComputationType.BACKWARD
83        elif action == "W":
84            return _ComputationType.WEIGHT
85        elif action == "UNSHARD":
86            return _ComputationType.UNSHARD
87        elif action == "RESHARD":
88            return _ComputationType.RESHARD
89        elif action == "SEND_F":
90            return _ComputationType.SEND_F
91        elif action == "RECV_F":
92            return _ComputationType.RECV_F
93        elif action == "SEND_B":
94            return _ComputationType.SEND_B
95        elif action == "RECV_B":
96            return _ComputationType.RECV_B
97        else:
98            raise RuntimeError(f"Invalid computation type {action}")
99
100
101FORWARD = _ComputationType.FORWARD
102BACKWARD = _ComputationType.BACKWARD
103WEIGHT = _ComputationType.WEIGHT
104UNSHARD = _ComputationType.UNSHARD
105RESHARD = _ComputationType.RESHARD
106SEND_F = _ComputationType.SEND_F
107RECV_F = _ComputationType.RECV_F
108SEND_B = _ComputationType.SEND_B
109RECV_B = _ComputationType.RECV_B
110
111# Convenience shorthand for compute actions only since they are used in 'simple schedule format'
112F = FORWARD
113B = BACKWARD
114W = WEIGHT
115
116# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
117_action_regex = re.compile(
118    r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)"
119)
120
121
122class _Action(NamedTuple):
123    stage_index: int
124    computation_type: _ComputationType
125    microbatch_index: Optional[int] = None
126
127    def __repr__(self):
128        repr = str(self.stage_index)
129        repr += str(self.computation_type)
130        if self.microbatch_index is not None:
131            repr += str(self.microbatch_index)
132        return repr
133
134    @staticmethod
135    def from_str(str):
136        """
137        Reverse of __repr__
138
139        String should be formatted as [stage][action type][(microbatch)]
140            e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
141        """
142        if match := _action_regex.match(str):
143            stage_index, computation_type, microbatch_index = match.groups()
144            return _Action(
145                int(stage_index),
146                _ComputationType.from_str(computation_type),
147                int(microbatch_index) if len(microbatch_index) else None,
148            )
149        elif str == "" or str.isspace():
150            return None
151        raise RuntimeError(
152            f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
153        )
154
155
156def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
157    """
158    Formats the pipeline order in a timestep (row) x rank (column) grid of actions
159    and returns the formatted string
160    """
161    # Calculate the maximum number of steps across all ranks
162    num_steps = max(len(actions) for actions in pipeline_order.values())
163    step_labels = [
164        "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
165    ]
166    # Sorting the dictionary by keys and retrieving values in that order
167    rank_actions = [
168        pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
169    ]
170    # Transpose the list of lists (rows to columns)
171    transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
172    # Generate column labels for ranks
173    num_ranks = len(pipeline_order)
174    rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
175    # Calculate the maximum length of each column, considering labels
176    max_lengths = [
177        max(len(str(item)) if item is not None else 0 for item in col)
178        for col in zip(step_labels, *transposed_actions)
179    ]
180    # Format the header row with rank labels
181    header_row = " " * (len(step_labels[0]) + 2) + " ".join(
182        f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
183    )
184    # Format each row with its corresponding label
185    formatted_rows = [
186        f"{label}: "
187        + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
188        for label, row in zip(step_labels, transposed_actions)
189    ]
190    # Join the rows into a single string
191    formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
192    return formatted_table
193
194
195def _validate_pipeline_order(
196    pipeline_order: Dict[int, List[Optional[_Action]]],
197    num_microbatches: int,
198    num_stages: int,
199    enable_zero_bubble: bool = False,
200):
201    """
202    pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
203    Validating that the pipeline order follows the rules:
204    1. Forward action for a microbatch must be before the Backward action for that microbatch
205    2. Recv for a microbatch must be before the send for that microbatch
206    3. Microbatch index is handled in sequential order for each stage
207    4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
208    5. Same microbatch cannot be handled in the same time step across ranks
209    """
210    # microbatch_index: (current computation type, current stage)
211    microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
212    max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
213    for timestep in range(max_timestep):
214        error_msg: List[str] = []
215        current_timestep_actions = []
216        for rank in range(len(pipeline_order)):
217            action = (
218                pipeline_order[rank][timestep]
219                if timestep < len(pipeline_order[rank])
220                else None
221            )
222
223            if action is not None:
224                computation_type = action.computation_type
225                if computation_type != _ComputationType.WEIGHT:
226                    current_timestep_actions.append(action)
227
228        # TODO: enable this
229        # if len(current_timestep_actions) == 0:
230        #     error_msg.append(
231        #         "All actions were None, there is an unnecessary gap in the schedule"
232        #     )
233
234        # Ensure that no microbatch is operated on twice in current_timestep_actions
235        unique_microbatch_indices = {
236            action.microbatch_index for action in current_timestep_actions
237        }
238        if len(unique_microbatch_indices) != len(current_timestep_actions):
239            error_msg.append(
240                "Duplicate microbatch index found in current_timestep_actions"
241            )
242
243        for action in current_timestep_actions:
244            stage_index = action.stage_index
245            computation_type = action.computation_type
246            mb_index = action.microbatch_index
247            assert (
248                mb_index is not None
249            ), "All currently supported action types require valid microbatch_index"
250            if mb_index >= num_microbatches:
251                error_msg.append(f"Microbatch index {mb_index} out of range")
252
253            # first microbatch
254            if mb_index not in microbatch_process_info:
255                if computation_type != _ComputationType.FORWARD or stage_index != 0:
256                    error_msg.append(f"Incorrect start for microbatch {mb_index}")
257                microbatch_process_info[mb_index] = (computation_type, stage_index)
258            else:
259                # if the microbatch is included, check that the current stage is right after prev
260                prev_computation, prev_stage = microbatch_process_info[mb_index]
261
262                if prev_computation == _ComputationType.FORWARD:
263                    if prev_stage == num_stages - 1:
264                        expected_stage = num_stages - 1
265                        expected_computation = _ComputationType.BACKWARD
266                    else:
267                        expected_stage = prev_stage + 1
268                        expected_computation = _ComputationType.FORWARD
269                elif prev_computation == _ComputationType.BACKWARD:
270                    if prev_stage == 0:
271                        error_msg.append(
272                            f"[{mb_index=}] already finished backward computation"
273                        )
274                        break
275                    else:
276                        expected_stage = prev_stage - 1
277                        expected_computation = _ComputationType.BACKWARD
278                else:
279                    raise ValueError(
280                        f"Computation type {prev_computation} not supported"
281                    )
282
283                if expected_computation is not None:
284                    if expected_computation != computation_type:
285                        error_msg.append(
286                            f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
287                        )
288
289                if expected_stage != stage_index:
290                    error_msg.append(
291                        f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
292                    )
293
294                microbatch_process_info[mb_index] = (
295                    expected_computation,
296                    expected_stage,
297                )
298
299        if not enable_zero_bubble:
300            if len(error_msg) != 0:
301                raise RuntimeError(
302                    f"Error at timestep {timestep}: " + ",".join(error_msg)
303                )
304            return
305
306        for rank in range(len(pipeline_order)):
307            backward_steps: Set[Tuple[int, int]] = set()
308            weight_steps: Set[Tuple[int, int]] = set()
309
310            for action in pipeline_order[rank]:
311                if action is None:
312                    continue
313
314                stage_index = action.stage_index
315                computation_type = action.computation_type
316                mb_index = action.microbatch_index
317                if computation_type == _ComputationType.BACKWARD:
318                    if mb_index is not None:
319                        backward_steps.add((mb_index, stage_index))
320                elif computation_type == _ComputationType.WEIGHT:
321                    if (mb_index, stage_index) not in backward_steps:
322                        error_msg.append(
323                            f"{mb_index=}, {stage_index=} Weight happened before bwd"
324                        )
325                    if (mb_index, stage_index) in weight_steps:
326                        error_msg.append(
327                            f"{mb_index=}, {stage_index=} Duplicated weight step"
328                        )
329                    if mb_index is not None:
330                        weight_steps.add((mb_index, stage_index))
331
332            if len(backward_steps) != len(weight_steps):
333                error_msg.append("Length weight steps != Length bwd steps")
334
335        if len(error_msg) != 0:
336            raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
337
338
339class _PipelineSchedule(ABC):
340    def __init__(
341        self,
342        n_microbatches: int,
343        loss_fn: Optional[Callable[..., torch.Tensor]] = None,
344        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
345        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
346        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
347    ):
348        # From arguments
349        self._n_microbatches = n_microbatches
350        self._loss_fn = loss_fn
351        # Chunking specification for positional inputs. (default: `None`)
352        self._args_chunk_spec = args_chunk_spec
353        # Chunking specification for keyword inputs. (default: `None`)
354        self._kwargs_chunk_spec = kwargs_chunk_spec
355        self._output_merge_spec = output_merge_spec
356        """
357        # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
358        # They are used to convert batch to microbatches in `step(x)`.  See
359        # `TensorChunkSpec` for helper methods for creating them.
360        """
361
362        # Derived
363        self._has_backward = self._loss_fn is not None
364
365        # Holds the losses for each microbatch.
366        self._internal_losses: List[torch.Tensor] = []
367        logger.info("Using %s", self.__class__.__name__)
368
369    def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
370        if stage.is_last and self._has_backward:
371            loss = self._compute_loss(output, target_mbs[mb_index])  # type: ignore[index]
372            self._internal_losses.append(loss)
373
374    def _maybe_get_loss(self, stage, mb_index):
375        valid_index = 0 <= mb_index < len(self._internal_losses)
376        if stage.is_last and self._has_backward and valid_index:
377            return self._internal_losses[mb_index]
378        elif len(self._internal_losses) != 0 and not valid_index:
379            raise RuntimeError(
380                f"Loss for microbatch {mb_index} is not available. "
381                f"Available losses for microbatches: {self._internal_losses}"
382            )
383        else:
384            return None
385
386    def _update_losses(self, stages, losses):
387        """
388        Update the losses to those in the internal state
389        """
390        # if stages not a list turn into a list
391        if not isinstance(stages, list):
392            stages = [stages]
393        contains_last_stage = any(stage.is_last for stage in stages)
394
395        # Return losses if there is a container passed in
396        if contains_last_stage and losses is not None:
397            if len(self._internal_losses) != self._n_microbatches:
398                raise RuntimeError(
399                    f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
400                )
401
402            # Clean external container first
403            losses.clear()
404            # Copy internal losses to external container
405            losses.extend(self._internal_losses)
406
407        self._internal_losses.clear()
408
409    @abstractmethod
410    def _step_microbatches(
411        self,
412        arg_mbs: Optional[List] = None,
413        kwarg_mbs: Optional[List] = None,
414        target_mbs: Optional[List] = None,
415        losses: Optional[List] = None,
416    ):
417        """
418        Run one iteration of the pipeline schedule with list of microbatches.
419        Will go through all the microbatches according to the schedule
420        implementation.
421
422        Args:
423            microbatches: list of microbatch args.
424        """
425        raise NotImplementedError
426
427    @abstractmethod
428    def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
429        """
430        Run one iteration of the pipeline schedule with *whole-batch* input.
431        Will chunk the input into microbatches automatically, and go through the
432        microbatches according to the schedule implementation.
433
434        args: positional arguments to the model (as in non-pipeline case).
435        kwargs: keyword arguments to the model (as in non-pipeline case).
436        target: target for the loss function.
437        losses: a list to store the losses for each microbatch.
438        """
439        raise NotImplementedError
440
441    def _check_inputs(
442        self,
443        arg_mbs: Optional[List] = None,
444        kwarg_mbs: Optional[List] = None,
445        target_mbs: Optional[List] = None,
446        losses: Optional[List] = None,
447    ):
448        """
449        Pre-process/check inputs
450        """
451
452        def check_type_and_len(mbs, name: str):
453            if not isinstance(mbs, list):
454                raise TypeError(f"{name} must be a list but got a {type(mbs)}")
455            if len(mbs) != self._n_microbatches:
456                raise ValueError(
457                    f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
458                )
459
460        if arg_mbs is not None:
461            check_type_and_len(arg_mbs, "arg_mbs")
462        else:
463            arg_mbs = [()] * self._n_microbatches
464
465        if kwarg_mbs is not None:
466            check_type_and_len(kwarg_mbs, "kwarg_mbs")
467        else:
468            kwarg_mbs = [{}] * self._n_microbatches
469
470        if target_mbs is not None:
471            check_type_and_len(target_mbs, "target_mbs")
472
473        if losses is not None:
474            if not isinstance(losses, list):
475                raise TypeError(f"losses must be a list but got a {type(losses)}")
476
477        return arg_mbs, kwarg_mbs
478
479    def _compute_loss(self, output, target):
480        return self._loss_fn(output, target)  # type: ignore[misc]
481
482    def _split_inputs(
483        self,
484        args: Tuple[Any, ...],
485        kwargs: Optional[Dict[str, Any]] = None,
486    ):
487        """
488        Splits a full-batch input into chunks (i.e. microbatches) and returns
489        the chunks
490        """
491        if args or kwargs:
492            args_split, kwargs_split = split_args_kwargs_into_chunks(
493                args,
494                kwargs,
495                self._n_microbatches,
496                self._args_chunk_spec,
497                self._kwargs_chunk_spec,
498            )
499            return args_split, kwargs_split
500        else:
501            # Empty inputs (e.g. when called on middle stages)
502            # Return a list of empty tuples/dicts with matching length as chunks
503            return [()] * self._n_microbatches, [{}] * self._n_microbatches
504
505    def _merge_outputs(self, output_chunks: List[Any]) -> Any:
506        """
507        Merge output chunks back to a batch state.
508        If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
509        """
510        return merge_chunks(
511            output_chunks,
512            self._output_merge_spec,
513        )
514
515
516def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
517    """
518    Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
519    """
520    if len(p2p_ops) == 0:
521        return None
522    desc_str = f"{desc}, " if desc else ""
523    logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
524    return dist.batch_isend_irecv(p2p_ops).pop()
525
526
527def _sorted_batch_p2p(
528    p2p_ops: List[dist.P2POp], desc: Optional[str] = None
529) -> Dict[int, dist.Work]:
530    """
531    Sorts the list of P2P ops by the peer rank, and then calls
532    batch_isend_irecv. Return a dictionary of works by peer rank. This function
533    helps us avoid hangs in case of skip connections.
534    """
535    # Arrange p2p_ops by peer rank:
536    #   int is the peer rank;
537    #   List is the list of ops towards the peer
538    ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
539    work_by_peer: Dict[int, dist.Work] = {}
540    if len(p2p_ops) == 0:
541        return work_by_peer
542
543    # Classify the ops by peer rank
544    for op in p2p_ops:
545        ops_by_peer[op.peer].append(op)
546
547    # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
548    for peer, ops in sorted(ops_by_peer.items()):
549        work_by_peer[peer] = _batch_p2p(ops, desc=desc)
550
551    return work_by_peer
552
553
554class PipelineScheduleSingle(_PipelineSchedule):
555    """
556    Base class for single-stage schedules.
557    Implements the `step` method.
558    Derived classes should implement `_step_microbatches`.
559    """
560
561    def __init__(
562        self,
563        stage: _PipelineStageBase,
564        n_microbatches: int,
565        loss_fn: Optional[Callable] = None,
566        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
567        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
568        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
569    ):
570        # Init parent
571        super().__init__(
572            n_microbatches=n_microbatches,
573            loss_fn=loss_fn,
574            args_chunk_spec=args_chunk_spec,
575            kwargs_chunk_spec=kwargs_chunk_spec,
576            output_merge_spec=output_merge_spec,
577        )
578        # Self attributes
579        self._stage = stage
580        self._num_stages = stage.num_stages
581        # Set the same has_backward flag for stage object
582        self._stage.has_backward = self._has_backward
583
584        # TODO: later replace this with lazy shape inference during forward
585        # Prepare forward send/recv infrastructure for stage
586        stage._prepare_forward_infra(n_microbatches)
587        if self._has_backward:
588            stage._prepare_backward_infra(n_microbatches)
589
590    def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
591        """
592        Run one iteration of the pipeline schedule with *whole-batch* input.
593        Will chunk the input into microbatches automatically, and go through the
594        microbatches according to the schedule implementation.
595
596        args: positional arguments to the model (as in non-pipeline case).
597        kwargs: keyword arguments to the model (as in non-pipeline case).
598        target: target for the loss function.
599        losses: a list to store the losses for each microbatch.
600        """
601
602        # Clean per iteration
603        self._stage.clear_runtime_states()
604
605        # Split inputs into microbatches
606        args_split, kwargs_split = self._split_inputs(args, kwargs)
607
608        # Split target into microbatches
609        if target is not None:
610            targets_split = list(torch.tensor_split(target, self._n_microbatches))
611        else:
612            targets_split = None
613
614        # Run microbatches
615        self._step_microbatches(args_split, kwargs_split, targets_split, losses)
616
617        # Return merged results per original format
618        if self._stage.is_last:
619            return self._merge_outputs(self._stage.output_chunks)
620        else:
621            return None
622
623
624class _ScheduleForwardOnly(PipelineScheduleSingle):
625    """
626    The forward-only schedule.
627    Will go through all the microbatches and perform only the forward pass
628    """
629
630    def _step_microbatches(
631        self,
632        arg_mbs: Optional[List] = None,
633        kwarg_mbs: Optional[List] = None,
634        target_mbs: Optional[List] = None,
635        losses: Optional[List] = None,
636    ):
637        """
638        Run one iteration of the pipeline schedule
639        """
640        if target_mbs is not None or losses is not None:
641            raise RuntimeError(
642                "Forward-only schedule does not support loss computation"
643            )
644
645        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
646
647        # Delay send waits
648        fwd_sends_to_wait: List[dist.Work] = []
649
650        # Run microbatches
651        for i in range(self._n_microbatches):
652            with record_function(f"Forward {i}"):
653                ops = self._stage.get_fwd_recv_ops(i)
654                works = _sorted_batch_p2p(ops, desc="fwd_recv")
655                for work in works.values():
656                    work.wait()
657
658                self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i])  # type: ignore[index]
659
660                ops = self._stage.get_fwd_send_ops(i)
661                works = _sorted_batch_p2p(ops, desc="fwd_send")
662                fwd_sends_to_wait.extend(works.values())
663
664            logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
665
666        # Wait for all forward sends to finish
667        # This should not have performance impact because by the time the first
668        # backward arrives all the forward sends should have been finished.
669        for work in fwd_sends_to_wait:
670            work.wait()
671
672
673class ScheduleGPipe(PipelineScheduleSingle):
674    """
675    The GPipe schedule.
676    Will go through all the microbatches in a fill-drain manner.
677    """
678
679    def _step_microbatches(
680        self,
681        arg_mbs: Optional[List] = None,
682        kwarg_mbs: Optional[List] = None,
683        target_mbs: Optional[List] = None,
684        losses: Optional[List] = None,
685    ):
686        """
687        Run one iteration of the pipeline schedule with list of microbatches.
688        Will go through all the microbatches according to the GPipe schedule.
689
690        Args:
691            microbatches: list of microbatch args.
692        """
693        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
694
695        # Delay send waits
696        fwd_sends_to_wait: List[dist.Work] = []
697
698        # Run microbatches
699        for i in range(self._n_microbatches):
700            with record_function(f"Forward {i}"):
701                ops = self._stage.get_fwd_recv_ops(i)
702                works = _sorted_batch_p2p(ops, desc="fwd_recv")
703                for work in works.values():
704                    work.wait()
705
706                output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i])  # type: ignore[index]
707
708                ops = self._stage.get_fwd_send_ops(i)
709                works = _sorted_batch_p2p(ops, desc="fwd_send")
710                fwd_sends_to_wait.extend(works.values())
711
712            logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
713
714            self._maybe_compute_loss(self._stage, output, target_mbs, i)
715
716        # Wait for all forward sends to finish
717        # This should not have performance impact because by the time the first
718        # backward arrives all the forward sends should have been finished.
719        for work in fwd_sends_to_wait:
720            work.wait()
721
722        # No loss function, no need to run backward
723        if not self._has_backward:
724            return
725
726        # Run backward
727        # Delay send waits
728        bwd_sends_to_wait: List[dist.Work] = []
729        for i in range(self._n_microbatches):
730            with record_function(f"Backward {i}"):
731                ops = self._stage.get_bwd_recv_ops(i)
732                works = _sorted_batch_p2p(ops, desc="bwd_recv")
733                for work in works.values():
734                    work.wait()
735
736                loss = self._maybe_get_loss(self._stage, i)
737                self._stage.backward_one_chunk(i, loss=loss)
738
739                ops = self._stage.get_bwd_send_ops(i)
740                works = _sorted_batch_p2p(ops, desc="bwd_send")
741                bwd_sends_to_wait.extend(works.values())
742
743            logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
744
745        # Return losses if there is a container passed in
746        self._update_losses(self._stage, losses)
747
748        # Wait for all backward sends to finish
749        for work in bwd_sends_to_wait:
750            work.wait()
751
752
753class Schedule1F1B(PipelineScheduleSingle):
754    """
755    The 1F1B schedule.
756    Will perform one forward and one backward on the microbatches in steady state.
757    """
758
759    def _step_microbatches(
760        self,
761        arg_mbs: Optional[List] = None,
762        kwarg_mbs: Optional[List] = None,
763        target_mbs: Optional[List] = None,
764        losses: Optional[List] = None,
765    ):
766        """
767        Run one iteration of the pipeline schedule with list of microbatches.
768        Will go through all the microbatches according to the 1F1B schedule.
769
770        Args:
771            microbatches: list of microbatch args.
772        """
773        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
774
775        # Last stage has 1 warmup, second-to-last 2 warmups, ...
776        # first stage `num_stages` warmups
777        warmup_chunks = min(
778            self._n_microbatches,
779            self._num_stages - self._stage.stage_index,
780        )
781
782        # Chunk counters
783        fwd_mb_index = 0
784        bwd_mb_index = 0
785        weight_stage_mb_index = 0
786
787        # Warmup phase
788        send_work = None
789        fwd_sends = []
790        for _ in range(warmup_chunks):
791            # Receive activations
792            fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
793            if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
794                recv_work.wait()
795
796            # Compute
797            output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])  # type: ignore[index]
798
799            # Clear previous chunk's forward sends (hopefully they have well
800            # finished, otherwise, we are heavily communication bound, in which
801            # case it doesn't create a lot of benefit to compute next chunk
802            # eagerly either)
803            if send_work:
804                send_work.wait()
805
806            # Send activations
807            fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
808            if fwd_mb_index != warmup_chunks - 1:
809                # Safe to fire
810                send_work = _batch_p2p(fwd_sends, desc="fwd_send")
811            # otherwise:
812            #   The last foward send is left for fuse with first 1B in 1B1F below
813
814            # Compute loss
815            self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
816            fwd_mb_index += 1
817
818        # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
819
820        # 1B1F phase
821        while True:  # Don't worry, we have a break inside
822            # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
823            bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
824
825            # Now, we need to fire the fwd_sends and bwd_recvs together
826            if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
827                fuse_work.wait()
828
829            # Backward one chunk
830            loss = self._maybe_get_loss(self._stage, bwd_mb_index)
831            self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
832
833            # Get the bwd send ops, but don't fire, to be fused with the 1F below
834            bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
835            bwd_mb_index += 1
836
837            if fwd_mb_index == self._n_microbatches:
838                # We are done with 1B1F, so break with some left-over bwd_sends
839                break
840
841            # We prepare 1F of the `1B1F`
842            fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
843
844            # Fuse it with bwd_sends above
845            if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
846                fuse_work.wait()
847
848            # Now do the fwd
849            output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])  # type: ignore[index]
850
851            # Compute loss
852            self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
853
854            # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
855            fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
856            fwd_mb_index += 1
857
858        # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
859        send_work = _batch_p2p(bwd_sends, desc="bwd_send")
860
861        # Cooldown
862        while bwd_mb_index < self._n_microbatches:
863            # prepare bwd recv ops
864            bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
865            if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
866                recv_work.wait()
867
868            # Backward one chunk
869            loss = self._maybe_get_loss(self._stage, bwd_mb_index)
870            self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
871
872            # Clear previous chunk's backward sends (hopefully they have well finished)
873            if send_work:
874                send_work.wait()
875
876            # Get the bwd send ops, fire it
877            bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
878            send_work = _batch_p2p(bwd_sends, desc="bwd_send")
879            bwd_mb_index += 1
880
881        # Wait for the last backward send to finish
882        if send_work:
883            send_work.wait()
884
885        # Return losses if there is a container passed in
886        self._update_losses(self._stage, losses)
887
888
889def _add_unshard_reshard(
890    compute_actions: List[Optional[_Action]],
891    max_active_stages: int = 3,
892) -> List[_Action]:
893    """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
894
895    UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
896    RESHARD does the opposite, releasing memory (but doing no commmunication)
897
898    We abandon the "timestep lock"  during lowering
899
900    max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
901    3 stages is probably the thing we want?
902    (to account for having one f and one b active, and something else prefetching?)
903    """
904
905    def next_stage_indices(
906        count: int, next_actions: List[Optional[_Action]]
907    ) -> List[int]:
908        """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
909        seen: Set[int] = set()
910        ret: List[int] = []
911
912        for a in next_actions:
913            if a is not None and a.stage_index not in seen:
914                seen.add(a.stage_index)
915                ret.append(a.stage_index)
916                if len(ret) == count:
917                    break
918        return ret
919
920    active_stages: Set[int] = set()
921    fsdp_aware_actions: List[_Action] = []
922
923    def _unshard(stage_index: int):
924        active_stages.add(stage_index)
925        fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
926
927    def _reshard(stage_index: int):
928        active_stages.remove(stage_index)
929        fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
930
931    for i, action in enumerate(compute_actions):
932        if action is None:
933            continue
934
935        # We prefetch the next N stages we'll see, dropping existing stages to make room
936        next_n = next_stage_indices(max_active_stages, compute_actions[i:])
937        # Fetch needs to be ordered correctly, so don't use a set
938        fetch = list(filter(lambda s: s not in active_stages, next_n))
939        # Unclear what the best policy is for eviction, but we can maintain order so we do
940        evict = list(filter(lambda s: s not in next_n, active_stages))
941
942        # logger.debug(
943        #     "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
944        #     i,
945        #     active_stages,
946        #     fetch,
947        #     evict,
948        # )
949
950        for stage in evict:
951            _reshard(stage)
952        for stage in fetch:
953            _unshard(stage)
954        fsdp_aware_actions.append(action)
955
956    return fsdp_aware_actions
957
958
959def _add_send_recv(
960    compute_actions: Dict[int, List[_Action]],
961    stage_to_rank: Callable[[int], int],
962    num_stages: int,
963) -> Dict[int, List[_Action]]:
964    comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
965
966    def _has_comms(action: _Action) -> bool:
967        if action.computation_type == F:
968            return action.stage_index != num_stages - 1
969        elif action.computation_type == B:
970            return action.stage_index != 0
971        return False
972
973    def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
974        assert _has_comms(action), f"{action} is not a valid comm action"
975        stage_idx = action.stage_index
976        ctype = action.computation_type
977        mb_idx = action.microbatch_index
978        send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
979        recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
980        recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
981        return send, recv
982
983    def _ready_to_schedule(
984        action: Optional[_Action], prev_actions: List[_Action]
985    ) -> bool:
986        """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
987        This helps ensure a sane (non-hanging) ordering of sends and recvs.
988        But it also means we might not be able to schedule our next compute action yet.
989        """
990        if action is None:
991            return True
992        elif action.computation_type == F and not action.stage_index == 0:
993            expected_recv = _Action(
994                action.stage_index,
995                RECV_F if action.computation_type == F else RECV_B,
996                action.microbatch_index,
997            )
998            return expected_recv in prev_actions
999        elif action.computation_type == B and not action.stage_index == num_stages - 1:
1000            expected_recv = _Action(
1001                action.stage_index,
1002                RECV_F if action.computation_type == F else RECV_B,
1003                action.microbatch_index,
1004            )
1005            return expected_recv in prev_actions
1006        else:
1007            return True
1008
1009    while compute_actions:
1010        progress = False
1011        # go in order of ranks even if dict keys aren't ordered
1012        for rank in range(len(compute_actions)):
1013            assert len(compute_actions[rank]) > 0
1014            action = compute_actions[rank][0]
1015
1016            if not _ready_to_schedule(action, comm_actions[rank]):
1017                continue
1018
1019            if action is not None:
1020                comm_actions[rank].append(action)
1021                if _has_comms(action):
1022                    send, recv = _get_comms(action)
1023                    # TODO we can avoid send/recv if the 2 stages are on the same rank.
1024                    # should we avoid that in the runtime or here?
1025                    comm_actions[rank].append(send)
1026                    comm_actions[stage_to_rank(recv.stage_index)].append(recv)
1027
1028            compute_actions[rank].pop(0)
1029            if len(compute_actions[rank]) == 0:
1030                del compute_actions[rank]
1031            progress = True
1032        assert progress, "Malformed compute schedule, can't schedule sends/recvs"
1033    return comm_actions
1034
1035
1036class PipelineScheduleMulti(_PipelineSchedule):
1037    """
1038    Base class for multi-stage schedules.
1039    Implements the `step` method.
1040    """
1041
1042    def __init__(
1043        self,
1044        stages: List[_PipelineStageBase],
1045        n_microbatches: int,
1046        loss_fn: Optional[Callable] = None,
1047        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1048        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1049        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1050        stage_index_to_group_rank: Optional[Dict[int, int]] = None,
1051        use_full_backward: bool = True,
1052    ):
1053        if len(stages) <= 1:
1054            raise ValueError(
1055                f"Multi-stage schedule expects at least two stages but got {len(stages)}"
1056            )
1057        # Init parent
1058        super().__init__(
1059            n_microbatches=n_microbatches,
1060            loss_fn=loss_fn,
1061            args_chunk_spec=args_chunk_spec,
1062            kwargs_chunk_spec=kwargs_chunk_spec,
1063            output_merge_spec=output_merge_spec,
1064        )
1065        # Self attributes
1066        self._stages = stages
1067        self._num_stages = stages[0].num_stages
1068        self.pp_group_size = stages[0].group_size
1069        self.rank = stages[0].group_rank
1070        # Set the pipeline stage states
1071        if stage_index_to_group_rank is not None:
1072            for stage in self._stages:
1073                stage.stage_index_to_group_rank = stage_index_to_group_rank
1074        self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank
1075
1076        # Set the same has_backward flag for stage object
1077        for stage in self._stages:
1078            stage.has_backward = self._has_backward
1079
1080        self._should_compute_loss = (
1081            lambda stage: stage.is_last and self._loss_fn is not None
1082        )
1083
1084        # This will be set during init of derived schedules
1085        self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1086        self.use_full_backward = use_full_backward
1087
1088        # TODO: later replace this with lazy shape inference during forward
1089        # Prepare forward send/recv infrastructure for stage
1090        for stage in self._stages:
1091            stage._prepare_forward_infra(n_microbatches)
1092            if self._has_backward:
1093                stage._prepare_backward_infra(n_microbatches)
1094
1095    def _dump_csv(self, filename):
1096        """Dump a CSV representation of the schedule into a file with the provided filename."""
1097        with open(filename, "w", newline="") as csvfile:
1098            writer = csv.writer(csvfile)
1099            for rank in self.pipeline_order:
1100                writer.writerow(self.pipeline_order[rank])
1101
1102    def _validate_schedule(self):
1103        # TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
1104        def _validate_rank_actions(
1105            actions: Dict[int, List[_Action | None]],
1106            num_stages: int,
1107            num_microbatches: int,
1108        ):
1109            # We will count all the actions per stage and ensure they happen in a valid order
1110            # (e.g. F before B before W for a given microbatch)
1111            stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
1112                stage_id: {
1113                    F: set(),
1114                    B: set(),
1115                    W: set(),
1116                }
1117                for stage_id in range(num_stages)
1118            }
1119            for rank in actions:
1120                for action in actions[rank]:
1121                    if action is None:
1122                        continue
1123                    assert isinstance(
1124                        action, _Action
1125                    ), f"Got an invalid action: {action}, expected instance of _Action"
1126                    s_id = action.stage_index
1127                    ctype = action.computation_type
1128                    mb_id = action.microbatch_index
1129                    if ctype == F:
1130                        stage_actions[s_id][F].add(mb_id)
1131                    elif ctype == B:
1132                        assert (
1133                            mb_id in stage_actions[s_id][F]
1134                        ), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
1135                        stage_actions[s_id][B].add(mb_id)
1136                    elif ctype == W:
1137                        assert (
1138                            not self.use_full_backward
1139                        ), "Schedule contains 'W' actions, but is configured to use full backward"
1140                        assert (
1141                            mb_id in stage_actions[s_id][B]
1142                        ), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
1143                        stage_actions[s_id][W].add(mb_id)
1144
1145            for s_id in stage_actions:
1146                for ctype in (F, B, W):
1147                    stage_mb = len(stage_actions[s_id][ctype])
1148                    assert (
1149                        stage_mb == num_microbatches
1150                    ), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
1151
1152        assert (
1153            len(self.pipeline_order) == self.pp_group_size
1154        ), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
1155        for rank in range(self.pp_group_size):
1156            assert (
1157                rank in self.pipeline_order
1158            ), f"Schedule is missing actions for rank {rank}"
1159        _validate_rank_actions(
1160            self.pipeline_order,
1161            self._num_stages,
1162            self._n_microbatches,
1163        )
1164
1165    def _load_csv(self, filename, format="compute_only"):
1166        """Load a CSV representation of the schedule from a file with the provided filename.
1167        This API will most likely get renamed/refactored so is marked as internal for now.
1168
1169        format must be "compute_only" for PipelineScheduleMulti
1170        """
1171        assert format == "compute_only"
1172        with open(filename, newline="") as csvfile:
1173            reader = csv.reader(csvfile)
1174            for rank, row in enumerate(reader):
1175                self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
1176        self._validate_schedule()
1177
1178    def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
1179        """
1180        Run one iteration of the pipeline schedule with *whole-batch* input.
1181        Will chunk the input into microbatches automatically, and go through the
1182        microbatches according to the schedule implementation.
1183
1184        args: positional arguments to the model (as in non-pipeline case).
1185        kwargs: keyword arguments to the model (as in non-pipeline case).
1186        target: target for the loss function.
1187        losses: a list to store the losses for each microbatch.
1188        """
1189
1190        # Clean per iteration
1191        for stage in self._stages:
1192            stage.clear_runtime_states()
1193
1194        # Split inputs into microbatches
1195        args_split, kwargs_split = self._split_inputs(args, kwargs)
1196
1197        # Split target into microbatches
1198        if target is not None:
1199            targets_split = list(torch.tensor_split(target, self._n_microbatches))
1200        else:
1201            targets_split = None
1202
1203        # Run microbatches
1204        self._step_microbatches(args_split, kwargs_split, targets_split, losses)
1205
1206        # Return merged results per original format
1207        for stage in self._stages:
1208            if stage.is_last:
1209                return self._merge_outputs(stage.output_chunks)
1210        # Does not contain the last stage
1211        return None
1212
1213    def _step_microbatches(
1214        self,
1215        arg_mbs: Optional[List] = None,
1216        kwarg_mbs: Optional[List] = None,
1217        target_mbs: Optional[List] = None,
1218        losses: Optional[List] = None,
1219    ):
1220        """
1221        Operate on the microbatches for looped schedules (multiple stages on each rank).
1222
1223        TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
1224        not support models with skip connections.
1225        """
1226        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
1227
1228        # Based on the plan in Step 1 created in __init__:
1229        # 2. Perform communication based on the pipeline_order
1230        stage_index_to_stage: Dict[int, _PipelineStageBase] = {
1231            stage.stage_index: stage for stage in self._stages
1232        }
1233
1234        # determine prev_rank and next_rank based on which ranks are next to
1235        # the stages in the pipeline_order
1236        all_prev_ranks: Set[int] = set()
1237        all_next_ranks: Set[int] = set()
1238        for stage_index in stage_index_to_stage.keys():
1239            # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
1240            if stage_index > 0:
1241                all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
1242            if stage_index < self._num_stages - 1:
1243                all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
1244
1245        for time_step, action in enumerate(self.pipeline_order[self.rank]):
1246            try:
1247                ops: List[dist.P2POp] = []
1248                if action is not None:
1249                    computation_type = action.computation_type
1250                    mb_index = action.microbatch_index
1251                    stage_index = action.stage_index
1252                    assert (
1253                        mb_index is not None
1254                    ), "All currently supported action types require valid microbatch_index"
1255                    if computation_type == _ComputationType.FORWARD:
1256                        # perform forward computation
1257                        stage = stage_index_to_stage[stage_index]
1258                        output = stage.forward_one_chunk(
1259                            mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
1260                        )
1261                        self._maybe_compute_loss(stage, output, target_mbs, mb_index)
1262                        ops.extend(stage.get_fwd_send_ops(mb_index))
1263                    elif computation_type == _ComputationType.BACKWARD:
1264                        # perform backward computation
1265                        stage = stage_index_to_stage[stage_index]
1266                        loss = self._maybe_get_loss(stage, mb_index)
1267                        stage.backward_one_chunk(
1268                            mb_index, loss=loss, full_backward=self.use_full_backward
1269                        )
1270                        ops.extend(stage.get_bwd_send_ops(mb_index))
1271                    elif computation_type == _ComputationType.WEIGHT:
1272                        # perform weight update
1273                        if self.use_full_backward:
1274                            raise ValueError(
1275                                f"We detected a weight update in the pipeline schedule, but \
1276                                {self.use_full_backward=}"
1277                            )
1278                        stage = stage_index_to_stage[stage_index]
1279                        stage.backward_weight_one_chunk(mb_index)
1280                    else:
1281                        raise ValueError(f"Unknown computation type {computation_type}")
1282
1283                # Look at the neighboring ranks for this current timestep and determine whether
1284                # this current rank needs to do any recv communication
1285                for prev_rank in all_prev_ranks:
1286                    prev_rank_ops = self.pipeline_order[prev_rank]
1287                    prev_rank_action = None
1288                    if time_step < len(prev_rank_ops):
1289                        prev_rank_action = prev_rank_ops[time_step]
1290                    if prev_rank_action is not None:
1291                        computation_type = prev_rank_action.computation_type
1292                        mb_index = prev_rank_action.microbatch_index
1293                        stage_index = prev_rank_action.stage_index
1294                        assert (
1295                            mb_index is not None
1296                        ), "All currently supported action types require valid microbatch_index"
1297                        # Only handle sends for the forward from a previous rank
1298                        if computation_type == _ComputationType.FORWARD:
1299                            # If not the last stage, then receive fwd activations
1300                            if stage_index + 1 in stage_index_to_stage:
1301                                # TODO: We are assuming that stage will always receive from stage-1
1302                                # however that is not necessarily true of get_fwd_recv_ops
1303                                stage = stage_index_to_stage[stage_index + 1]
1304                                ops.extend(stage.get_fwd_recv_ops(mb_index))
1305                        elif (
1306                            computation_type == _ComputationType.BACKWARD
1307                            or computation_type == _ComputationType.WEIGHT
1308                        ):
1309                            # Previous rank doing backward or weight update has no influence for the current rank forward recv
1310                            pass
1311                        else:
1312                            raise ValueError(
1313                                f"Unknown computation type {computation_type}"
1314                            )
1315                for next_rank in all_next_ranks:
1316                    next_rank_ops = self.pipeline_order[next_rank]
1317                    next_rank_action = None
1318                    if time_step < len(next_rank_ops):
1319                        next_rank_action = next_rank_ops[time_step]
1320                    if next_rank_action is not None:
1321                        computation_type = next_rank_action.computation_type
1322                        mb_index = next_rank_action.microbatch_index
1323                        stage_index = next_rank_action.stage_index
1324                        assert (
1325                            mb_index is not None
1326                        ), "All currently supported action types require valid microbatch_index"
1327                        # Only handle receives for the backwards from a next rank
1328                        if (
1329                            computation_type == _ComputationType.FORWARD
1330                            or computation_type == _ComputationType.WEIGHT
1331                        ):
1332                            # Next rank doing forward or weight update has no influence for the current rank backward recv
1333                            pass
1334                        elif computation_type == _ComputationType.BACKWARD:
1335                            # If not the first stage, then receive bwd gradients
1336                            if stage_index - 1 in stage_index_to_stage:
1337                                # TODO: We are assuming that stage will always receive from stage+1
1338                                # however that is not necessarily true of get_bwd_recv_ops
1339                                stage = stage_index_to_stage[stage_index - 1]
1340                                ops.extend(stage.get_bwd_recv_ops(mb_index))
1341                        else:
1342                            raise ValueError(
1343                                f"Unknown computation type {computation_type}"
1344                            )
1345
1346                # do the communication
1347                if ops:
1348                    _batch_p2p(ops).wait()
1349            except Exception as e:
1350                logger.error(
1351                    "[Rank %s] pipeline schedule %s caught the following exception \
1352                     at time_step %s when running action %s",
1353                    self.rank,
1354                    self.__class__.__name__,
1355                    time_step,
1356                    action,
1357                )
1358                logger.error("%s", _format_pipeline_order(self.pipeline_order))
1359                raise e
1360        # Return losses if there is a container passed in
1361        self._update_losses(self._stages, losses)
1362
1363
1364class _PipelineScheduleRuntime(PipelineScheduleMulti):
1365    """
1366    Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
1367
1368    Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
1369    subclassed and the subclass can be responsible for creating a schedule IR.
1370    """
1371
1372    def _load_actions(
1373        self,
1374        actions: Dict[int, List[Optional[_Action]]],
1375        format: str = "compute_only",
1376    ):
1377        """
1378        Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
1379        communication actions.  Stores the schedule in self, and must be called before running step_mo()
1380        """
1381        assert (
1382            self.stage_index_to_group_rank is not None
1383        ), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
1384        self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
1385        if format == "compute_comms":
1386            for rank in actions:
1387                self.pipeline_order_with_comms[rank] = []
1388                for action in actions[rank]:
1389                    assert action is not None
1390                    self.pipeline_order_with_comms[rank].append(action)
1391            # TODO what level of validation should we offer for compute+comms schedule?
1392        elif format == "compute_only":
1393            # Perform schedule lowering
1394            for rank in actions:
1395                self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
1396                    actions[rank]
1397                )
1398
1399            self.pipeline_order_with_comms = _add_send_recv(
1400                self.pipeline_order_with_comms,
1401                stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
1402                num_stages=self._num_stages,
1403            )
1404        else:
1405            raise NotImplementedError(f"{format=} is not implemented")
1406
1407    def _load_csv(self, filename: str, format: str = "compute_only"):
1408        """Loads a csv in simple format and then lowers it to include comunication actions
1409
1410        format must be either "compute_only" or "compute_comms".  If compute_only, the lowering passes
1411        will automatically be run to generate a compute_comms schedule.
1412        """
1413        if format == "compute_only":
1414            # this will populate self.pipeline_order
1415            super()._load_csv(filename)
1416            # this will populate self.pipeline_order_with_comms
1417            self._load_actions(self.pipeline_order)
1418        elif format == "compute_comms":
1419            actions = {}
1420            with open(filename, newline="") as csvfile:
1421                reader = csv.reader(csvfile)
1422                for rank, row in enumerate(reader):
1423                    actions[rank] = [_Action.from_str(s) for s in row]
1424                self._load_actions(actions, format=format)
1425        else:
1426            raise NotImplementedError(f"{format=} is not implemented")
1427
1428    def _dump_csv(self, filename: str):
1429        """Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
1430        # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
1431        # that it does not exist if it was created from a compute_comms schedule.
1432        assert (
1433            self.pipeline_order_with_comms is not None
1434        ), "Must initialize compute_comms schedule before dump_csv"
1435        with open(filename, "w", newline="") as csvfile:
1436            writer = csv.writer(csvfile)
1437            for rank in self.pipeline_order_with_comms:
1438                writer.writerow(self.pipeline_order_with_comms[rank])
1439
1440    def _step_microbatches(
1441        self,
1442        arg_mbs: Optional[List] = None,
1443        kwarg_mbs: Optional[List] = None,
1444        target_mbs: Optional[List] = None,
1445        losses: Optional[List] = None,
1446    ):
1447        """
1448        Operate on the microbatches for looped schedules (multiple stages on each rank).
1449
1450        TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
1451        not support models with skip connections.
1452        """
1453        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
1454
1455        # Based on the plan in Step 1 created in __init__:
1456        # 2. Perform communication based on the pipeline_order
1457        stage_index_to_stage: Dict[int, _PipelineStageBase] = {
1458            stage.stage_index: stage for stage in self._stages
1459        }
1460
1461        assert (
1462            self.pipeline_order_with_comms is not None
1463        ), "Must call _load_actions() before calling _step_microbatches()"
1464
1465        # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
1466        bwd_recv_ops: Dict[Tuple[int, int], Work] = {}
1467        fwd_recv_ops: Dict[Tuple[int, int], Work] = {}
1468
1469        # send ops should be waited on before step() exists, mainly for hygeine
1470        send_ops: List[Work] = []
1471
1472        # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
1473        unshard_ops: Dict[int, UnshardHandle] = {}
1474        unsharded_stages = set()
1475
1476        def _assert_unsharded(stage_idx: int):
1477            """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
1478            if stage_idx in unshard_ops:
1479                unshard_ops[stage_idx].wait()
1480                del unshard_ops[stage_idx]
1481                unsharded_stages.add(stage_idx)
1482            assert (
1483                stage_idx in unsharded_stages
1484            ), f"Attempted to compute on sharded {stage_idx=}"
1485
1486        for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
1487            try:
1488                comp_type = action.computation_type
1489                mb_index: int = (
1490                    action.microbatch_index
1491                    if action.microbatch_index is not None
1492                    else -1
1493                )
1494                assert mb_index >= 0 or comp_type in (
1495                    UNSHARD,
1496                    RESHARD,
1497                ), f"{action=} missing mb_index"
1498                stage_idx = action.stage_index
1499                stage = stage_index_to_stage[stage_idx]
1500                stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
1501
1502                logger.debug(
1503                    "_PipelineScheduleRuntime running time_step %d, action %s",
1504                    time_step,
1505                    action,
1506                )
1507
1508                # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
1509                # since we do not want to batch up ops between more than a pair of ranks.  _sorted_batch_p2p would be
1510                # safe to use instead.
1511                # However, I was wondering if I should avoid calling batched operators at all in the case that there is
1512                # only one operator per batch.  I could iterate through the 'fwd_send_ops' one by one and run them.
1513                if comp_type == SEND_F:
1514                    send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
1515                elif comp_type == SEND_B:
1516                    send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
1517                elif comp_type == RECV_F:
1518                    assert (
1519                        stage_idx,
1520                        mb_index,
1521                    ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
1522                    fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
1523                        stage.get_fwd_recv_ops(mb_index)
1524                    )
1525                elif comp_type == RECV_B:
1526                    assert (
1527                        stage_idx,
1528                        mb_index,
1529                    ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
1530                    bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
1531                        stage.get_bwd_recv_ops(mb_index)
1532                    )
1533                elif comp_type == UNSHARD:
1534                    if stage_uses_fsdp:
1535                        assert (
1536                            stage_idx not in unsharded_stages
1537                            and stage_idx not in unshard_ops
1538                        ), f"Unsharding the same {stage_idx=} twice"
1539                        unshard_ops[stage_idx] = stage.submod.unshard(async_op=True)
1540                elif comp_type == RESHARD:
1541                    if stage_uses_fsdp:
1542                        assert (
1543                            stage_idx in unsharded_stages
1544                        ), f"Resharding {stage_idx=} without unsharding"
1545                        assert (
1546                            stage_idx not in unshard_ops
1547                        ), f"Resharding {stage_idx=} before finishing unshard"
1548                        stage.submod.reshard()
1549                elif comp_type == FORWARD:
1550                    if stage_uses_fsdp:
1551                        _assert_unsharded(stage_idx)
1552
1553                    if not stage.is_first:
1554                        assert (
1555                            stage_idx,
1556                            mb_index,
1557                        ) in fwd_recv_ops, f"Computing {action=} before receiving input"
1558                        fwd_recv_ops.pop((stage_idx, mb_index)).wait()
1559                    output = stage.forward_one_chunk(
1560                        mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
1561                    )
1562                    self._maybe_compute_loss(stage, output, target_mbs, mb_index)
1563                elif comp_type == BACKWARD:
1564                    if stage_uses_fsdp:
1565                        _assert_unsharded(stage_idx)
1566
1567                    if not stage.is_last:
1568                        assert (
1569                            stage_idx,
1570                            mb_index,
1571                        ) in bwd_recv_ops, (
1572                            f"Attempted to run compute {action=} before receiving input"
1573                        )
1574                        bwd_recv_ops.pop((stage_idx, mb_index)).wait()
1575                    loss = self._maybe_get_loss(stage, mb_index)
1576                    stage.backward_one_chunk(
1577                        mb_index, loss=loss, full_backward=self.use_full_backward
1578                    )
1579                elif comp_type == WEIGHT:
1580                    if stage_uses_fsdp:
1581                        _assert_unsharded(stage_idx)
1582
1583                    if self.use_full_backward:
1584                        raise ValueError(
1585                            f"We detected a weight update in the pipeline schedule, but \
1586                            {self.use_full_backward=}"
1587                        )
1588                    stage.backward_weight_one_chunk(mb_index)
1589                else:
1590                    raise ValueError(f"{action=} is unknown or unsupported")
1591            except Exception as e:
1592                logger.error(
1593                    "_PipelineScheduleRuntime caught exception at step %s when running action %s.  Full Schedule:",
1594                    time_step,
1595                    action,
1596                )
1597                # TODO(whc) what is the best practice for printing a multiline log?
1598                # logger will split it into multiple log lines, but this makes it hard to read (too wide)
1599                print(_format_pipeline_order(self.pipeline_order_with_comms))  # type: ignore[arg-type]
1600                raise e
1601
1602        # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
1603        while len(send_ops):
1604            send_ops.pop().wait()
1605
1606        assert len(unshard_ops) == 0, "Unused unshard operations"
1607
1608        # Return losses if there is a container passed in
1609        self._update_losses(self._stages, losses)
1610
1611
1612class ScheduleLoopedBFS(PipelineScheduleMulti):
1613    """
1614    Breadth-First Pipeline Parallelism.
1615    See https://arxiv.org/abs/2211.05953 for details.
1616    Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
1617    What is different is that when microbatches are ready for multiple local
1618    stages, Loops BFS will prioritizes the earlier stage, running all available
1619    microbatches at once.
1620    """
1621
1622    def __init__(
1623        self,
1624        stages: List[_PipelineStageBase],
1625        n_microbatches: int,
1626        loss_fn: Optional[Callable] = None,
1627        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1628    ):
1629        super().__init__(
1630            stages=stages,
1631            n_microbatches=n_microbatches,
1632            loss_fn=loss_fn,
1633            output_merge_spec=output_merge_spec,
1634        )
1635
1636        # 1. Create the pipeline_order (all ranks do this calculation)
1637        # This will be used to keep track of the current state of the entire pipeline
1638        # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1639        self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1640        # ========================================================================
1641        for rank in range(self.pp_group_size):
1642            rank_ops = self._calculate_single_rank_operations(rank)
1643            self.pipeline_order[rank] = rank_ops
1644
1645    def _calculate_single_rank_operations(self, rank):
1646        n_local_stages = len(self._stages)
1647        stage_indices = range(
1648            rank, self.pp_group_size * n_local_stages, self.pp_group_size
1649        )
1650
1651        # Store the list of operations used for that rank
1652        rank_ops: List[Optional[_Action]] = []
1653        # Pre-padding, rank starts with no-ops based on the warmup.
1654        for _ in range(rank):
1655            rank_ops.append(None)
1656
1657        for stage_index in stage_indices:
1658            for mb_index in range(self._n_microbatches):
1659                rank_ops.append(
1660                    _Action(stage_index, _ComputationType.FORWARD, mb_index)
1661                )
1662
1663        # wait for the first backward to trickle up
1664        # which is 2 for every hop away
1665        post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
1666        rank_ops.extend([None] * post_warmup_ops)
1667
1668        for stage_index in reversed(stage_indices):
1669            for mb_index in reversed(range(self._n_microbatches)):
1670                rank_ops.append(
1671                    _Action(stage_index, _ComputationType.BACKWARD, mb_index)
1672                )
1673        return rank_ops
1674
1675
1676def _get_1f1b_rank_ops(
1677    n_local_stages,
1678    pp_group_size,
1679    warmup_ops,
1680    fwd_bwd_ops,
1681    cooldown_ops,
1682    rank,
1683    forward_stage_index,
1684    backward_stage_index,
1685    num_1f1b_microbatches=0,
1686    enable_zero_bubble=False,
1687):
1688    # All stages start with handling microbatch 0
1689    fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1690    bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1691    weight_stage_mb_index: Dict[int, int] = defaultdict(int)
1692
1693    # Store the list of operations used for that rank
1694    rank_ops: List[Optional[_Action]] = []
1695    # Pre-padding, rank starts with no-ops based on the warmup.
1696    for _ in range(rank):
1697        rank_ops.append(None)
1698    # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
1699    # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
1700    # Formula:
1701    # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
1702    # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
1703    # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
1704    # warmup_ops = calculated above
1705    post_warmup_ops = (
1706        n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
1707    ) - (warmup_ops + rank)
1708
1709    if enable_zero_bubble:
1710        post_warmup_ops = pp_group_size - rank - 1
1711
1712    total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1713
1714    backward_op_ids = []
1715    weight_op_count = 0
1716
1717    for op in range(total_ops):
1718        # Warmup phase
1719        if op < warmup_ops:
1720            fwd_stage_index = forward_stage_index(op)
1721            # This will assign the current microbatch index and update it as well
1722            fwd_stage_mb_index[fwd_stage_index] = (
1723                mb_index := fwd_stage_mb_index[fwd_stage_index]
1724            ) + 1
1725            rank_ops.append(
1726                _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
1727            )
1728            if op == warmup_ops - 1:
1729                # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
1730                rank_ops.extend([None] * post_warmup_ops)
1731        # 1F1B Phase (forward and backward)
1732        elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
1733            fwd_stage_index = forward_stage_index(op)
1734            fwd_stage_mb_index[fwd_stage_index] = (
1735                fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
1736            ) + 1
1737            rank_ops.append(
1738                _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
1739            )
1740            bwd_stage_index = backward_stage_index(op)
1741            bwd_stage_mb_index[bwd_stage_index] = (
1742                bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1743            ) + 1
1744            rank_ops.append(
1745                _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
1746            )
1747            backward_op_ids.append(op)
1748
1749            if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
1750                weight_stage_index = backward_stage_index(
1751                    backward_op_ids[weight_op_count]
1752                )
1753                weight_stage_mb_index[weight_stage_index] = (
1754                    weight_mb_index := weight_stage_mb_index[weight_stage_index]
1755                ) + 1
1756                rank_ops.append(
1757                    _Action(
1758                        weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
1759                    )
1760                )
1761                weight_op_count += 1
1762        # Cooldown phase
1763        else:
1764            # During cooldown phase, we need steps to align with 1f1b happening in other ranks
1765            # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
1766            if not enable_zero_bubble:
1767                rank_ops.append(None)
1768
1769            bwd_stage_index = backward_stage_index(op)
1770            bwd_stage_mb_index[bwd_stage_index] = (
1771                bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1772            ) + 1
1773            rank_ops.append(
1774                _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
1775            )
1776            backward_op_ids.append(op)
1777
1778            if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
1779                weight_stage_index = backward_stage_index(
1780                    backward_op_ids[weight_op_count]
1781                )
1782                weight_stage_mb_index[weight_stage_index] = (
1783                    weight_mb_index := weight_stage_mb_index[weight_stage_index]
1784                ) + 1
1785                rank_ops.append(
1786                    _Action(
1787                        weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
1788                    )
1789                )
1790                weight_op_count += 1
1791
1792    while enable_zero_bubble and weight_op_count < len(backward_op_ids):
1793        weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
1794        weight_stage_mb_index[weight_stage_index] = (
1795            weight_mb_index := weight_stage_mb_index[weight_stage_index]
1796        ) + 1
1797        rank_ops.append(
1798            _Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
1799        )
1800        weight_op_count += 1
1801
1802    return rank_ops
1803
1804
1805class ScheduleInterleaved1F1B(PipelineScheduleMulti):
1806    """
1807    The Interleaved 1F1B schedule.
1808    See https://arxiv.org/pdf/2104.04473 for details.
1809    Will perform one forward and one backward on the microbatches in steady
1810    state and supports multiple stages per rank. When microbatches are ready for
1811    multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
1812    (also called "depth first").
1813    """
1814
1815    def __init__(
1816        self,
1817        stages: List[_PipelineStageBase],
1818        n_microbatches: int,
1819        loss_fn: Optional[Callable] = None,
1820        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1821        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1822        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1823    ):
1824        self.pp_group_size = stages[0].group_size
1825        # TODO: is this limitation a must?
1826        if n_microbatches % self.pp_group_size != 0:
1827            raise ValueError(
1828                f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
1829                to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
1830            )
1831
1832        super().__init__(
1833            stages=stages,
1834            n_microbatches=n_microbatches,
1835            loss_fn=loss_fn,
1836            args_chunk_spec=args_chunk_spec,
1837            kwargs_chunk_spec=kwargs_chunk_spec,
1838            output_merge_spec=output_merge_spec,
1839        )
1840
1841        self.n_local_stages = len(stages)
1842        self.rank = stages[0].group_rank
1843        self.group = stages[0].group
1844
1845        # 1. Create the pipeline_order (all ranks do this calculation)
1846        # This will be used to keep track of the current state of the entire pipeline
1847        # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1848        self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1849
1850        for rank in range(self.pp_group_size):
1851            rank_ops = self._calculate_single_rank_operations(rank)
1852            self.pipeline_order[rank] = rank_ops
1853
1854    def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
1855        def get_rank_warmup_ops(rank):
1856            # Warms up operations for last stage
1857            warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size
1858            # Increment warmup operations by 2 for each hop away from the last stage
1859            warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
1860            # We cannot have more warmup operations than there are number of microbatches, so cap it there
1861            return min(warmup_ops, self._n_microbatches * self.n_local_stages)
1862
1863        warmup_ops = get_rank_warmup_ops(rank)
1864        microbatch_ops = self.n_local_stages * self._n_microbatches
1865        # fwd_bwd_ops should encompass the remaining forwards
1866        fwd_bwd_ops = microbatch_ops - warmup_ops
1867        # cooldown_ops should encompass the remaining backwards
1868        cooldown_ops = microbatch_ops - fwd_bwd_ops
1869        # total ops encompass both forward and backward ops
1870        total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1871        # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
1872
1873        logger.debug(
1874            "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
1875            rank,
1876            warmup_ops,
1877            fwd_bwd_ops,
1878            cooldown_ops,
1879            total_ops,
1880        )
1881
1882        # Calculates the stage index based on step and pp_group_size
1883        def forward_stage_index(step):
1884            # Get the local index from 0 to n_local_stages-1
1885            local_index = (step // self.pp_group_size) % self.n_local_stages
1886            return (local_index * self.pp_group_size) + rank
1887
1888        def backward_stage_index(step):
1889            local_index = (
1890                self.n_local_stages
1891                - 1
1892                - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages
1893            )
1894            return (local_index * self.pp_group_size) + rank
1895
1896        return _get_1f1b_rank_ops(
1897            self.n_local_stages,
1898            self.pp_group_size,
1899            warmup_ops,
1900            fwd_bwd_ops,
1901            cooldown_ops,
1902            rank,
1903            forward_stage_index,
1904            backward_stage_index,
1905        )
1906
1907
1908class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
1909    """
1910    The Flexible Interleaved 1F1B schedule.
1911
1912    This schedule is mostly similar to the interleaved 1F1B schedule.
1913    It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
1914    Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
1915    it works as long as n_microbatches % num_rounds is 0. As a few examples, support
1916
1917    1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
1918    2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
1919
1920    When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
1921    """
1922
1923    def __init__(
1924        self,
1925        stages: List[_PipelineStageBase],
1926        n_microbatches: int,
1927        loss_fn: Optional[Callable] = None,
1928        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1929        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1930        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1931        enable_zero_bubble: bool = False,
1932    ):
1933        self.pp_group_size = stages[0].group_size
1934        super().__init__(
1935            stages=stages,
1936            n_microbatches=n_microbatches,
1937            loss_fn=loss_fn,
1938            args_chunk_spec=args_chunk_spec,
1939            kwargs_chunk_spec=kwargs_chunk_spec,
1940            output_merge_spec=output_merge_spec,
1941            use_full_backward=not enable_zero_bubble,
1942        )
1943        self.n_local_stages = len(stages)
1944        self.rank = stages[0].group_rank
1945        self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
1946        self.microbatches_per_round = n_microbatches // self.number_of_rounds
1947        self.enable_zero_bubble = enable_zero_bubble
1948        if n_microbatches % self.number_of_rounds != 0:
1949            raise ValueError(
1950                "Flexible Interleaved 1F1B requires the number of microbatches to be a "
1951                f"multiple of the number of rounds ({self.number_of_rounds}), "
1952                f"but got {n_microbatches}."
1953            )
1954        # 1. Create the pipeline_order (all ranks do this calculation)
1955        # This will be used to keep track of the current state of the entire pipeline
1956        # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1957        self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1958        for rank in range(self.pp_group_size):
1959            rank_ops = self._calculate_single_rank_operations(rank)
1960            self.pipeline_order[rank] = rank_ops
1961
1962        # This function add bubbles to the generated schedule based on dependencies of actions
1963        # Note that the ZB1P schedule will not require bubbles to be manually added and it is
1964        # only useful when n_microbatches <= microbatches_per_round
1965        self.pipeline_order = self._add_bubbles_to_actions(
1966            self.n_local_stages * self.pp_group_size,
1967        )
1968
1969    def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
1970        def get_rank_warmup_ops(rank):
1971            # Warms up operations for last stage
1972            warmups_ops_last_stage = (
1973                self.n_local_stages - 1
1974            ) * self.microbatches_per_round
1975            # Increment warmup operations by 2 for each hop away from the last stage
1976            multiply_factor = 1 if self.enable_zero_bubble else 2
1977            warmup_ops = warmups_ops_last_stage + multiply_factor * (
1978                (self.pp_group_size - 1) - rank
1979            )
1980
1981            # We cannot have more warmup operations than there are number of microbatches, so cap it there
1982            return min(warmup_ops, self._n_microbatches * self.n_local_stages)
1983
1984        warmup_ops = get_rank_warmup_ops(rank)
1985        microbatch_ops = self.n_local_stages * self._n_microbatches
1986        # fwd_bwd_ops should encompass the remaining forwards
1987        fwd_bwd_ops = microbatch_ops - warmup_ops
1988        # cooldown_ops should encompass the remaining backwards
1989        cooldown_ops = microbatch_ops - fwd_bwd_ops
1990        # total ops encompass both forward and backward ops
1991        total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1992        # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
1993        logger.debug(
1994            "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
1995            rank,
1996            warmup_ops,
1997            fwd_bwd_ops,
1998            cooldown_ops,
1999            total_ops,
2000        )
2001
2002        # Calculates the stage index based on step and pp_group_size
2003
2004        def forward_stage_index(step):
2005            # Get the local index from 0 to n_local_stages-1
2006            local_index = (step // self.microbatches_per_round) % self.n_local_stages
2007            return (local_index * self.pp_group_size) + rank
2008
2009        def backward_stage_index(step):
2010            local_index = (
2011                self.n_local_stages
2012                - 1
2013                - ((step - warmup_ops) // self.microbatches_per_round)
2014                % self.n_local_stages
2015            )
2016            return (local_index * self.pp_group_size) + rank
2017
2018        if self.enable_zero_bubble:
2019            num_1f1b_microbatches = rank
2020
2021            return _get_1f1b_rank_ops(
2022                self.n_local_stages,
2023                self.pp_group_size,
2024                warmup_ops,
2025                fwd_bwd_ops,
2026                cooldown_ops,
2027                rank,
2028                forward_stage_index,
2029                backward_stage_index,
2030                num_1f1b_microbatches,
2031                enable_zero_bubble=True,
2032            )
2033
2034        return _get_1f1b_rank_ops(
2035            self.n_local_stages,
2036            self.pp_group_size,
2037            warmup_ops,
2038            fwd_bwd_ops,
2039            cooldown_ops,
2040            rank,
2041            forward_stage_index,
2042            backward_stage_index,
2043        )
2044
2045    def _add_bubbles_to_actions(self, num_stages_global):
2046        actions = self.pipeline_order
2047        if not self.enable_zero_bubble:
2048            return actions
2049
2050        def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
2051            if op == _ComputationType.FORWARD:
2052                if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
2053                    return True
2054            elif op == _ComputationType.BACKWARD:
2055                if stage == num_stages_global - 1:
2056                    return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
2057                return (stage + 1, op, microbatch) not in seen_ops
2058            return False
2059
2060        seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
2061        result: Dict[int, List[Optional[_Action]]] = {}
2062        next_pointer: Dict[int, int] = {}
2063        bubbles_added: Dict[int, int] = {}
2064        total_bubbles_added = 0
2065
2066        for rank in range(self.pp_group_size):
2067            result[rank] = []
2068            next_pointer[rank] = 0
2069            bubbles_added[rank] = 0
2070
2071        while True:
2072            should_stop = True
2073
2074            temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
2075
2076            for rank in range(self.pp_group_size):
2077                timestamp = next_pointer[rank]
2078                if timestamp >= len(actions[rank]):
2079                    continue
2080
2081                should_stop = False
2082
2083                if actions[rank][timestamp] is not None:
2084                    temp_action = actions[rank][timestamp]
2085                    assert temp_action is not None
2086                    stage_index, op, microbatch = temp_action
2087                    if not need_bubble(
2088                        stage_index, op, microbatch, num_stages_global, seen_ops
2089                    ):
2090                        result[rank].append(actions[rank][timestamp])
2091                        if microbatch is not None:
2092                            temp_seen_ops.add((stage_index, op, microbatch))
2093                        next_pointer[rank] += 1
2094                    else:
2095                        result[rank].append(None)
2096                        bubbles_added[rank] += 1
2097                else:
2098                    next_pointer[rank] += 1
2099                    result[rank].append(None)
2100
2101            seen_ops.update(temp_seen_ops)
2102            if should_stop:
2103                break
2104
2105        if total_bubbles_added > 0:
2106            logger.warning(
2107                "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
2108                total_bubbles_added,
2109                bubbles_added,
2110            )
2111        return result
2112
2113
2114class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
2115    """
2116    The Interleaved Zero Bubble schedule.
2117    See https://arxiv.org/pdf/2401.10241 for details.
2118    Will perform one forward and one backward on inputs for the microbatches in steady
2119    state and supports multiple stages per rank. Uses the backward for weights to fill in
2120    the pipeline bubble.
2121    """
2122
2123    def __init__(
2124        self,
2125        stages: List[_PipelineStageBase],
2126        n_microbatches: int,
2127        loss_fn: Optional[Callable] = None,
2128        args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
2129        kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
2130        output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
2131    ):
2132        super().__init__(
2133            stages=stages,
2134            n_microbatches=n_microbatches,
2135            loss_fn=loss_fn,
2136            args_chunk_spec=args_chunk_spec,
2137            kwargs_chunk_spec=kwargs_chunk_spec,
2138            output_merge_spec=output_merge_spec,
2139            enable_zero_bubble=True,
2140        )
2141
2142
2143def get_schedule_class(schedule_name: str):
2144    """
2145    Maps a schedule name to its corresponding class object.
2146
2147    Args:
2148        schedule_name (str): The name of the schedule.
2149    """
2150    schedule_map = {
2151        "1F1B": Schedule1F1B,
2152        "Interleaved1F1B": ScheduleInterleaved1F1B,
2153        "GPipe": ScheduleGPipe,
2154        "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
2155        "LoopedBFS": ScheduleLoopedBFS,
2156        "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
2157        "PipelineScheduleSingle": PipelineScheduleSingle,
2158        "PipelineScheduleMulti": PipelineScheduleMulti,
2159    }
2160    if schedule_name not in schedule_map:
2161        raise ValueError(f"Unknown schedule name: {schedule_name}")
2162    return schedule_map[schedule_name]
2163