xref: /aosp_15_r20/external/pytorch/test/fx/test_source_matcher_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import os
4import sys
5import unittest
6
7import torch
8
9
10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
11sys.path.append(pytorch_test_dir)
12from torch._dynamo.eval_frame import is_dynamo_supported
13from torch.fx.passes.tools_common import legalize_graph
14from torch.fx.passes.utils.source_matcher_utils import (
15    check_subgraphs_connected,
16    get_source_partitions,
17)
18from torch.testing._internal.common_utils import (
19    instantiate_parametrized_tests,
20    parametrize,
21)
22from torch.testing._internal.jit_utils import JitTestCase
23
24
25class TestSourceMatcher(JitTestCase):
26    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
27    def test_module_partitioner_linear_relu_linear(self):
28        class M(torch.nn.Module):
29            def __init__(self) -> None:
30                super().__init__()
31                self.linear1 = torch.nn.Linear(3, 3)
32                self.relu = torch.nn.ReLU()
33                self.linear2 = torch.nn.Linear(3, 5)
34
35            def forward(self, x):
36                x = self.linear1(x)
37                x = self.linear1(x)
38                x = self.relu(x)
39                x = self.linear2(x)
40                return x
41
42        inputs = (torch.randn(3, 3),)
43        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
44        gm.graph.eliminate_dead_code()
45
46        module_partitions = get_source_partitions(
47            gm.graph, [torch.nn.Linear, torch.nn.ReLU]
48        )
49
50        self.assertEqual(len(module_partitions), 2)
51        self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
52        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
53
54        self.assertFalse(
55            check_subgraphs_connected(
56                module_partitions[torch.nn.Linear][0],
57                module_partitions[torch.nn.ReLU][0],
58            )
59        )
60        self.assertTrue(
61            check_subgraphs_connected(
62                module_partitions[torch.nn.Linear][1],
63                module_partitions[torch.nn.ReLU][0],
64            )
65        )
66        self.assertFalse(
67            check_subgraphs_connected(
68                module_partitions[torch.nn.Linear][2],
69                module_partitions[torch.nn.ReLU][0],
70            )
71        )
72
73    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
74    def test_module_partitioner_conv_relu_maxpool(self):
75        class M(torch.nn.Module):
76            def __init__(self, constant_tensor: torch.Tensor) -> None:
77                super().__init__()
78                self.constant_tensor = constant_tensor
79                self.conv1 = torch.nn.Conv2d(
80                    in_channels=3, out_channels=16, kernel_size=3, padding=1
81                )
82                self.conv2 = torch.nn.Conv2d(
83                    in_channels=16, out_channels=16, kernel_size=3, padding=1
84                )
85                self.conv3 = torch.nn.Conv2d(
86                    in_channels=16, out_channels=16, kernel_size=3, padding=1
87                )
88                self.relu = torch.nn.ReLU()
89                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
90
91            def forward(self, x: torch.Tensor) -> torch.Tensor:
92                a = self.conv1(x)
93                b = self.conv2(a)
94                c = a + self.constant_tensor
95                z = self.conv3(b + c)
96                return self.maxpool(self.relu(z))
97
98        inputs = (torch.randn(1, 3, 256, 256),)
99        gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
100            *inputs
101        )
102        gm.graph.eliminate_dead_code()
103
104        module_partitions = get_source_partitions(
105            gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
106        )
107
108        self.assertEqual(len(module_partitions), 3)
109        self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
110        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
111        self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
112
113        self.assertFalse(
114            check_subgraphs_connected(
115                module_partitions[torch.nn.Conv2d][0],
116                module_partitions[torch.nn.ReLU][0],
117            )
118        )
119        self.assertFalse(
120            check_subgraphs_connected(
121                module_partitions[torch.nn.Conv2d][1],
122                module_partitions[torch.nn.ReLU][0],
123            )
124        )
125        self.assertTrue(
126            check_subgraphs_connected(
127                module_partitions[torch.nn.Conv2d][2],
128                module_partitions[torch.nn.ReLU][0],
129            )
130        )
131        self.assertFalse(
132            check_subgraphs_connected(
133                module_partitions[torch.nn.MaxPool2d][0],
134                module_partitions[torch.nn.ReLU][0],
135            )
136        )
137        self.assertTrue(
138            check_subgraphs_connected(
139                module_partitions[torch.nn.ReLU][0],
140                module_partitions[torch.nn.MaxPool2d][0],
141            )
142        )
143
144    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
145    def test_module_partitioner_functional_conv_relu_conv(self):
146        class FunctionalConv2d(torch.nn.Module):
147            def __init__(self) -> None:
148                super().__init__()
149                self.stride = (1, 1)
150                self.padding = (0, 0)
151                self.dilation = (1, 1)
152                self.groups = 1
153
154            def forward(self, x, weight, bias):
155                return torch.nn.functional.conv2d(
156                    x,
157                    weight,
158                    bias,
159                    self.stride,
160                    self.padding,
161                    self.dilation,
162                    self.groups,
163                )
164
165        class M(torch.nn.Module):
166            def __init__(self) -> None:
167                super().__init__()
168                self.conv1 = FunctionalConv2d()
169                self.conv2 = FunctionalConv2d()
170
171            def forward(self, x, weight, bias):
172                x = self.conv1(x, weight, bias)
173                x = torch.nn.functional.relu(x)
174                x = self.conv2(x, weight, bias)
175                return x
176
177        inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
178        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
179        gm.graph.eliminate_dead_code()
180
181        module_partitions = get_source_partitions(
182            gm.graph, [torch.nn.functional.conv2d]
183        )
184
185        self.assertEqual(len(module_partitions), 1)
186        self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
187
188    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
189    def test_module_partitioner_functional_linear_relu_linear(self):
190        class M(torch.nn.Module):
191            def __init__(self) -> None:
192                super().__init__()
193
194            def forward(self, x, weight, bias):
195                x = torch.nn.functional.linear(x, weight, bias)
196                x = torch.nn.functional.linear(x, weight, bias)
197                x = torch.nn.functional.relu(x)
198                x = torch.nn.functional.linear(x, weight, bias)
199                x = torch.nn.functional.linear(x, weight, bias)
200                x = torch.nn.functional.relu(x)
201                return x
202
203        inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
204        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
205        gm.graph.eliminate_dead_code()
206
207        module_partitions = get_source_partitions(
208            gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
209        )
210
211        self.assertEqual(len(module_partitions), 2)
212        self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
213        self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
214
215    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
216    def test_legalize_slice(self):
217        class M(torch.nn.Module):
218            def forward(self, x, y):
219                b = x.item()
220                torch._check_is_size(b)
221                torch._check(b + 1 < y.size(0))
222                return y[: b + 1]
223
224        ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
225        fake_inputs = [
226            node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
227        ]
228        gm = ep.module()
229        with fake_inputs[0].fake_mode:
230            torch.fx.Interpreter(gm).run(*fake_inputs)
231        legalized_gm = legalize_graph(gm)
232        with fake_inputs[0].fake_mode:
233            torch.fx.Interpreter(legalized_gm).run(*fake_inputs)
234
235    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
236    @parametrize("strict", (True, False))
237    def test_module_partitioner_linear_relu_linear_torch_fn_export(self, strict: bool):
238        class M(torch.nn.Module):
239            def __init__(self) -> None:
240                super().__init__()
241                self.linear1 = torch.nn.Linear(3, 3)
242                self.relu = torch.nn.ReLU()
243                self.linear2 = torch.nn.Linear(3, 5)
244
245            def forward(self, x):
246                x = self.linear1(x)
247                x = self.linear1(x)
248                x = self.relu(x)
249                x = self.linear2(x)
250                return x
251
252        inputs = (torch.randn(3, 3),)
253        gm = torch.export.export(M(), inputs, strict=strict).module()
254        gm.graph.eliminate_dead_code()
255
256        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
257        # TODO: remove this after we fix "torch_fn". T199561090
258        for node in gm.graph.nodes:
259            node.meta["source_fn_stack"] = None
260
261        module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
262
263        self.assertEqual(len(module_partitions), 2)
264        self.assertEqual(len(module_partitions["linear"]), 3)
265        self.assertEqual(len(module_partitions["relu"]), 1)
266
267        self.assertFalse(
268            check_subgraphs_connected(
269                module_partitions["linear"][0],
270                module_partitions["relu"][0],
271            )
272        )
273        self.assertTrue(
274            check_subgraphs_connected(
275                module_partitions["linear"][1],
276                module_partitions["relu"][0],
277            )
278        )
279        self.assertFalse(
280            check_subgraphs_connected(
281                module_partitions["linear"][2],
282                module_partitions["relu"][0],
283            )
284        )
285
286    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
287    @parametrize("strict", (True, False))
288    def test_module_partitioner_conv_relu_maxpool_torch_fn_export(self, strict: bool):
289        class M(torch.nn.Module):
290            def __init__(self, constant_tensor: torch.Tensor) -> None:
291                super().__init__()
292                self.constant_tensor = constant_tensor
293                self.conv1 = torch.nn.Conv2d(
294                    in_channels=3, out_channels=16, kernel_size=3, padding=1
295                )
296                self.conv2 = torch.nn.Conv2d(
297                    in_channels=16, out_channels=16, kernel_size=3, padding=1
298                )
299                self.conv3 = torch.nn.Conv2d(
300                    in_channels=16, out_channels=16, kernel_size=3, padding=1
301                )
302                self.relu = torch.nn.ReLU()
303                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
304
305            def forward(self, x: torch.Tensor) -> torch.Tensor:
306                a = self.conv1(x)
307                b = self.conv2(a)
308                c = a + self.constant_tensor
309                z = self.conv3(b + c)
310                return self.maxpool(self.relu(z))
311
312        inputs = (torch.randn(1, 3, 256, 256),)
313        gm = torch.export.export(
314            M(torch.ones(1, 16, 256, 256)), inputs, strict=strict
315        ).module()
316        gm.graph.eliminate_dead_code()
317
318        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
319        # TODO: remove this after we fix "torch_fn". T199561090
320        for node in gm.graph.nodes:
321            node.meta["source_fn_stack"] = None
322
323        module_partitions = get_source_partitions(
324            gm.graph, ["conv2d", "relu", "max_pool2d"]
325        )
326
327        self.assertEqual(len(module_partitions), 3)
328        self.assertEqual(len(module_partitions["conv2d"]), 3)
329        self.assertEqual(len(module_partitions["relu"]), 1)
330        self.assertEqual(len(module_partitions["max_pool2d"]), 1)
331
332        self.assertFalse(
333            check_subgraphs_connected(
334                module_partitions["conv2d"][0],
335                module_partitions["relu"][0],
336            )
337        )
338        self.assertFalse(
339            check_subgraphs_connected(
340                module_partitions["conv2d"][1],
341                module_partitions["relu"][0],
342            )
343        )
344        self.assertTrue(
345            check_subgraphs_connected(
346                module_partitions["conv2d"][2],
347                module_partitions["relu"][0],
348            )
349        )
350        self.assertFalse(
351            check_subgraphs_connected(
352                module_partitions["max_pool2d"][0],
353                module_partitions["relu"][0],
354            )
355        )
356        self.assertTrue(
357            check_subgraphs_connected(
358                module_partitions["relu"][0],
359                module_partitions["max_pool2d"][0],
360            )
361        )
362
363    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
364    @parametrize("strict", (True, False))
365    def test_module_partitioner_functional_conv_relu_conv_torch_fn_export(
366        self, strict: bool
367    ):
368        class FunctionalConv2d(torch.nn.Module):
369            def __init__(self) -> None:
370                super().__init__()
371                self.stride = (1, 1)
372                self.padding = (0, 0)
373                self.dilation = (1, 1)
374                self.groups = 1
375
376            def forward(self, x, weight, bias):
377                return torch.nn.functional.conv2d(
378                    x,
379                    weight,
380                    bias,
381                    self.stride,
382                    self.padding,
383                    self.dilation,
384                    self.groups,
385                )
386
387        class M(torch.nn.Module):
388            def __init__(self) -> None:
389                super().__init__()
390                self.conv1 = FunctionalConv2d()
391                self.conv2 = FunctionalConv2d()
392
393            def forward(self, x, weight, bias):
394                x = self.conv1(x, weight, bias)
395                x = torch.nn.functional.relu(x)
396                x = self.conv2(x, weight, bias)
397                return x
398
399        inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
400        gm = torch.export.export(M(), inputs, strict=strict).module()
401        gm.graph.eliminate_dead_code()
402
403        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
404        # TODO: remove this after we fix "torch_fn". T199561090
405        for node in gm.graph.nodes:
406            node.meta["source_fn_stack"] = None
407
408        module_partitions = get_source_partitions(gm.graph, ["conv2d"])
409
410        self.assertEqual(len(module_partitions), 1)
411        self.assertEqual(len(module_partitions["conv2d"]), 2)
412
413    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
414    @parametrize("strict", (True, False))
415    def test_module_partitioner_functional_linear_relu_linear_torch_fn_export(
416        self, strict: bool
417    ):
418        class M(torch.nn.Module):
419            def __init__(self) -> None:
420                super().__init__()
421
422            def forward(self, x, weight, bias):
423                x = torch.nn.functional.linear(x, weight, bias)
424                x = torch.nn.functional.linear(x, weight, bias)
425                x = torch.nn.functional.relu(x)
426                x = torch.nn.functional.linear(x, weight, bias)
427                x = torch.nn.functional.linear(x, weight, bias)
428                x = torch.nn.functional.relu(x)
429                return x
430
431        inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
432        gm = torch.export.export(M(), inputs, strict=strict).module()
433        gm.graph.eliminate_dead_code()
434
435        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
436        # TODO: remove this after we fix "torch_fn". T199561090
437        for node in gm.graph.nodes:
438            node.meta["source_fn_stack"] = None
439
440        module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
441
442        self.assertEqual(len(module_partitions), 2)
443        self.assertEqual(len(module_partitions["linear"]), 4)
444        self.assertEqual(len(module_partitions["relu"]), 2)
445
446
447instantiate_parametrized_tests(TestSourceMatcher)
448