xref: /aosp_15_r20/external/pytorch/test/fx/test_fx_split.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3from collections import defaultdict
4from typing import Dict, List, Tuple
5
6import torch
7from torch.fx.passes.split_utils import split_by_tags
8from torch.testing._internal.common_utils import TestCase
9
10
11class TestFXSplit(TestCase):
12    def test_split_preserve_node_meta(self):
13        class TestModule(torch.nn.Module):
14            def forward(self, x, y):
15                x = x + x
16                y = y * y
17                return x - y
18
19        gm = torch.fx.symbolic_trace(TestModule())
20        for node in gm.graph.nodes:
21            node.meta["name"] = node.name
22            if node.name == "add":
23                node.tag = "a"
24            elif node.name == "mul":
25                node.tag = "b"
26            elif node.name == "sub":
27                node.tag = "c"
28
29        split_gm = split_by_tags(gm, ["a", "b", "c"])
30        for m in split_gm.children():
31            for n in m.graph.nodes:
32                if n.op != "output":
33                    self.assertIn("name", n.meta)
34                    self.assertEqual(n.meta["name"], n.name)
35
36        # Validate that metadata is copied correctly for graph placeholder nodes
37        for node in split_gm.graph.nodes:
38            if node.op == "placeholder":
39                self.assertIn("name", node.meta)
40                self.assertEqual(node.meta["name"], node.name)
41
42
43class TestSplitByTags(TestCase):
44    class TestModule(torch.nn.Module):
45        def __init__(self) -> None:
46            super().__init__()
47            self.linear1 = torch.nn.Linear(2, 3)
48            self.linear2 = torch.nn.Linear(4, 5)
49            self.linear3 = torch.nn.Linear(6, 7)
50            self.linear4 = torch.nn.Linear(8, 6)
51
52        def forward(
53            self,
54            x1: torch.Tensor,
55            x2: torch.Tensor,
56            x3: torch.Tensor,
57        ) -> torch.Tensor:
58            v1 = self.linear1(x1)
59            v2 = self.linear2(x2)
60            v3 = self.linear3(x3)
61            v4 = torch.cat([v1, v2, v3])
62            return self.linear4(v4)
63
64    @staticmethod
65    def trace_and_tag(
66        module: torch.nn.Module, tags: List[str]
67    ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
68        """
69        Test simple gm consists of nodes with tag (only show call_module nodes here):
70            linear1 - tag: "red"
71            linear2 - tag: "blue"
72            linear3, linear4 - tag: "green"
73
74        At the beginning we have:
75            gm:
76                linear1
77                linear2
78                linear3
79                linear4
80
81        split_gm = split_by_tags(gm, tags)
82
83        Then we have:
84            split_gm:
85                red:
86                    linear1
87                blue:
88                    linear2
89                green:
90                    linear3
91                    linear4
92        """
93        tag_node = defaultdict(list)
94        gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)
95
96        # Add tag to all nodes and build dictionary record tag to call_module nodes
97        for node in gm.graph.nodes:
98            if "linear1" in node.name:
99                node.tag = tags[0]
100                tag_node[tags[0]].append(node.name)
101            elif "linear2" in node.name:
102                node.tag = tags[1]
103                tag_node[tags[1]].append(node.name)
104            else:
105                node.tag = tags[2]
106                if node.op == "call_module":
107                    tag_node[tags[2]].append(node.name)
108        return gm, tag_node
109
110    def test_split_by_tags(self) -> None:
111        tags = ["red", "blue", "green"]
112        module = TestSplitByTags.TestModule()
113        gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
114        split_gm, orig_to_split_fqn_mapping = split_by_tags(
115            gm, tags, return_fqn_mapping=True
116        )
117        # Ensure split_gm has (and only has) ordered submodules named
118        # red_0, blue_1, green_2
119        for idx, (name, _) in enumerate(split_gm.named_children()):
120            if idx < len(tags):
121                self.assertTrue(
122                    name == tags[idx],
123                    f"split_gm has an incorrect submodule named {name}",
124                )
125
126        # Ensure each submodule has expected (ordered) call_module node(s).
127        # For example, a submodule named split_gm.red_0 has (and only has) linear1;
128        # split_gm.green_2 has (and only has) linear3 and linear4 with order
129        sub_graph_idx = 0
130        for sub_name, sub_graph_module in split_gm.named_children():
131            node_idx = 0
132            for node in sub_graph_module.graph.nodes:
133                if node.op != "call_module":
134                    continue
135                self.assertTrue(
136                    node.name == tag_node[f"{sub_name}"][node_idx],
137                    # pyre-fixme[61]: `name` is undefined, or not always defined.
138                    f"{sub_name} has incorrectly include {node.name}",
139                )
140                node_idx += 1
141            sub_graph_idx += 1
142
143        self.assertEqual(
144            orig_to_split_fqn_mapping,
145            {
146                "linear1": "red.linear1",
147                "linear2": "blue.linear2",
148                "linear3": "green.linear3",
149                "linear4": "green.linear4",
150            },
151            f"{orig_to_split_fqn_mapping=}",
152        )
153
154
155class TestSplitOutputType(TestCase):
156    class TestModule(torch.nn.Module):
157        def __init__(self) -> None:
158            super().__init__()
159            self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
160            self.relu = torch.nn.ReLU()
161
162        def forward(self, x):
163            conv = self.conv(x)
164            conv = conv * 0.5
165            relu = self.relu(conv)
166            return relu
167
168    @staticmethod
169    def trace_and_tag(
170        module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
171    ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
172        """
173        Test simple gm consists of nodes with tag (only show call_module nodes here):
174            conv - tag: "red"
175            mul - tag: "blue"
176            relu - tag: "green"
177
178        At the beginning we have:
179            gm:
180                conv
181                mul
182                relu
183
184        split_gm = split_by_tags(gm, tags)
185
186        Then we have:
187            split_gm:
188                red:
189                    conv
190                blue:
191                    mul
192                green:
193                    relu
194        """
195        tag_node = defaultdict(list)
196        gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
197        # Add tag to all nodes and build dictionary record tag to call_module nodes
198        for node in gm.graph.nodes:
199            if "conv" in node.name:
200                node.tag = tags[0]
201                tag_node[tags[0]].append(node.name)
202            elif "mul" in node.name:
203                node.tag = tags[1]
204                tag_node[tags[1]].append(node.name)
205            else:
206                node.tag = tags[2]
207                if node.op == "call_module":
208                    tag_node[tags[2]].append(node.name)
209        return gm, tag_node
210
211    def test_split_by_tags(self) -> None:
212        tags = ["red", "blue", "green"]
213        module = TestSplitOutputType.TestModule()
214
215        inputs = torch.randn((1, 3, 224, 224))
216
217        gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
218        split_gm, orig_to_split_fqn_mapping = split_by_tags(
219            gm, tags, return_fqn_mapping=True
220        )
221
222        gm_output = module(inputs)
223        split_gm_output = split_gm(inputs)
224
225        self.assertTrue(type(gm_output) == type(split_gm_output))
226        self.assertTrue(torch.equal(gm_output, split_gm_output))
227