xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/_IR.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import copy
4import logging
5import operator
6from collections import defaultdict
7from enum import Enum
8from inspect import Parameter, Signature, signature
9from types import MethodType
10from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
11
12import torch
13import torch.fx as fx
14from torch.distributed import ProcessGroup
15from torch.export import ExportedProgram
16from torch.export.unflatten import (
17    _assign_attr,
18    _AttrKind,
19    _sink_params,
20    InterpreterModule,
21)
22from torch.fx.node import map_aggregate
23from torch.fx.passes.split_module import split_module
24
25from ._backward import _null_coalesce_accumulate, stage_backward
26from ._unflatten import _outline_submodules
27from ._utils import PipeInfo
28from .stage import _PipelineStage
29
30
31logger = logging.getLogger(__name__)
32
33# TODO:
34# 1. investigate gradient sync for shared parameters. how does DDP do it?
35# 2. Add parameter movement to split_module
36
37
38def _find_loss_from_output_and_spec(output_val, spec_val):
39    if spec_val is False:
40        return None
41    if spec_val is True:
42        if not isinstance(output_val, fx.Node):
43            raise RuntimeError(
44                f"Loss spec must specify a dynamic value but got {output_val}"
45            )
46        return output_val
47
48    if isinstance(spec_val, (tuple, list)):
49        if not isinstance(output_val, (tuple, list)):
50            raise RuntimeError(
51                f"Output value {output_val} must match type of loss specification "
52                f"{spec_val}"
53            )
54        if len(output_val) != len(spec_val):
55            raise RuntimeError(
56                f"Output value {output_val} must match length of loss specification "
57                f"{spec_val}"
58            )
59        for out, spec in zip(output_val, spec_val):
60            loss_val = _find_loss_from_output_and_spec(out, spec)
61            if loss_val is not None:
62                return loss_val
63        raise RuntimeError(f"Did not find loss value in specification {spec_val}")
64
65    if isinstance(spec_val, dict):
66        if not isinstance(output_val, dict):
67            raise RuntimeError(
68                f"Output value {output_val} must match type of loss specification "
69                f"{spec_val}"
70            )
71        if set(output_val.keys()) != set(spec_val.keys()):
72            raise RuntimeError(
73                f"Output value {output_val} must match keys of loss specification "
74                f"{spec_val}"
75            )
76        for k in spec_val:
77            loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
78            if loss_val is not None:
79                return loss_val
80        raise RuntimeError(f"Did not find loss value in specification {spec_val}")
81
82    raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
83
84
85def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
86    output_nodes = [n for n in g.nodes if n.op == "output"]
87    assert len(output_nodes) == 1
88    output_node = output_nodes[0]
89    output_val = output_node.args[0]
90    generated_spec: Any = None
91
92    if isinstance(mod, TrivialLossWrapper):
93        # TrivialLossWrapper is pre-defined by PiPPy.
94        # It has loss as the only output so we can safely assume the first output arg is the loss.
95        assert len(output_node.args) == 1
96        loss_node = output_val
97        generated_spec = TrivialLossWrapper.loss_spec
98    elif output_loss_value_spec is None:
99        # Use default spec, i.e. search for "loss" in output values
100        if isinstance(output_val, dict) and "loss" in output_val.keys():
101            loss_node = output_val["loss"]
102            generated_spec = {k: k == "loss" for k in output_val}
103        else:
104            loss_node = None
105            generated_spec = None
106    else:
107        loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
108        generated_spec = output_loss_value_spec
109
110    return loss_node, output_node, generated_spec
111
112
113def _insert_stage_symbolic_backward(
114    g: fx.Graph,
115    loss_node: fx.Node,
116    output_node: fx.Node,
117):
118    # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
119    tuples: Dict[fx.Node, Tuple] = {}
120    for node in reversed(g.nodes):
121        if node.op == "call_function":
122            # In the forward pass, only emit placeholder, module calls, and
123            # getitem calls. If we have a target other than getitem in this
124            # (forward-only) code, there is a bug.
125            assert node.target == operator.getitem, (
126                "Found non-getitem call in forward pass. "
127                "Please report a bug to PiPPy"
128            )
129            assert (
130                len(node.args) == 2
131            ), "Found malformed getitem call. Please report a bug to PiPPy"
132            indexed_value, node_idx = tuple(node.args)
133
134            # indexed_value is a collection that we are indexing into. It could
135            # exist in the tuples map if we've processed another `getitem`
136            # already.
137            existing_list_size = (
138                len(tuples[indexed_value]) if indexed_value in tuples else -1
139            )
140            new_list_size = max(node_idx + 1, existing_list_size)
141
142            reconstructed_list = [None for _ in range(new_list_size)]
143
144            # Copy over existing elements if present
145            if indexed_value in tuples:
146                for i, val in enumerate(tuples[indexed_value]):
147                    reconstructed_list[i] = val
148
149            # Populate value represented by this node
150            reconstructed_list[node_idx] = node
151
152            tuples[indexed_value] = tuple(reconstructed_list)
153
154    # Keep track of nodes that dominate the loss node.
155    # We will only emit backward operations for nodes that can contribute
156    # to the specified loss value.
157    live_nodes = {loss_node: None}
158    val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
159
160    def assign_or_accumulate_grad(forward_node, grad_value):
161        if forward_node in val_to_grad and forward_node.op != "placeholder":
162            grad_value = g.call_function(
163                _null_coalesce_accumulate,
164                (val_to_grad[forward_node], grad_value),
165            )
166        val_to_grad[forward_node] = grad_value
167
168    with g.inserting_before(output_node):
169        for node in reversed(g.nodes):
170            if node not in live_nodes:
171                continue
172
173            def add_to_live_nodes(n):
174                live_nodes.setdefault(n, None)
175
176            fx.node.map_arg(node.args, add_to_live_nodes)
177            fx.node.map_arg(node.kwargs, add_to_live_nodes)
178            if node.op == "call_module":
179                output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
180                if node in tuples:
181                    stage_output = tuples[node]
182                    output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
183                    outputs_with_grads_idxs = [
184                        i for i, n in enumerate(tuples[node]) if n in live_nodes
185                    ]
186                else:
187                    stage_output = (node,)
188                    output_grads = val_to_grad[node]
189                    outputs_with_grads_idxs = [0]
190
191                output_grads = (
192                    (output_grads,)
193                    if not isinstance(output_grads, tuple)
194                    else output_grads
195                )
196
197                grad_call = g.call_function(
198                    stage_backward,
199                    kwargs={
200                        "stage_output": stage_output,
201                        "output_grads": output_grads,
202                        "input_values": list(node.all_input_nodes),
203                        "outputs_with_grads_idxs": outputs_with_grads_idxs,
204                    },
205                )
206                # Insert backward stage debug info
207                kwargs_copy = dict(grad_call.kwargs)
208                grad_call.kwargs = kwargs_copy
209
210                grad_call_proxy = fx.Proxy(grad_call)
211                grads = grad_call_proxy.node
212
213                input_nodes = list(node.all_input_nodes)
214                grads_proxy = fx.Proxy(grads)
215                for i, input_node in enumerate(input_nodes):
216                    assign_or_accumulate_grad(input_node, grads_proxy[i].node)  # type: ignore[index]
217
218    return g
219
220
221class PipeSequential(torch.nn.Sequential):
222    @staticmethod
223    def from_sequential(sequential_instance: torch.nn.Sequential):
224        return PipeSequential(*[copy.copy(m) for m in sequential_instance])
225
226    def forward(self, input):
227        for i, module in enumerate(self):
228            input = module(input)
229            if i != len(self) - 1:
230                pipe_split()
231        return input
232
233
234class LossWrapper(torch.nn.Module):
235    """
236    LossWrapper is a convenient abstract class that allows you to wrap up both
237    your model as well as its loss function and specify the connectivity between
238    the inputs, model, loss function, and output value. Example::
239
240        class MyModelWrapper(LossWrapper):
241            def forward(self, x, targets):
242                model_out = self.module(x)
243                loss_value = self.loss_fn(model_out, targets)
244                return loss_value
245
246    The above example defines a connectivity where we expect the forward/loss/backward
247    training procedure to take two arguments (x and targets), pass x into the module
248    to get the output of the feedforward computation, pass the model output and the
249    targets value into the loss function, and get and return the loss value, which will
250    be backpropagated by PiPPy. The above class would then be instantiated like::
251
252        model = ... # instantiate the model
253        loss_fn = torch.nn.MSELoss() # for the sake of demonstration
254
255        wrapper = MyModelWrapper(model, loss_fn)
256        pipe = Pipe.from_tracing(wrapper, ...)
257
258    """
259
260    def __init__(self, module, loss_fn):
261        super().__init__()
262        self.module = module
263        self.loss_fn = loss_fn
264
265    def forward(self, *args, **kwargs):
266        raise NotImplementedError(
267            "This instance of LossWrapper does not have an overridden"
268            "forward(). Please implement forward() to specify the arguments, "
269            "connection between the module and loss, and loss output "
270            "value."
271        )
272
273
274class TrivialLossWrapper(LossWrapper):
275    def forward(self, x, targets):
276        model_out = self.module(x)
277        return self.loss_fn(model_out, targets)
278
279    loss_spec = True
280
281
282# Pipe model representation
283#
284# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
285# a single topological ordering of pipeline "stages" that, when run in series,
286# constitutes all of the operations of the program. However, unlike `nn.Sequential`,
287# Pipe allows non-local usages of values, so long as those uses still respect
288# topological ordering. In particular:
289#
290# 1. Non-local activations. This type of usage can appear in, for example, skip
291#    connections. These values will be directly transmitted from the "def" stage
292#    to all stages that use them skipping intermediate stages. During autograd,
293#    gradients will be propagated back through this skip connection reverse
294#    to how activations propagated in the forward pass.
295# 2. Non-local parameter/module invocations. This occurs when a parameter is used
296#    in a stage downstream of where it is resident. These values can be carried
297#    forward similarly to (1), but in addition one might want to replicate the
298#    value on multiple stages. Gradients for these shared parameters will be
299#    accumulated separately on each stage, but there will be an additional
300#    gradient accumulation before the optimizer step.
301
302
303# Register `_pipe_split()` as an ATen operator. This is required for Export to
304# preserve this marker in the graph.
305torch.library.define("pippy::_pipe_split", "() -> ()")
306
307
308@torch.library.impl("pippy::_pipe_split", "BackendSelect")
309def _pipe_split():
310    return None
311
312
313@torch.library.register_fake("pippy::_pipe_split")  # type: ignore[no-redef]
314def _pipe_split():  # noqa: F811
315    return None
316
317
318# Add an alias for convenience
319aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
320
321# Ask Export to preserve the `_pipe_split` op.
322# See examples in pytorch/torch/fx/node.py
323fx.node._side_effectful_functions.add(aten_pipe_split_alias)
324
325
326# User facing API
327def pipe_split():
328    """
329    pipe_split is a special operator that is used to mark the boundary between
330    stages in a module. It is used to split the module into stages. It is a
331    no-op if your annotated module is run eagerly.
332
333    Example:
334        >>> # xdoctest: +SKIP
335        >>> def forward(self, x):
336        >>>     x = torch.mm(x, self.mm_param)
337        >>>     x = torch.relu(x)
338        >>>     pipe_split()
339        >>>     x = self.lin(x)
340        >>>     return x
341
342    The above example will be split into two stages.
343    """
344    return torch.ops.pippy._pipe_split()
345
346
347class MultiUseParameterConfig(Enum):
348    TRANSMIT = 1
349    REPLICATE = 2
350
351
352MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
353
354
355class DetachExecutor(fx.Interpreter):
356    """
357    Special interpreter to run the split_gm in testing that detaches all inputs to
358    a module invocation. This is needed so that the values at the boundary are
359    leaf modules in autograd execution.
360    """
361
362    def __init__(self, module, garbage_collect_values=True):
363        garbage_collect_values = False
364        super().__init__(module, garbage_collect_values)
365        self.value_remap = {}
366
367    def run(self, *args, initial_env=None):
368        self.value_remap = {}
369        return super().run(*args, initial_env=initial_env)
370
371    def call_module(self, target, args, kwargs):
372        def detach_tensors(a):
373            if isinstance(a, torch.Tensor) and a.requires_grad:
374                if a not in self.value_remap:
375                    new_val = a.detach().requires_grad_(True)
376                    self.value_remap[a] = new_val
377                return self.value_remap[a]
378            else:
379                return a
380
381        """
382        def dont_traverse_size(a):
383            return type(a) != torch.Size
384        """
385
386        args = map_aggregate(
387            args,
388            detach_tensors,  # dont_traverse_size
389        )
390        kwargs = map_aggregate(
391            kwargs,
392            detach_tensors,  # dont_traverse_size
393        )
394
395        return super().call_module(target, args, kwargs)
396
397    def call_function(self, target, args, kwargs):
398        # HACK to reroute saved input tensors to point to the detach()ed version
399        if target == stage_backward:
400            kwargs = dict(kwargs)
401            kwargs["input_values"] = [
402                self.value_remap.get(v, v) for v in kwargs["input_values"]
403            ]
404        return super().call_function(target, args, kwargs)
405
406
407class _NodeReference:
408    def __init__(self, name):
409        self.name = name
410
411    name: str
412
413
414class _LinearNodeList:
415    def __init__(self, node_list):
416        self.serialize_node_list = []
417        for node in node_list:
418            node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))  # type: ignore[arg-type,return-value]
419            node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))  # type: ignore[arg-type,return-value]
420            serialize_node = fx.Node(
421                graph=None,  # type: ignore[arg-type]
422                name=node.name,
423                op=node.op,
424                target=node.target,
425                args=node_args,  # type: ignore[arg-type]
426                kwargs=node_kwargs,  # type: ignore[arg-type]
427                return_type=node.type,
428            )
429            serialize_node.meta = copy.copy(node.meta)
430            self.serialize_node_list.append(serialize_node)
431
432    def to_graph(self):
433        graph = fx.Graph()
434
435        ref_str_to_node: Dict[str, fx.Node] = {}
436
437        def ref_to_node(arg):
438            if isinstance(arg, _NodeReference):
439                return ref_str_to_node[arg.name]
440            else:
441                return arg
442
443        for node in self.serialize_node_list:
444            node_args = map_aggregate(node.args, ref_to_node)
445            node_kwargs = map_aggregate(node.kwargs, ref_to_node)
446            deser_node = graph.create_node(
447                op=node.op,
448                target=node.target,
449                args=node_args,  # type: ignore[arg-type]
450                kwargs=node_kwargs,  # type: ignore[arg-type]
451                name=node.name,
452                type_expr=node.type,
453            )
454            ref_str_to_node[node.name] = deser_node
455
456        return graph
457
458
459def _direct_serialization_deserialize(body, nodes):
460    """
461    Custom `__reduce__` method for serialization.
462    DO AS I SAY -- NOT AS I DO. This violates the principle that
463    GraphModules serialize via code export & re-tracing. We allow
464    for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
465    TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
466    these instances to disk will expose internal implementation
467    details of `fx.Graph` and related data structures and is
468    NOT advised.
469    """
470
471    class DummyModule(torch.nn.Module):
472        def __init__(self, body):
473            super().__init__()
474            self.__dict__.update(body)
475
476    dummy = DummyModule(body)
477
478    return fx.GraphModule(dummy, nodes.to_graph())
479
480
481def _direct_serialization_reduce(self):
482    serialization_dict = dict(self.__dict__)
483    serialization_dict.pop("_graph")
484    return (
485        _direct_serialization_deserialize,
486        (serialization_dict, _LinearNodeList(self.graph.nodes)),
487    )
488
489
490def _modify_graph_op_device(
491    gm: torch.fx.GraphModule,
492    new_device: torch.device,
493):
494    """
495    Modify the device argument of all "call_function" nodes in the graph.  This
496    is useful for moving the graph to a different device. In particular for
497    generator ops, like torch.ones.
498    """
499    modified = False
500    for node in gm.graph.nodes:
501        if node.op == "call_function":
502            if "device" in node.kwargs and node.kwargs["device"] != new_device:
503                logger.debug(
504                    f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}"  # noqa: G004
505                )
506                node.update_kwarg("device", new_device)
507                modified = True
508        elif node.op == "call_module":
509            # Recursively modify "device" in submodules
510            submod = gm.get_submodule(node.target)
511            if isinstance(submod, torch.fx.GraphModule):
512                _modify_graph_op_device(submod, new_device)
513            elif isinstance(submod, InterpreterModule):
514                # If unflattening has been performed, we need to access its graph module by `.graph_module`
515                _modify_graph_op_device(submod.graph_module, new_device)
516            else:
517                logger.warning(
518                    f"Skipping device modification for submodule {node.target} because it is a {type(submod)}"  # noqa: G004
519                )
520
521    if modified:
522        gm.recompile()
523
524
525class Pipe(torch.nn.Module):
526    def __init__(
527        self,
528        split_gm: fx.GraphModule,
529        num_stages: int,
530        has_loss_and_backward: bool,
531        loss_spec,
532    ):
533        # TODO: is there a way not to hard wire init?
534        torch.nn.Module.__init__(self)
535        self.split_gm: fx.GraphModule = split_gm
536        self.executor: DetachExecutor = DetachExecutor(self.split_gm)
537        self.num_stages: int = num_stages
538        self.has_loss_and_backward = has_loss_and_backward
539        self.loss_spec = loss_spec
540
541        for node in split_gm.graph.nodes:
542            assert (
543                node.op in {"call_module", "placeholder", "output"}
544                or (node.op, node.target) == ("call_function", operator.getitem)
545                or (node.op, node.target) == ("call_method", "backward")
546                or (node.op, node.target) == ("call_function", stage_backward)
547                or (node.op, node.target)
548                == ("call_function", _null_coalesce_accumulate)
549            ), node
550
551        # Detect replicated parameters so we know that we have to do an additional allreduce
552        # before applying the optimizer
553        #
554        # Note that this also handles the case where there were multiple calls to a single
555        # module from different stages, regardless of whether that module invocation
556        # was handled by the logic above.
557
558        # Map parameter value to a dictionary that maps the user pipeline module
559        # to the local qualname within that module
560        params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
561
562        for m_qualname, mod in self.split_gm.named_children():
563            for p_qualname, param in mod.named_parameters():
564                params_to_users.setdefault(param, {})
565                params_to_users[param][m_qualname] = p_qualname
566
567        self.replicated_params: List[Dict[str, str]] = [
568            use_mapping
569            for _, use_mapping in params_to_users.items()
570            if len(use_mapping) > 1
571        ]
572
573        # We must break the aliasing relationship between the replicated parameters for correct
574        # numerics in reference runs. If we do not do this, the autograd tape in separate stages
575        # will have a reference to the same tensor value and will erroneously apply gradient
576        # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
577        # values so that we have separate instances.
578        for param_mapping in self.replicated_params:
579            for submod_name, param_qualname in param_mapping.items():
580                submod = getattr(self.split_gm, submod_name)
581                atoms = param_qualname.split(".")
582                for atom in atoms[:-1]:
583                    submod = getattr(submod, atom)
584                setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
585
586        def throw(self, *args, **kwargs):
587            raise RuntimeError(
588                "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
589            )
590
591        self.split_gm.forward = throw
592
593        # Make submodules use custom direct-serialized GraphModule
594        i = 0
595        while True:
596            try:
597                name = f"submod_{i}"
598                submod = getattr(self.split_gm, name)
599                submod.__class__.__reduce__ = _direct_serialization_reduce
600                i += 1
601            except AttributeError:
602                break
603
604    def forward(self, *args, **kwargs):
605        executor_args = args
606        if len(kwargs) > 0:
607            parameters = []
608            for node in self.split_gm.graph.nodes:
609                if node.op == "placeholder":
610                    if node.args and len(node.args) > 0:
611                        parameters.append(
612                            Parameter(
613                                node.target,
614                                Parameter.POSITIONAL_OR_KEYWORD,
615                                default=node.args[0],
616                            )
617                        )
618                    else:
619                        parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
620                        param_name = node.target
621                        if node.target.startswith("**"):
622                            parameter_kind = Parameter.VAR_KEYWORD  # type: ignore[assignment]
623                            param_name = param_name[2:]
624                        elif node.target.startswith("*"):
625                            parameter_kind = Parameter.VAR_POSITIONAL  # type: ignore[assignment]
626                            param_name = param_name[1:]
627                        parameters.append(Parameter(param_name, parameter_kind))
628            signature = Signature(parameters)
629            ba = signature.bind(*args, **kwargs)
630            ba.apply_defaults()
631            executor_args = ba.arguments.values()  # type: ignore[assignment]
632
633        res = self.executor.run(*executor_args)
634
635        return res
636
637    def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
638        """
639        Return a stage module corresponding to `stage_idx` of the `pipe`.
640        """
641        if stage_idx < 0 or stage_idx >= self.num_stages:
642            raise ValueError(f"Invalid stage index {stage_idx}!")
643        return getattr(self.split_gm, f"submod_{stage_idx}")
644
645    @staticmethod
646    def _number_and_count_forward_stages(gm: fx.GraphModule):
647        num_stages = 0
648        found_idxs: Dict[int, None] = {}
649        for node in gm.graph.nodes:
650            if node.op == "call_module" and node.target.startswith("submod_"):
651                node.meta["stage_idx"] = int(node.target[len("submod_") :])
652                found_idxs.setdefault(node.meta["stage_idx"])
653                num_stages += 1
654
655        # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
656        # Update: the following assert may fail against some torch versions >=
657        # 2.2.0, as:
658        # submod_0, submod_1, submod_2, ...
659        # may be named as
660        # submod_0, submod_2, submod_4, ...
661        # TODO: investigate
662        # assert all(i in found_idxs for i in range(num_stages))
663
664        return num_stages
665
666    @staticmethod
667    def _from_traced(
668        mod: torch.nn.Module,
669        exported_program: ExportedProgram,
670        multi_use_param_spec: Optional[MultiUseParamSpec] = None,
671        output_loss_value_spec=None,
672        split_policy: Optional[
673            Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
674        ] = None,
675    ):
676        """
677        Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
678        which value in the output of `forward` is the loss value on which PiPPy should apply
679        backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
680        you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
681        a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
682        ``output_loss_value_spec={'loss': True, 'model_out': False}``
683        """
684
685        traced = exported_program.module()
686
687        if split_policy is not None:
688            logger.info("Auto-splitting model")
689            traced = split_policy(traced)  # type: ignore[arg-type]
690
691        logger.debug(traced.print_readable(print_output=False))
692
693        # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
694        # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
695        # the case (especially with custom tracers), so fix that up here.
696        get_attr_nodes: Dict[str, fx.Node] = {}
697        for node in traced.graph.nodes:
698            if node.op == "get_attr":
699                get_attr_nodes.setdefault(node.target, node)
700
701                if get_attr_nodes[node.target] != node:
702                    node.replace_all_uses_with(get_attr_nodes[node.target])
703                    traced.graph.erase_node(node)
704
705        # avoid looking at next node by keeping track of previous pipe_split
706        prev_pipe_split_idx = -1
707        pipe_split_nodes_to_erase = set()
708        for i, node in enumerate(traced.graph.nodes):
709            if (node.op, node.target) == ("call_function", pipe_split):
710                if prev_pipe_split_idx == i - 1:
711                    pipe_split_nodes_to_erase.add(node)
712                prev_pipe_split_idx = i
713
714        for node in pipe_split_nodes_to_erase:
715            traced.graph.erase_node(node)
716
717        traced.recompile()
718
719        part_idx = 0
720
721        def split_callback(n: fx.Node):
722            nonlocal part_idx
723            if (n.op, n.target) == (
724                "call_function",
725                aten_pipe_split_alias,
726            ):
727                logger.debug(f"Found pipe_split {part_idx}")  # noqa: G004
728                part_idx += 1
729            return part_idx
730
731        # TODO: what does split do with module invocations? does it move the modules
732        # into the submodules?
733        split = split_module(traced, mod, split_callback)  # type: ignore[arg-type]
734        # a (custom) tracer can produce dead code like orphan get_attr nodes
735        split.graph.eliminate_dead_code()
736
737        # peephole to remove pipe_split
738        for submodule in split.modules():
739            if isinstance(submodule, fx.GraphModule):
740                for node in submodule.graph.nodes:
741                    if (node.op, node.target) == (
742                        "call_function",
743                        aten_pipe_split_alias,
744                    ):
745                        submodule.graph.erase_node(node)
746                submodule.recompile()
747
748        for name, submodule in split.named_children():
749            if isinstance(submodule, fx.GraphModule):
750                new_submod = _outline_submodules(submodule.graph)
751                # Replace old submod
752                split.register_module(name, new_submod)
753
754        # TODO: backport this into split_module
755        def delete_user_reference(node, user):
756            """
757            Delete reference of `node` from `user`'s arg list.
758            Args:
759                - node: a `get_attr` node at root.
760                - user: a submodule node that uses `node`.
761            """
762            assert len(user.kwargs) == 0
763            use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
764            assert len(use_idxs) == 1
765            args_copy = list(user.args)
766            args_copy.pop(use_idxs[0])
767            user.args = tuple(args_copy)
768            logger.debug(
769                f"Deleted {node} from user {user}, arg index = {use_idxs[0]}"  # noqa: G004
770            )
771
772        # A list of param referrals for deferred deletion.
773        # To be accumulated in `move_param_to_callee`.
774        to_delete = []
775
776        def _recursive_getattr_with_parent(mod, fqn):
777            # Returns getattr call given a nested FQN, and the last parent
778            atoms = fqn.split(".")
779            for atom in atoms[:-1]:
780                if not hasattr(mod, atom):
781                    return None, None
782                mod = getattr(mod, atom)
783            if not hasattr(mod, atoms[-1]):
784                return mod, None
785            attr = getattr(mod, atoms[-1])
786            return mod, attr
787
788        def move_param_to_callee(
789            root,
790            callee_name,
791            param_fqn,
792        ):
793            """
794            Move a parameter from the root module to a submodule.
795            Args:
796                root: The root module.
797                callee_name: The name of the submodule to move the parameter to.
798                param_fqn: The fully qualified name of the parameter to move.
799            """
800            # `atoms` is a list of strings representing the path to the
801            # parameter in the original model
802            atoms = param_fqn.split(".")
803            mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
804            # Check whether the parameter is a buffer or a parameter
805            is_buffer = atoms[-1] in mod_itr._buffers
806
807            # Check whether the parameter is a tensor
808            assert isinstance(param_val, torch.Tensor), (
809                f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
810                + (
811                    f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
812                    f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
813                    f"usages of '{param_fqn}' in the traced graph."
814                    if isinstance(param_val, torch.nn.Module)
815                    else ""
816                )
817            )
818
819            # Get submodule
820            callee = root.get_submodule(callee_name)
821            assert not hasattr(
822                callee, param_fqn
823            ), f"Module {callee_name} already has a parameter named {param_fqn}"
824
825            # Assign the parameter to the submodule
826            if is_buffer:
827                _assign_attr(
828                    param_val,
829                    callee,
830                    param_fqn,
831                    attr_kind=_AttrKind.BUFFER,
832                    persistent=True,  # TODO: handle non-persistent buffer
833                )
834            else:
835                _assign_attr(
836                    param_val,
837                    callee,
838                    param_fqn,
839                    attr_kind=_AttrKind.PARAMETER,
840                )
841            logger.debug(f"Moved parameter {param_fqn} to {callee_name}")  # noqa: G004
842
843            # Next step is to replace placeholder of submodule with a get_attr.
844            # Those placeholders are created by `split_module` inside each
845            # submodule.
846            # Update: this step is now moved to `_sink_params` because
847            # `_sink_params` can do it recursively (i.e. for modules inside
848            # submodule)
849
850            to_delete.append((mod_itr, atoms[-1]))
851
852        # Get the list of all parameters in the root module
853        attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
854        for node in attr_nodes:
855            # Check whether the parameter is used in only one submodule
856            if len(node.users) > 1:
857                logger.info(
858                    f"Parameter {node.target} used in multiple stages: {node.users}."  # noqa: G004
859                )
860            for user in node.users:
861                assert user.op == "call_module"
862                # Move parameter into submodule
863                move_param_to_callee(
864                    split,
865                    user.target,
866                    node.target,
867                )
868
869        # [aliasing] store tensor id -> list of FQNs, built from state dict
870        # Also assign non-persistent buffers
871        id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
872        for fqn, tensor in mod.state_dict(keep_vars=True).items():
873            id_to_fqns[id(tensor)].add(fqn)
874        for fqn, tensor in mod.named_buffers():
875            id_to_fqns[id(tensor)].add(fqn)
876
877        # After moving the params to their corresponding hierarchies, we also
878        # need to move the `get_attr` nodes from the root of the graph to those
879        # hierarchies.
880        # [aliasing] use id -> fqn mapping to list out all valid FQNs
881        inputs_to_state: Dict[str, List[str]] = {}
882        for attr in attr_nodes:
883            _, tensor = _recursive_getattr_with_parent(mod, attr.target)
884            fqns = list(id_to_fqns[id(tensor)])
885            if fqns:
886                inputs_to_state[attr.name] = fqns
887            elif attr.target in exported_program.constants:  # lifted constants
888                inputs_to_state[attr.name] = [attr.target]
889
890        # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
891        # We determine this based on whether or not the FQN attribute parent exists.
892        # i.e. if the last submodule exists, assign the attribute.
893        added_attributes: Dict[str, List[str]] = defaultdict(list)
894        for fqn, tensor in mod.state_dict(keep_vars=True).items():
895            for name, submod in split.named_children():
896                if isinstance(submod, fx.GraphModule):
897                    parent, child = _recursive_getattr_with_parent(submod, fqn)
898                    if (
899                        parent and child is None
900                    ):  # parent exists, attribute doesn't -> assign
901                        added_attributes[name].append(fqn)
902                        setattr(parent, fqn.split(".")[-1], tensor)
903
904        # Deferral deletion: Remove the original attributes (to params) from the
905        # root GraphModule
906        for mod_itr, last_atom in to_delete:
907            try:
908                delattr(mod_itr, last_atom)
909            except AttributeError:
910                # This is expected if the parameter is used in multiple stages
911                pass
912
913        # This is done by (1) `_sink_params` at each submodule;
914        for name, submod in split.named_children():
915            if isinstance(submod, fx.GraphModule):
916                _sink_params(submod, inputs_to_state, [])
917                submod.graph.lint()
918                submod.recompile()
919
920        # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
921        # After _sink_params() routine has run, clean up unused attributes that we previously added.
922        # Determine this based on the get_attr nodes - if not used, remove it.
923        for name, attributes in added_attributes.items():
924            submod = getattr(split, name)
925            unused_attributes = set(attributes)
926            # track used attributes in the submodule, running DFS on subgraph hierarchy
927            stack = [("", submod)]  # (scope, submodule)
928            while stack:
929                scope, _mod = stack.pop()
930                if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
931                    for node in _mod.graph.nodes:
932                        if node.op == "get_attr":
933                            # get_attr might get access deeper level attribute
934                            fqn = scope + "." + node.target if scope else node.target
935                            if fqn in unused_attributes:  # used, remove it
936                                unused_attributes.remove(fqn)
937                for _name, _submod in _mod.named_children():
938                    stack.append((scope + "." + _name if scope else _name, _submod))
939            # delete unused attributes
940            for attr in unused_attributes:
941                mod_itr, atoms = submod, attr.split(".")
942                for atom in atoms[:-1]:
943                    mod_itr = getattr(mod_itr, atom)
944                delattr(mod_itr, atoms[-1])
945
946        for node in attr_nodes:
947            # And (2): remove `get_attr` node from submod's arg list
948            for user in copy.copy(node.users):
949                assert user.op == "call_module"
950                delete_user_reference(node, user)
951            # And (3): remove the `get_attr` node from the root graph.
952            split.graph.erase_node(node)
953
954        split.delete_all_unused_submodules()
955        split.graph.lint()
956        split.recompile()
957
958        num_stages = Pipe._number_and_count_forward_stages(split)
959
960        has_loss_and_backward = False
961        generated_loss_spec = output_loss_value_spec
962
963        if output_loss_value_spec is not None:
964            loss_node, output_node, generated_loss_spec = _find_loss_output(
965                mod, split.graph, output_loss_value_spec
966            )
967            if loss_node is not None:
968                _insert_stage_symbolic_backward(
969                    split.graph,
970                    loss_node,
971                    output_node,
972                )
973                split.recompile()
974                has_loss_and_backward = True
975                logger.debug("Pipeline is in training mode, backward pass generated")
976            else:
977                raise RuntimeError(
978                    f"Did not find any loss value according to {output_loss_value_spec=}"
979                )
980        else:
981            logger.debug("Pipeline is in inference mode, backward pass not generated")
982
983        logger.debug("Full pipe model:\n" f"{split}")  # noqa: G004
984
985        return Pipe(
986            split,
987            num_stages,
988            has_loss_and_backward,
989            generated_loss_spec,
990        )
991
992    def print_readable(self):
993        """
994        Print the pipe in a human-readable format.
995        This will print both the root pipe and each stage module.
996        """
997        self.split_gm.print_readable()
998
999    @staticmethod
1000    def _trace_with_export(
1001        mod: torch.nn.Module,
1002        example_args: Tuple[Any, ...],
1003        example_kwargs: Optional[Dict[str, Any]] = None,
1004    ) -> ExportedProgram:
1005        logger.info("Tracing model ...")
1006        try:
1007            ep = torch.export.export(
1008                mod,
1009                example_args,
1010                example_kwargs,
1011            )
1012        except Exception as e:
1013            raise RuntimeError(
1014                "It seems that we cannot capture your model as a full graph. "
1015                "Typical reasons include graph breaks, data/shape-dependent "
1016                "control flow, or missing meta kernels for custom operators. "
1017                "You can use our manual pipeline interfaces, or try to fix the "
1018                "graph breaks, see https://pytorch.org/docs/stable/export.html"
1019            ) from e
1020
1021        return ep
1022
1023    @staticmethod
1024    def from_tracing(
1025        mod: torch.nn.Module,
1026        example_args: Tuple[Any, ...],
1027        example_kwargs: Optional[Dict[str, Any]] = None,
1028        split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
1029    ):
1030        # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
1031        # stages instead of TRANSMIT'ting it
1032        multi_use_param_spec = MultiUseParameterConfig.REPLICATE
1033
1034        # Figure out which output is loss from output_chunk_spec
1035        output_loss_value_spec: Any = None
1036        # Deprecated
1037        """
1038        if output_chunk_spec is not None:
1039            output_loss_value_spec = map_aggregate(
1040                output_chunk_spec, lambda v: isinstance(v, _LossReducer)
1041            )
1042        """
1043
1044        # Trace with export
1045        exported_program = Pipe._trace_with_export(
1046            mod,
1047            example_args,
1048            example_kwargs,
1049        )
1050
1051        pipe = Pipe._from_traced(
1052            mod,
1053            exported_program,
1054            multi_use_param_spec,
1055            output_loss_value_spec=output_loss_value_spec,
1056            split_policy=split_policy,
1057        )
1058
1059        # Users want the first pipeline stage to accept kwargs if the original
1060        # program does. This is controlled by the `_codegen` field of the graph,
1061        # so we make a copy here. Note: we only want the input spec and not the
1062        # output spec, because the output spec is for the last stage. Maybe a
1063        # TODO? Not sure yet.
1064        split = pipe.split_gm
1065        traced = exported_program.module()
1066        submod0 = next(iter(split.children()))
1067        submod0_sign = signature(submod0.forward)
1068        model_sign = signature(traced.forward)
1069        if len(model_sign.parameters) != len(submod0_sign.parameters):
1070            # We don't change the signature of the first stage if it takes
1071            # different number of args than original model
1072            logger.info(
1073                f"Original model takes {len(model_sign.parameters)} args but the "  # noqa: G004
1074                f"first pipeline stage takes {len(submod0_sign.parameters)}. "
1075                "Please provide args to respective pipeline stages."
1076            )
1077        else:
1078            # Support kwargs for the first stage
1079            submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
1080            # `_replace` is actually not "private" or internal. based on this doc:
1081            # To prevent conflicts with field names, the method and attribute names
1082            # start with an underscore
1083            submod0.graph._codegen.pytree_info = (
1084                submod0.graph._codegen.pytree_info._replace(out_spec=None)
1085            )
1086            submod0.recompile()
1087
1088        return pipe
1089
1090    def __str__(self):
1091        return self.split_gm.__str__()
1092
1093    def __repr__(self):
1094        return self.split_gm.__repr__()
1095
1096    def info(self) -> PipeInfo:
1097        """
1098        Get information about the pipe.
1099
1100        Returns
1101        -------
1102        PipeInfo
1103            A dataclass containing information about the pipe.
1104        """
1105        return PipeInfo(
1106            graph=self.split_gm.graph,
1107            num_stages=self.num_stages,
1108            has_loss_and_backward=self.has_loss_and_backward,
1109        )
1110
1111    def build_stage(
1112        self,
1113        stage_index: int,
1114        device: torch.device,
1115        group: Optional[ProcessGroup] = None,
1116    ) -> _PipelineStage:
1117        """
1118        Create a `PipelineStage` given a stage index and distributed group.
1119        The `PipelineStage` can run with `PipelineSchedule`s.
1120        """
1121        # Find stage module
1122        stage_module = self.get_stage_module(stage_index)
1123
1124        # Move ops argument to device
1125        # Today PT2 tracer does not treat `x.device` as a symbolic device;
1126        # instead, the device of tracing time got burned into the generated
1127        # code.  Here we provide a workaround for users to manually modify the
1128        # "device" kwarg of operations. Such operation may include:
1129        # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
1130        if isinstance(stage_module, torch.fx.GraphModule):
1131            _modify_graph_op_device(stage_module, device)
1132        else:
1133            logger.warning(
1134                f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}"  # noqa: G004
1135            )
1136
1137        # Detach pipe info
1138        # Note: be careful what's included in `pipe_info`. We don't want to keep
1139        # a reference to `Pipe` or `Pipe.split_gm` which stops python from
1140        # recycling them. When python recycles them, other stage modules (which
1141        # are irrelevant to current rank) can be automatically freed.
1142        pipe_info = self.info()
1143        return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
1144
1145
1146class SplitPoint(Enum):
1147    BEGINNING = 1
1148    END = 2
1149
1150
1151# For backward compatibility, we kept the PipeSplitWrapper class because `class
1152# SplitPoint` used to be defined in this class.
1153class PipeSplitWrapper:
1154    # Create a class alias for BC
1155    SplitPoint = SplitPoint
1156
1157
1158def _split_before_forward(self, *args, **kwargs):
1159    pipe_split()
1160    return self._orig_forward(*args, **kwargs)
1161
1162
1163def _split_after_forward(self, *args, **kwargs):
1164    try:
1165        return self._orig_forward(*args, **kwargs)
1166    finally:
1167        pipe_split()
1168
1169
1170def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
1171    # TODO: make this implementation out-of-place?
1172    for qualname, split_type in spec.items():
1173        atoms = qualname.split(".")
1174        predecessor_module = mod
1175        for i, atom in enumerate(atoms[:-1]):
1176            try:
1177                predecessor_module = getattr(predecessor_module, atom)
1178            except AttributeError as e:
1179                raise AttributeError(
1180                    f"Specified target {qualname} referenced "
1181                    f'nonexistent module {".".join(atoms[: i + 1])}'
1182                ) from e
1183
1184        mod_to_wrap = getattr(predecessor_module, atoms[-1])
1185        mod_to_wrap._orig_forward = mod_to_wrap.forward
1186        if split_type == SplitPoint.BEGINNING:
1187            mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
1188        elif split_type == SplitPoint.END:
1189            mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
1190        else:
1191            raise ValueError("Unknown split point type.")
1192
1193
1194def pipeline(
1195    module: torch.nn.Module,
1196    mb_args: Tuple[Any, ...],
1197    mb_kwargs: Optional[Dict[str, Any]] = None,
1198    split_spec: Optional[Dict[str, SplitPoint]] = None,
1199    split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
1200) -> Pipe:
1201    """
1202    Split a module based on a specification.
1203
1204    See `Pipe` for more details.
1205
1206    Arguments
1207    ---------
1208    module:
1209        The module to be splitted.
1210    mb_args:
1211        Example positional inputs, in micro-batch form.
1212    mb_kwargs:
1213        Example keyword inputs, in micro-batch form. (default: `None`)
1214    split_spec:
1215        A dictionary using submodule names as split marker. (default: `None`)
1216    split_policy:
1217        The policy to use for splitting the module. (default: `None`)
1218
1219    Returns
1220    -------
1221    A pipeline representation of class `Pipe`.
1222    """
1223    if split_spec is not None and split_policy is not None:
1224        raise ValueError(
1225            "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
1226        )
1227
1228    if split_spec is not None:
1229        # Annotate split points in the module based on user spec
1230        annotate_split_points(module, split_spec)
1231        return Pipe.from_tracing(
1232            mod=module,
1233            example_args=mb_args,
1234            example_kwargs=mb_kwargs,
1235        )
1236    else:
1237        # Use split policy
1238        return Pipe.from_tracing(
1239            mod=module,
1240            example_args=mb_args,
1241            example_kwargs=mb_kwargs,
1242            split_policy=split_policy,
1243        )
1244