xref: /aosp_15_r20/external/executorch/exir/backend/test/test_graph_partition.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import collections
8import unittest
9from typing import List, Optional, Tuple
10
11import torch
12from executorch.exir import EdgeCompileConfig, to_edge
13from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
14    generate_partitions_from_list_of_nodes,
15)
16from executorch.exir.dialects._ops import ops as exir_ops
17from torch.export import export
18from torch.fx.node import Node
19from torch.fx.passes.operator_support import OperatorSupportBase
20
21
22class TestGraphPartition(unittest.TestCase):
23    def get_graph_module(
24        self, module: torch.nn.Module, inputs: Tuple[torch.Tensor]
25    ) -> torch.fx.GraphModule:
26        graph_module = (
27            to_edge(
28                export(module, inputs),
29                compile_config=EdgeCompileConfig(
30                    _check_ir_validity=False,
31                ),
32            )
33            .exported_program()
34            .graph_module
35        )
36
37        return graph_module
38
39    # hackily get list of nodes
40    def get_node_list(
41        self,
42        graph_module: torch.fx.GraphModule,
43        supported_modules: List[torch.nn.Module],
44    ) -> List[List[Node]]:
45        pattern_list_map = collections.defaultdict(list)
46        placeholders = [
47            node
48            for node in graph_module.graph.nodes
49            if node.op == "placeholder"
50            and node.target != "x"  # x is a hack to avoid the user input
51        ]
52        for node in graph_module.graph.nodes:
53            if "nn_module_stack" in node.meta:
54                module_values_list = list(node.meta["nn_module_stack"].values())
55                full_qualified_name = module_values_list[-1][0]
56                owning_module = module_values_list[-1][1]
57                if owning_module in supported_modules:
58                    pattern_list_map[(full_qualified_name, owning_module)].append(node)
59                    for arg in node.args:
60                        if isinstance(arg, Node) and arg in placeholders:
61                            pattern_list_map[
62                                (full_qualified_name, owning_module)
63                            ].append(arg)
64
65        return list(pattern_list_map.values())
66
67    def extract_partition_list(
68        self,
69        graph_module: torch.fx.GraphModule,
70        supported_modules: List[torch.nn.Module],
71        op_support: Optional[OperatorSupportBase] = None,
72    ) -> List:
73
74        node_list = self.get_node_list(graph_module, supported_modules)
75
76        partition_list = generate_partitions_from_list_of_nodes(
77            graph_module, node_list, op_support
78        )
79
80        return partition_list
81
82    def test_partition_list_without_op_support_one_partition(self):
83        """
84        check all of submodules should be lowered into a single part
85        """
86
87        class TestModule(torch.nn.Module):
88            def __init__(self):
89                super().__init__()
90                self.conv1 = torch.nn.Conv2d(32, 32, 1)
91                self.conv2 = torch.nn.Conv2d(32, 32, 1)
92                self.conv3 = torch.nn.Conv2d(32, 32, 1)
93                self.relu = torch.nn.ReLU()
94
95            def forward(self, x: torch.Tensor):
96                a = self.conv1(x)
97                b = self.conv2(a)
98                c = self.conv3(b)
99                d = self.conv3(c)
100                return self.relu(d)
101
102        example_inputs = (torch.rand(1, 32, 16, 16),)
103        test_module = TestModule()
104        graph_module = self.get_graph_module(test_module, example_inputs)
105
106        supported_module = [
107            "torch.nn.modules.conv.Conv2d",
108            "torch.nn.modules.activation.ReLU",
109        ]
110        partition_list = self.extract_partition_list(graph_module, supported_module)
111
112        self.assertEqual(len(partition_list), 1)
113
114    def test_partition_list_without_op_support_two_partitions(self):
115        """
116        check graph will be divided into 2 parts when the supported modules is provided, but OpeartorSupportBase is not provideds
117        """
118
119        class TestModule(torch.nn.Module):
120            def __init__(self):
121                super().__init__()
122                self.conv1 = torch.nn.Conv2d(32, 32, 1)
123                self.conv2 = torch.nn.Conv2d(32, 32, 1)
124                self.conv3 = torch.nn.Conv2d(32, 32, 1)
125                self.relu = torch.nn.ReLU()
126
127            def forward(self, x: torch.Tensor):
128                a = self.conv1(x)
129                b = self.conv2(a)
130                c = self.conv3(a + b)
131                d = self.conv3(c)
132                return self.relu(d)
133
134        example_inputs = (torch.rand(1, 32, 16, 16),)
135        test_module = TestModule()
136        graph_module = self.get_graph_module(test_module, example_inputs)
137
138        supported_module = [
139            "torch.nn.modules.conv.Conv2d",
140            "torch.nn.modules.activation.ReLU",
141        ]
142        partition_list = self.extract_partition_list(graph_module, supported_module)
143
144        self.assertEqual(len(partition_list), 2)
145
146        partition_1 = [
147            "aten_convolution_default_2",
148            "aten_convolution_default_3",
149            "aten_relu_default",
150            "p_conv3_bias",
151            "p_conv3_weight",
152        ]
153        partition_2 = [
154            "aten_convolution_default",
155            "aten_convolution_default_1",
156            "p_conv1_bias",
157            "p_conv1_weight",
158            "p_conv2_bias",
159            "p_conv2_weight",
160        ]
161
162        # extract node names from partition_list, compare them with expected node names
163        node_list_1 = []
164        for node in partition_list[0].nodes:
165            node_list_1.append(node.name)
166
167        node_list_2 = []
168        for node in partition_list[1].nodes:
169            node_list_2.append(node.name)
170
171        node_list_1 = sorted(node_list_1)
172        node_list_2 = sorted(node_list_2)
173
174        self.assertEqual(node_list_1, partition_1)
175        self.assertEqual(node_list_2, partition_2)
176
177    def test_graph_partition_with_op_support(self):
178        """
179        check graph will be divided into 2 parts when the supported modules and OpeartorSupportBase are provided,
180        """
181
182        class TestOperatorSupport(OperatorSupportBase):
183            def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
184                return node.op == "call_function" and node.target in [
185                    exir_ops.edge.aten.div.Tensor,
186                    exir_ops.edge.aten.add.Tensor,
187                ]
188
189        class TestModule(torch.nn.Module):
190            def __init__(self):
191                super().__init__()
192                self.conv1 = torch.nn.Conv2d(32, 32, 1)
193                self.conv2 = torch.nn.Conv2d(32, 32, 1)
194                self.conv3 = torch.nn.Conv2d(32, 32, 1)
195                self.relu = torch.nn.ReLU()
196
197            def forward(self, x: torch.Tensor):
198                a = self.conv1(x)
199                b = self.conv2(a)
200                c = self.conv3(a + b)
201                d = self.conv3(c)
202                c, _ = torch.max(c, dim=2)
203                d, _ = torch.max(d, dim=2)
204                e = d - c
205                return self.relu(e)
206
207        example_inputs = (torch.rand(1, 32, 16, 16),)
208        test_module = TestModule()
209        graph_module = self.get_graph_module(test_module, example_inputs)
210
211        supported_module = [
212            "torch.nn.modules.conv.Conv2d",
213            "torch.nn.modules.activation.ReLU",
214        ]
215        partition_list = self.extract_partition_list(
216            graph_module, supported_module, TestOperatorSupport()
217        )
218
219        self.assertEqual(len(partition_list), 2)
220
221        partition_1 = ["aten_relu_default"]
222        partition_2 = [
223            "aten_add_tensor",
224            "aten_convolution_default",
225            "aten_convolution_default_1",
226            "aten_convolution_default_2",
227            "aten_convolution_default_3",
228            "p_conv1_bias",
229            "p_conv1_weight",
230            "p_conv2_bias",
231            "p_conv2_weight",
232            "p_conv3_bias",
233            "p_conv3_weight",
234        ]
235
236        # extract node names from partition_list, compare them with expected node names
237        node_list_1 = []
238        for node in partition_list[0].nodes:
239            node_list_1.append(node.name)
240
241        node_list_2 = []
242        for node in partition_list[1].nodes:
243            node_list_2.append(node.name)
244
245        node_list_1 = sorted(node_list_1)
246        node_list_2 = sorted(node_list_2)
247
248        self.assertEqual(node_list_1, partition_1)
249        self.assertEqual(node_list_2, partition_2)
250