xref: /aosp_15_r20/external/pytorch/torch/fx/passes/split_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from dataclasses import dataclass, field
4from typing import Dict, List, Optional, Tuple, Type, Union
5
6import torch.fx
7from torch.fx._compatibility import compatibility
8from torch.fx.graph import map_arg
9from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
10
11from .tools_common import NodeList
12
13__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
14
15
16@compatibility(is_backward_compatible=False)
17def getattr_recursive(obj, name):
18    for layer in name.split("."):
19        if hasattr(obj, layer):
20            obj = getattr(obj, layer)
21        else:
22            return None
23    return obj
24
25
26@compatibility(is_backward_compatible=False)
27def setattr_recursive(obj, attr, value):
28    if "." not in attr:
29        setattr(obj, attr, value)
30    else:
31        layer = attr.split(".")
32        setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
33
34
35@compatibility(is_backward_compatible=False)
36@dataclass
37class Component:
38    """
39    A component serves as a container for a subgraph we want to create afterwards.
40    """
41
42    graph: torch.fx.Graph
43    order: int
44    name: str
45
46    # Stores the placeholder nodes in `graph`.
47    input_placeholders: List = field(default_factory=list)
48
49    # Store the nodes in original graph that are placeholder in `graph`.
50    orig_inputs: List = field(default_factory=list)
51
52    # Store the nodes in original graph that are outputs in `graph`.
53    orig_outputs: List = field(default_factory=list)
54
55    # Mapping from get_attr node in original graph to get_attr node in `graph`.
56    getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
57    constructor_args: List[str] = field(default_factory=list)
58    gm: Optional[torch.fx.GraphModule] = None
59
60
61@compatibility(is_backward_compatible=False)
62def split_by_tags(
63    gm: torch.fx.GraphModule,
64    tags: List[str],
65    return_fqn_mapping: bool = False,
66    return_tuple: bool = False,
67    GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
68) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
69    """
70    Splits a GraphModule using tags on its graph nodes. We honor the order of
71    tags. For example, we have tags = ["a", "b", "c"], the function will create
72    the initial submodules in the order of "a", "b", "c".
73
74    To set a tag:
75    gm.graph.nodes[idx].tag = "mytag"
76
77    This will result in all nodes with the same tag being extracted and placed in their
78    own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
79    and output nodes are created when needed while get_attr nodes get copied to submodules
80    where they are used.
81
82    Given the following module def:
83
84    class SimpleModule(torch.nn.Module):
85        def __init__(self) -> None:
86            super().__init__()
87            self.linear1 = torch.nn.Linear(...)
88            self.linear2 = torch.nn.Linear(...)
89            self.linear3 = torch.nn.Linear(...)
90
91        def forward(self, in1, in2):
92            r1 = self.linear1(in1)
93            r2 = self.linear2(in2)
94            r3 = torch.cat([r1, r2])
95            return self.linear3(r3)
96
97    Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
98
99    ro:
100    def forward(self, in1):
101        self = self.root
102        linear1 = self.linear1(in1)
103        return linear1
104
105    main:
106    def forward(self, in2, linear1):
107        self = self.root
108        linear2 = self.linear2(in2)
109        cat_1 = torch.cat([linear1, linear2])
110        linear3 = self.linear3(cat_1)
111        return linear3
112
113    main:
114    def forward(self, in1, in2):
115        self = self.root
116        ro_0 = self.ro_0(in1)
117        main_1 = self.main_1(in2, ro_0)
118        return main_1
119
120    Returns:
121        split_gm: torch fx graph after split
122        orig_to_split_fqn_mapping: a map between the original fqn and the fqn
123            after split for call_module and get_attr.
124    """
125
126    def flatten(x: torch.fx.node.Argument) -> NodeList:
127        """
128        Stores nodes in x to a list and returns the list.
129        """
130        r: NodeList = []
131        map_arg(x, r.append)
132        return r
133
134    # Mapping from node in original module to node in created submodule.
135    node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
136
137    # Mapping from node in original module or created submodules to
138    # corresponding component.
139    node_to_component: Dict[torch.fx.Node, Component] = {}
140
141    # Mapping from tag to the corresponding component.
142    tag_to_component: Dict[str, Component] = {}
143
144    # Stores all components.
145    all_components: List[Component] = []
146
147    # Stores nodes that will be used in main graph.
148    used_in_main: Dict[torch.fx.Node, None] = {}
149
150    # Main graph after split.
151    main_g = torch.fx.Graph()
152
153    # Mapping from node in original module to node in main graph after split.
154    main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
155
156    # Output node of original module.
157    output_node: Optional[torch.fx.Node] = None
158
159    # Create a component for each tag, we don't expect to create other components afterwards.
160    for tag in tags:
161        comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
162        all_components.append(comp)
163        tag_to_component[tag] = comp
164
165    # Traverse the nodes in original graph and take care of them.
166    for node in gm.graph.nodes:
167        if node.op == "output":
168            if output_node is not None:
169                raise RuntimeError("Multiple output nodes in graph!")
170            output_node = node
171            continue
172
173        # Placeholders in the original graph get copied to main graph.
174        if node.op == "placeholder":
175            main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
176            main_remapping[node].meta = copy.copy(node.meta)
177            continue
178
179        # Get_attr nodes are ignored because we are not tagging them.
180        # Instead, we copy them directly to the submodules use them afterwards.
181        if node.op == "get_attr":
182            continue
183
184        # Now we process callable nodes which are nodes with op of call_module,
185        # call_function or call_method. Every callable nodes should be tagged.
186        assert hasattr(node, "tag")
187
188        upstream_components = [
189            node_to_component[x]
190            for x in flatten(node.args) + flatten(node.kwargs)
191            if x.op not in {"placeholder", "get_attr"}
192        ]
193
194        comp = tag_to_component[node.tag]
195        node_to_component[node] = comp
196
197        # Max order of upperstream components.
198        mx = max((c.order for c in upstream_components), default=0)
199
200        # Expect the component for `node` has higher order then its upstream components.
201        assert comp.order >= mx
202
203        # Map a input of `node` to nodes in the component's graph.
204        def remap_func(x):
205            # If input is a get_attr node, copy it to current component's graph.
206            # Returns the get_attr node in current component's graph.
207            if x.op == "get_attr":
208                if x not in comp.getattr_maps:
209                    comp.getattr_maps[x] = comp.graph.get_attr(
210                        x.target, type_expr=x.type
211                    )
212                return comp.getattr_maps[x]
213
214            # If input is not a placeholder, it should have been put into a component
215            # already. If it's the current component then we return the corresponding
216            # node in the component.
217            if x.op != "placeholder" and node_to_component[x] == comp:
218                return node_remapping[x]
219
220            # If input is a placeholder or it's in other components, we want to make it
221            # as a placeholder in current component's graph.
222            if x not in comp.orig_inputs:
223                comp.orig_inputs.append(x)
224                placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
225                placeholder.meta = copy.copy(x.meta)
226                comp.input_placeholders.append(placeholder)
227                used_in_main[x] = None
228
229            return comp.input_placeholders[comp.orig_inputs.index(x)]
230
231        n = comp.graph.node_copy(node, remap_func)
232        n.tag = node.tag  # type: ignore[attr-defined]
233        node_remapping[node] = n
234        node_to_component[n] = comp
235
236    if output_node is None:
237        raise RuntimeError("Graph had no output node!")
238
239    for x in flatten(output_node.args[0]):
240        if x.op == "get_attr":
241            # We don't need components mapping for nodes of type "get_attr"
242            # that are consumed by the output. Only need to make sure we create
243            # corresponding counterparts in the resulting graph.
244            main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
245        else:
246            # All component results consumed by the output node should be
247            # marked as "used in main".
248            used_in_main[x] = None
249
250    # If a node is used in main graph then we mark it as an output in the component
251    # it belongs to.
252    for n in used_in_main:
253        if n.op != "placeholder":
254            node_to_component[n].orig_outputs.append(n)
255
256    # Now we create a graphmodule for each component.
257    orig_to_split_fqn_mapping: Dict[str, str] = {}
258    for comp in all_components:
259        outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
260
261        if return_tuple:
262            comp.graph.output(outs)
263        else:
264            # Take care of the args of FX output node. If there's a single
265            # output then the output node args is like (output_single), else
266            # if there're multiple outputs then the output node args is like
267            # ((output_0, output_1, ...)).
268            comp.graph.output(outs[0] if len(outs) == 1 else outs)
269
270        comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
271            gm, subgraph=comp.graph, comp_name=comp.name
272        )
273        orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
274
275        # Create a call_module node in main graph.
276        main_node = main_g.call_module(
277            comp.name,
278            args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
279            kwargs=None,
280        )
281
282        if len(outs) == 1 and not return_tuple:
283            main_remapping[comp.orig_outputs[0]] = main_node
284        else:
285            for i, o in enumerate(comp.orig_outputs):
286                # Use Proxy to record getitem access.
287                main_remapping[o] = torch.fx.Proxy(main_node)[i].node  # type: ignore[index]
288
289    main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
290    main_root = HolderModule({comp.name: comp.gm for comp in all_components})
291    main_g._codegen = gm.graph._codegen
292
293    # If the output nodes consumes get_attr directly in the original graph,
294    # then we need to make sure get_attr is copied to the new graph.
295    for x in flatten(output_node.args[0]):
296        if x.op == "get_attr":
297            setattr(main_root, x.name, getattr_recursive(gm, x.target))  # type: ignore[arg-type]
298
299    result_gm = GraphModuleCls(main_root, main_g)
300    if return_fqn_mapping:
301        return result_gm, orig_to_split_fqn_mapping
302
303    return result_gm
304