xref: /aosp_15_r20/external/pytorch/test/fx/test_matcher_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import os
4import sys
5from typing import Callable
6
7import torch
8import torch.nn.functional as F
9from torch.fx import symbolic_trace
10from torch.fx.experimental.proxy_tensor import make_fx
11
12
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15import unittest
16
17from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
18from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
19    SubgraphMatcherWithNameNodeMap,
20)
21from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
22from torch.testing._internal.jit_utils import JitTestCase
23
24
25class WrapperModule(torch.nn.Module):
26    def __init__(self, fn: Callable):
27        super().__init__()
28        self.fn = fn
29
30    def forward(self, *args, **kwargs):
31        return self.fn(*args, **kwargs)
32
33
34class TestMatcher(JitTestCase):
35    def test_subgraph_matcher_with_attributes(self):
36        class LargeModel(torch.nn.Module):
37            def __init__(self) -> None:
38                super().__init__()
39                self._weight = torch.nn.Parameter(torch.ones(3, 3))
40                self._bias = torch.nn.Parameter(torch.ones(3, 3))
41
42            def forward(self, x):
43                return torch.ops.aten.addmm.default(self._bias, x, self._weight)
44
45        # Large Model graph:
46        # opcode         name           target              args                 kwargs
47        # -------------  -------------  ------------------  -------------------  --------
48        # placeholder    x              x                   ()                   {}
49        # get_attr       _bias          _bias               ()                   {}
50        # get_attr       _weight        _weight             ()                   {}
51        # call_function  addmm_default  aten.addmm.default  (_bias, x, _weight)  {}
52        # output         output         output              (addmm_default,)     {}
53        large_model_graph = symbolic_trace(LargeModel()).graph
54
55        class PatternModel(torch.nn.Module):
56            def __init__(self) -> None:
57                super().__init__()
58                self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
59                self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
60
61            def forward(self, x):
62                return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
63
64        pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
65
66        subgraph_matcher = SubgraphMatcher(pattern_graph)
67        match_result = subgraph_matcher.match(large_model_graph)
68        self.assertEqual(len(match_result), 1)
69
70    def test_subgraph_matcher_with_list(self):
71        def original(x, y):
72            return torch.ops.aten.view(x, [5, y.shape[0]])
73
74        original_graph = torch.fx.symbolic_trace(original).graph
75
76        def pattern(x, y, z):
77            return torch.ops.aten.view(x, [z, y.shape[0]])
78
79        pattern_graph = torch.fx.symbolic_trace(pattern).graph
80
81        subgraph_matcher = SubgraphMatcher(pattern_graph)
82        match_result = subgraph_matcher.match(original_graph)
83        self.assertEqual(len(match_result), 1)
84
85    def test_subgraph_matcher_with_list_bad(self):
86        def original(x, y):
87            return torch.ops.aten._reshape_alias_copy.default(
88                x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
89            )
90
91        original_graph = torch.fx.symbolic_trace(original).graph
92
93        def pattern(x, y, b):
94            return torch.ops.aten._reshape_alias_copy.default(
95                x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
96            )
97
98        pattern_graph = torch.fx.symbolic_trace(pattern).graph
99
100        subgraph_matcher = SubgraphMatcher(pattern_graph)
101        match_result = subgraph_matcher.match(original_graph)
102        self.assertEqual(len(match_result), 0)
103
104    def test_subgraph_matcher_ignore_literals(self):
105        def original(x):
106            return x + 1
107
108        original_graph = make_fx(original)(torch.ones(3, 3)).graph
109        original_graph.eliminate_dead_code()
110
111        def pattern(x):
112            return x + 2
113
114        pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
115        pattern_graph.eliminate_dead_code()
116
117        subgraph_matcher = SubgraphMatcher(pattern_graph)
118        match_result = subgraph_matcher.match(original_graph)
119        self.assertEqual(len(match_result), 0)
120
121        subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
122        match_result = subgraph_matcher.match(original_graph)
123        self.assertEqual(len(match_result), 1)
124
125    def test_variatic_arg_matching(self):
126        inputs = (torch.randn(20, 16, 50, 32),)
127
128        def maxpool(x, kernel_size, stride, padding, dilation):
129            return torch.ops.aten.max_pool2d_with_indices.default(
130                x, kernel_size, stride, padding, dilation
131            )
132
133        maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
134
135        maxpool_matcher = SubgraphMatcher(maxpool_graph)
136        match_result = maxpool_matcher.match(maxpool_graph)
137        self.assertEqual(len(match_result), 1)
138
139        # Graph only contains "stride" argument
140        maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval()
141        maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph
142        match_s_result = maxpool_matcher.match(maxpool_s_graph)
143        self.assertEqual(len(match_s_result), 1)
144
145        # Graph only contains "padding" argument
146        maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1)
147        maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph
148        match_p_result = maxpool_matcher.match(maxpool_p_graph)
149        self.assertEqual(len(match_p_result), 1)
150
151        # Graph only contains "stride, padding" argument
152        maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
153        maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph
154        match_sp_result = maxpool_matcher.match(maxpool_sp_graph)
155        self.assertEqual(len(match_sp_result), 1)
156
157    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
158    def test_split_to_graph_and_name_node_map(self):
159        """Testing the internal helper function for splitting the pattern graph"""
160        from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
161            _split_to_graph_and_name_node_map,
162        )
163
164        def pattern(x, weight):
165            conv = F.conv2d(x, weight)
166            relu = F.relu(conv)
167            relu_mul_by_two = relu * 2
168            return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
169
170        from torch._export import capture_pre_autograd_graph
171
172        example_inputs = (
173            torch.randn(1, 3, 3, 3) * 10,
174            torch.randn(3, 3, 3, 3),
175        )
176        pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
177        before_split_res = pattern_gm(*example_inputs)
178        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
179        after_split_res = pattern_gm(*example_inputs)
180        self.assertEqual(before_split_res[0], after_split_res[0])
181        self.assertEqual(before_split_res[1], after_split_res[1])
182
183    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
184    def test_matcher_with_name_node_map_function(self):
185        """Testing SubgraphMatcherWithNameNodeMap with function pattern"""
186
187        def target_graph(x, weight):
188            x = x * 2
189            weight = weight * 3
190            conv = F.conv2d(x, weight)
191            relu = F.relu(conv)
192            relu2 = relu * 2
193            return relu + relu2
194
195        def pattern(x, weight):
196            conv = F.conv2d(x, weight)
197            relu = F.relu(conv)
198            relu_mul_by_two = relu * 2
199            return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
200
201        from torch._export import capture_pre_autograd_graph
202
203        example_inputs = (
204            torch.randn(1, 3, 3, 3) * 10,
205            torch.randn(3, 3, 3, 3),
206        )
207        pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
208        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
209        target_gm = capture_pre_autograd_graph(
210            WrapperModule(target_graph), example_inputs
211        )
212        internal_matches = matcher.match(target_gm.graph)
213        for internal_match in internal_matches:
214            name_node_map = internal_match.name_node_map
215            assert "conv" in name_node_map
216            assert "relu" in name_node_map
217            name_node_map["conv"].meta["custom_annotation"] = "annotation"
218            # check if we correctly annotated the target graph module
219            for n in target_gm.graph.nodes:
220                if n == name_node_map["conv"]:
221                    assert (
222                        "custom_annotation" in n.meta
223                        and n.meta["custom_annotation"] == "annotation"
224                    )
225
226    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
227    def test_matcher_with_name_node_map_module(self):
228        """Testing SubgraphMatcherWithNameNodeMap with module pattern"""
229
230        class M(torch.nn.Module):
231            def __init__(self) -> None:
232                super().__init__()
233                self.linear = torch.nn.Linear(5, 5)
234
235            def forward(self, x):
236                return self.linear(x)
237
238        class Pattern(torch.nn.Module):
239            def __init__(self) -> None:
240                super().__init__()
241                self.linear = torch.nn.Linear(5, 5)
242
243            def forward(self, x):
244                linear = self.linear(x)
245                # Note: we can't put "weight": self.linear.weight in dictionary since
246                # nn.Parameter is not an allowed output type in dynamo
247                return linear, {"linear": linear, "x": x}
248
249        from torch._export import capture_pre_autograd_graph
250
251        example_inputs = (torch.randn(3, 5),)
252        pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs)
253        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
254        target_gm = capture_pre_autograd_graph(M(), example_inputs)
255        internal_matches = matcher.match(target_gm.graph)
256        for internal_match in internal_matches:
257            name_node_map = internal_match.name_node_map
258            assert "linear" in name_node_map
259            assert "x" in name_node_map
260            name_node_map["linear"].meta["custom_annotation"] = "annotation"
261            # check if we correctly annotated the target graph module
262            for n in target_gm.graph.nodes:
263                if n == name_node_map["linear"]:
264                    assert (
265                        "custom_annotation" in n.meta
266                        and n.meta["custom_annotation"] == "annotation"
267                    )
268
269
270if __name__ == "__main__":
271    run_tests()
272