xref: /aosp_15_r20/external/executorch/exir/backend/test/test_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch import exir
11from executorch.exir import to_edge
12from executorch.exir.backend.backend_api import to_backend
13from executorch.exir.backend.partitioner import Partitioner, PartitionResult
14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
15from executorch.exir.backend.utils import (
16    format_delegated_graph,
17    get_delegates,
18    get_non_lowered_nodes,
19    is_identical_graph,
20)
21
22from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
23from torch.export import export, ExportedProgram
24from torch.fx import symbolic_trace
25from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
26from torch.library import Library
27
28T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
29T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
30
31
32class TestUtils(unittest.TestCase):
33    def test_identical_graph_with_unused_args(self):
34        class MyModule(torch.nn.Module):
35            def __init__(self):
36                super().__init__()
37
38            def forward(self, x, y):
39                # y is not used arg
40                return x
41
42        m = MyModule()
43        graph_module: torch.fx.GraphModule = symbolic_trace(m)
44        is_matched = is_identical_graph(graph_module, graph_module)
45        self.assertTrue(is_matched)
46
47    def test_identical_graph_with_used_args(self):
48        class MyModule(torch.nn.Module):
49            def __init__(self):
50                super().__init__()
51
52            def forward(self, x, y):
53                return x, y
54
55        m = MyModule()
56        graph_module: torch.fx.GraphModule = symbolic_trace(m)
57        is_matched = is_identical_graph(graph_module, graph_module)
58        self.assertTrue(is_matched)
59
60    def test_identical_graph_for_linear(self):
61        graph_module: torch.fx.GraphModule = symbolic_trace(torch.nn.Linear(10, 10))
62        is_matched = is_identical_graph(graph_module, graph_module)
63        self.assertTrue(is_matched)
64
65    def test_identical_graph_for_composite_module(self):
66        class MyModule(torch.nn.Module):
67            def __init__(self):
68                super().__init__()
69                self.param = torch.nn.Parameter(torch.rand(3, 4))
70                self.linear = torch.nn.Linear(4, 5)
71
72            def forward(self, x):
73                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
74
75        graph_module: torch.fx.GraphModule = symbolic_trace(MyModule())
76        is_matched = is_identical_graph(graph_module, graph_module)
77        self.assertTrue(is_matched)
78
79    def test_not_identical_graph_for_args(self):
80        class MyModule1(torch.nn.Module):
81            def __init__(self):
82                super().__init__()
83
84            def forward(self, x, y):
85                # y is not used arg
86                return x + 1
87
88        class MyModule2(torch.nn.Module):
89            def __init__(self):
90                super().__init__()
91
92            def forward(self, x, y):
93                return x + 1, y + 2
94
95        graph_module_1: torch.fx.GraphModule = (
96            to_edge(
97                export(
98                    MyModule1(),
99                    (torch.rand(3, 4), torch.rand(3, 4)),
100                )
101            )
102            .exported_program()
103            .graph_module
104        )
105        graph_module_2: torch.fx.GraphModule = (
106            to_edge(
107                export(
108                    MyModule2(),
109                    (torch.rand(3, 4), torch.rand(3, 4)),
110                )
111            )
112            .exported_program()
113            .graph_module
114        )
115        is_matched = is_identical_graph(graph_module_1, graph_module_2)
116        self.assertFalse(is_matched)
117
118    def test_match_attrs(self):
119        class LargeModel(torch.nn.Module):
120            def __init__(self):
121                super().__init__()
122                self.weght = torch.nn.Parameter(torch.ones(3, 3))
123                self.linear = torch.nn.Linear(3, 3)
124
125            def forward(self, x):
126                a = x + self.weght
127                b = self.linear(x)
128                return a, b
129
130        inputs = (torch.ones(3, 3),)
131
132        large_model = (
133            to_edge(
134                export(
135                    LargeModel(),
136                    inputs,
137                ),
138                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
139            )
140            .exported_program()
141            .graph_module
142        )
143
144        pattern = (
145            to_edge(
146                export(torch.nn.Linear(3, 3), inputs),
147                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
148            )
149            .exported_program()
150            .graph_module.graph
151        )
152
153        subgraph_matcher = SubgraphMatcher(pattern)
154        match_result = subgraph_matcher.match(large_model.graph)
155
156        # Should find exact one match
157        self.assertEqual(len(match_result), 1)
158
159    def test_invalid_partitioner_without_partitioner(self):
160        """
161        Tests replacing literals with placeholders in the case there are
162        `getitem` calls which do not have a schema.
163        """
164
165        class InvalidPartitioner(Partitioner):
166            """
167            Partitions all add/mul nodes regardless of order
168            """
169
170            def __init__(self) -> None:
171                # A valid partitioner should have partition_tags
172                self.test = "a"
173
174            def partition(
175                self, edge_exported_program: ExportedProgram
176            ) -> PartitionResult:
177                return PartitionResult(
178                    tagged_exported_program=edge_exported_program, partition_tags=None
179                )
180
181        exported_program = to_edge(
182            export(
183                torch.nn.Linear(3, 3),
184                (torch.randn(3, 3),),
185            )
186        )
187
188        error_msg = r"needs a `partition_tags` field containing a mapping of tags to delegate spec"
189        with self.assertRaisesRegex(
190            AssertionError,
191            error_msg,
192        ):
193            _ = to_backend(exported_program.exported_program(), InvalidPartitioner())
194
195    test_lib = Library("test_lib", "DEF")
196
197    @staticmethod
198    @bind_pattern_to_op(
199        test_lib, "test_q_linear(Tensor x, Tensor weight, Tensor bias) -> Tensor"
200    )
201    def q_linear(x, weight, bias):
202        return x
203
204    def test_get_non_lowered_nodes(self):
205        class Model(torch.nn.Module):
206            def __init__(self):
207                super().__init__()
208
209            def forward(self, a, x, b):
210                y = torch.mm(a, x)
211                z = y + b
212                a = z - a
213                y = torch.mm(a, x)
214                z = y + b
215                return z
216
217        m = Model()
218        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
219        edge = to_edge(export(m, inputs))
220        edge = edge.to_backend(AddMulPartitionerDemo())
221        number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph)
222        # Only sub is not not lowerable
223        self.assertEqual(len(number_of_cpu_nodes), 1)
224
225    def test_get_delegates(self):
226        class Model(torch.nn.Module):
227            def __init__(self):
228                super().__init__()
229
230            def forward(self, a, x, b):
231                y = torch.mm(a, x)
232                z = y + b
233                a = z - a
234                y = torch.mm(a, x)
235                z = y + b
236                return z
237
238        m = Model()
239        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
240        edge = to_edge(export(m, inputs))
241        edge = edge.to_backend(AddMulPartitionerDemo())
242        number_of_delegates = get_delegates(edge.exported_program().graph)
243        # there will be 2 delegates: (mm + add) -> sub -> (mm + add)
244        self.assertEqual(len(number_of_delegates), 2)
245
246    def test_print_delegted_graph(self):
247        class Model(torch.nn.Module):
248            def __init__(self):
249                super().__init__()
250
251            def forward(self, a, x, b):
252                y = torch.mm(a, x)
253                z = y + b
254                a = z - a
255                y = torch.mm(a, x)
256                z = y + b
257                return z
258
259        m = Model()
260        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
261
262        edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo())
263
264        graph_str = format_delegated_graph(edge.exported_program().graph_module)
265        self.assertIn(
266            "BackendWithCompilerDemo",
267            graph_str,
268            "Expect to find the backend id in the graph format string",
269        )
270        self.assertIn(
271            "executorch.exir.dialects.edge._ops.aten.mm.default",
272            graph_str,
273            "Expect to see the aten.mm in the delegated graph",
274        )
275