xref: /aosp_15_r20/external/pytorch/test/fx/test_subgraph_rewriter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import os
4import sys
5
6import torch
7from torch.fx import subgraph_rewriter, symbolic_trace
8from torch.fx.annotate import annotate
9
10# Make the helper files in test/ importable
11from torch.fx.experimental.rewriter import RewritingTracer
12
13
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16from torch.testing._internal.jit_utils import JitTestCase
17
18
19if __name__ == "__main__":
20    raise RuntimeError(
21        "This test file is not meant to be run directly, use:\n\n"
22        "\tpython test/test_fx.py TESTNAME\n\n"
23        "instead."
24    )
25
26
27@torch.fx.wrap
28def wrapped_gemm_bias_mul(a, b, bias):
29    lin_res = torch.nn.functional.linear(a, b, bias=bias)
30    mul_res = lin_res * a
31    return lin_res, mul_res
32
33
34@torch.fx.wrap
35def wrapped_gemm_bias_mul_with_c(a, b, bias, c):
36    lin_res = torch.nn.functional.linear(a, b, bias=bias)
37    mul_res = lin_res * c
38    return lin_res, mul_res
39
40
41class TestSubgraphRewriter(JitTestCase):
42    def test_subgraph_rewriter_preserves_logic(self):
43        class M(torch.nn.Module):
44            def forward(self, x):
45                val = torch.neg(x) + torch.relu(x)
46                return torch.add(val, val)
47
48        def pattern(x):
49            return torch.neg(x) + torch.relu(x)
50
51        def comparison(x):
52            val = torch.neg(x) + torch.relu(x)
53            return torch.add(val, val)
54
55        traced = symbolic_trace(M())
56        comparison_fn = symbolic_trace(comparison)
57
58        x = torch.rand(1, 3)
59
60        # Replace `pattern` with the same pattern (shouldn't change
61        # the underlying logic)
62        subgraph_rewriter.replace_pattern(traced, pattern, pattern)
63
64        traced.graph.lint()
65
66        ref_output = comparison_fn(x)
67        test_output = traced.forward(x)
68        self.assertEqual(ref_output, test_output)
69
70    def test_subgraph_rewriter_with_oneliner_pattern(self):
71        class M(torch.nn.Module):
72            def forward(self, x):
73                val = torch.neg(x)
74                return torch.add(val, val)
75
76        def pattern(x):
77            return torch.neg(x)
78
79        def replacement(x):
80            return torch.relu(x)
81
82        def comparison(x):
83            val = torch.relu(x)
84            return torch.add(val, val)
85
86        traced = symbolic_trace(M())
87        comparison_fn = symbolic_trace(comparison)
88
89        x = torch.rand(1, 3)
90
91        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
92
93        traced.graph.lint()
94
95        ref_output = comparison_fn(x)
96        test_output = traced.forward(x)
97        self.assertEqual(ref_output, test_output)
98
99    def test_subgraph_rewriter_with_trivial_replacement(self):
100        class M(torch.nn.Module):
101            def forward(self, x):
102                val = torch.neg(x)
103                val = torch.add(val, val)
104                return torch.add(val, val)
105
106        def pattern(x):
107            return torch.add(x, x)
108
109        def replacement(x):
110            return x
111
112        def comparison(x):
113            return torch.neg(x)
114
115        traced = symbolic_trace(M())
116        comparison_fn = symbolic_trace(comparison)
117
118        x = torch.randn(1, 5)
119
120        matches = subgraph_rewriter.replace_pattern_with_filters(
121            traced, pattern, replacement, []
122        )
123
124        traced.graph.lint()
125
126        ref_output = comparison_fn(x)
127        test_output = traced.forward(x)
128        no_replacements = len(matches) == 2 and len(matches[1].replacements) == 0
129        self.assertEqual(ref_output, test_output)
130        self.assertTrue(no_replacements)
131
132    def test_subgraph_rewriter_single_pattern_match(self):
133        class M(torch.nn.Module):
134            def forward(self, x):
135                val = torch.neg(x) + torch.relu(x)
136                return torch.add(val, val)
137
138        def pattern(x):
139            return torch.neg(x) + torch.relu(x)
140
141        def replacement(x):
142            return torch.relu(x)
143
144        def comparison(x):
145            val = torch.relu(x)
146            return torch.add(val, val)
147
148        traced = symbolic_trace(M())
149        comparison_fn = symbolic_trace(comparison)
150
151        x = torch.rand(1, 3)
152
153        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
154
155        traced.graph.lint()
156
157        ref_output = comparison_fn(x)
158        test_output = traced.forward(x)
159        self.assertEqual(ref_output, test_output)
160
161    def test_subgraph_rewriter_multiple_pattern_match(self):
162        class M(torch.nn.Module):
163            def forward(self, x, w1, w2):
164                m1 = torch.cat([w1, w2]).sum()
165                m2 = torch.cat([w1, w2]).sum()
166                return x + torch.max(m1) + torch.max(m2)
167
168        def pattern(w1, w2):
169            return torch.cat([w1, w2]).sum()
170
171        def replacement(w1, w2):
172            return torch.stack([w1, w2])
173
174        def comparison(x, w1, w2):
175            m1 = torch.stack([w1, w2])
176            m2 = torch.stack([w1, w2])
177            return x + torch.max(m1) + torch.max(m2)
178
179        traced = symbolic_trace(M())
180        comparison_fn = symbolic_trace(comparison)
181
182        x = torch.rand(1, 3)
183        w1 = torch.rand(1, 3)
184        w2 = torch.rand(1, 3)
185
186        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
187
188        traced.graph.lint()
189
190        ref_outs = comparison_fn(x, w1, w2)
191        test_outs = traced.forward(x, w1, w2)
192        self.assertEqual(ref_outs, test_outs)
193
194    def test_subgraph_rewriter_graph_argument_order(self):
195        class M(torch.nn.Module):
196            def forward(self, x, y):
197                return torch.mm(x, y)
198
199        def pattern(x, y):
200            return torch.mm(x, y)
201
202        def comparison(x, y):
203            return torch.mm(x, y)
204
205        traced = symbolic_trace(M())
206        comparison_fn = symbolic_trace(comparison)
207
208        x = torch.randn(3, 4)
209        y = torch.randn(4, 5)
210
211        subgraph_rewriter.replace_pattern(traced, pattern, pattern)
212
213        traced.graph.lint()
214
215        ref_outs = comparison_fn(x, y)
216        test_outs = traced.forward(x, y)
217        self.assertEqual(ref_outs, test_outs)
218
219    def test_subgraph_rewriter_correct_output_replacement(self):
220        class M(torch.nn.Module):
221            def forward(self, x, y):
222                val = torch.neg(y) + torch.relu(x)
223                return torch.add(val, val)
224
225        def pattern(x):
226            return torch.relu(x)
227
228        def replacement(x):
229            return torch.neg(x)
230
231        def comparison(x, y):
232            val = torch.neg(y) + torch.neg(x)
233            return torch.add(val, val)
234
235        traced = symbolic_trace(M())
236        comparison_fn = symbolic_trace(comparison)
237
238        x = torch.randn(4, 4)
239        y = torch.randn(4, 4)
240
241        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
242
243        traced.graph.lint()
244
245        ref_outs = comparison_fn(x, y)
246        test_outs = traced.forward(x, y)
247        self.assertEqual(ref_outs, test_outs)
248
249    def test_subgraph_rewriter_traced_as_callable(self):
250        class M(torch.nn.Module):
251            def forward(self, x):
252                val = torch.neg(x) + torch.relu(x)
253                return torch.add(val, val)
254
255        class Pattern(torch.nn.Module):
256            def forward(self, x):
257                return torch.neg(x) + torch.relu(x)
258
259        class Replacement(torch.nn.Module):
260            def forward(self, x):
261                return torch.sigmoid(x)
262
263        def comparison(x):
264            val = torch.sigmoid(x)
265            return torch.add(val, val)
266
267        traced = symbolic_trace(M())
268        traced_pattern = symbolic_trace(Pattern())
269        traced_replacement = symbolic_trace(Replacement())
270        comparison_fn = symbolic_trace(comparison)
271
272        x = torch.randn(3, 4)
273
274        subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement)
275
276        traced.graph.lint()
277
278        ref_outs = comparison_fn(x)
279        test_outs = traced.forward(x)
280        self.assertEqual(ref_outs, test_outs)
281
282    def test_subgraph_rewriter_pattern_is_entire_graph(self):
283        class M(torch.nn.Module):
284            def forward(self, x):
285                a = torch.neg(x)
286                return torch.add(a, a)
287
288        def pattern(x):
289            a = torch.neg(x)
290            return torch.add(a, a)
291
292        def replacement(x):
293            a = torch.sigmoid(x)
294            return torch.cat([a, a])
295
296        traced = symbolic_trace(M())
297        comparison_fn = symbolic_trace(replacement)
298
299        x = torch.randn(3, 4)
300
301        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
302
303        traced.graph.lint()
304
305        ref_outs = comparison_fn(x)
306        test_outs = traced.forward(x)
307        self.assertEqual(ref_outs, test_outs)
308
309    def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(
310        self,
311    ):
312        class M(torch.nn.Module):
313            def forward(self, x):
314                y = torch.relu(x)
315                return torch.neg(y) - y
316
317        def pattern(x):
318            return torch.relu(x)
319
320        def replacement(x):
321            return torch.sigmoid(x)
322
323        def comparison(x):
324            y = torch.sigmoid(x)
325            return torch.neg(y) - y
326
327        traced = symbolic_trace(M())
328        comparison_fn = symbolic_trace(comparison)
329
330        x = torch.randn(3, 4)
331
332        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
333
334        traced.graph.lint()
335
336        ref_outs = comparison_fn(x)
337        test_outs = traced.forward(x)
338        self.assertEqual(ref_outs, test_outs)
339
340    def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(
341        self,
342    ):
343        class M(torch.nn.Module):
344            def forward(self, x, w1, w2, b1, b2):
345                m0 = torch.cat([w1, w2])
346                m1 = torch.cat([w1, w2])
347                m2 = torch.cat([x, b2])
348                t0 = torch.addmm(b1, m1, m2.t())
349                t1 = torch.sum(w1, 1)
350                t2 = torch.addmm(b1, m1, m2.t())
351                return torch.sum(t1), torch.sum(t2)
352
353        def pattern(x, w1, w2, b1, b2):
354            m1 = torch.cat([w1, w2])
355            m2 = torch.cat([x, b2])
356            return torch.addmm(b1, m1, m2.t())
357
358        def replacement(x, w1, w2, b1, b2):
359            return torch.cat([x, w1, w2])
360
361        traced = symbolic_trace(M())
362
363        # Result should be [] since no matches can be found
364        res = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
365
366        traced.graph.lint()
367
368        self.assertEqual(res, [])
369
370    def test_subgraph_rewriter_placeholder_matching(self):
371        """
372        This tests that a placeholder Node can be matched to a Node with
373        a different number of input Nodes. In the example below, the
374        original traced Module looks like this:
375
376            opcode         target                                                      args                      kwargs
377            -------------  ----------------------------------------------------------  ------------------------  --------
378            placeholder    x                                                           ()                        {}
379            call_function  <built-in function add>                                     (x, 3)                    {}
380            call_method    dequantize                                                  (add,)                    {}
381            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
382            call_method    to                                                          (sigmoid, torch.float16)  {}
383            output         output                                                      (to,)                     {}
384
385        while the pattern we want to match looks like this:
386
387            opcode         target                                                      args                      kwargs
388            -------------  ----------------------------------------------------------  ------------------------  --------
389            placeholder    x                                                           ()                        {}
390            call_method    dequantize                                                  (x,)                      {}
391            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
392            call_method    to                                                          (sigmoid, torch.float16)  {}
393            output         output                                                      (to,)                     {}
394
395        Here, we want to be able to match the original graph's
396        `call_function.add` Node with the pattern graph's
397        `placeholder.x` Node.
398
399        Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
400        """
401
402        class M(torch.nn.Module):
403            def __init__(self) -> None:
404                super().__init__()
405                self.dtype = torch.float16
406
407            def forward(self, x):
408                x += 3
409                x = x.dequantize()
410                x = torch.sigmoid(x)
411                dtype = self.dtype
412                x = x.to(dtype)
413                return x
414
415        def pattern(x):
416            x = x.dequantize()
417            x = torch.sigmoid(x)
418            x = x.to(torch.float16)
419            return x
420
421        def replacement(x):
422            return x
423
424        def comparison(x):
425            return x + 3
426
427        traced = symbolic_trace(M())
428        comparison_fn = symbolic_trace(comparison)
429
430        x = torch.randn(3, 4)
431
432        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
433
434        traced.graph.lint()
435
436        ref_outs = comparison_fn(x)
437        test_outs = traced.forward(x)
438        self.assertEqual(ref_outs, test_outs)
439
440    def test_subgraph_rewriter_replaces_referenced_submodules(self):
441        class M(torch.nn.Module):
442            def __init__(self) -> None:
443                super().__init__()
444                self.sigmoid = torch.nn.Sigmoid()
445                self.submod = torch.nn.ReLU()
446
447            def forward(self, x):
448                x = x + 1
449                return self.submod(self.sigmoid(x))
450
451        class Pattern(torch.nn.Module):
452            def __init__(self) -> None:
453                super().__init__()
454                self.sigmoid = torch.nn.Sigmoid()
455                self.submod = torch.nn.ReLU()
456
457            def forward(self, x):
458                return self.submod(self.sigmoid(x))
459
460        class Replacement(torch.nn.Module):
461            def __init__(self) -> None:
462                super().__init__()
463                self.tanh = torch.nn.Tanh()
464                self.submod = torch.nn.ReLU()
465
466            def forward(self, x):
467                return self.submod(self.tanh(x))
468
469        class Comparison(torch.nn.Module):
470            def __init__(self) -> None:
471                super().__init__()
472                self.tanh = torch.nn.Tanh()
473                self.submod = torch.nn.ReLU()
474
475            def forward(self, x):
476                x = x + 1
477                return self.submod(self.tanh(x))
478
479        traced = symbolic_trace(M())
480        comparison = Comparison()
481
482        x = torch.randn(3, 4)
483
484        subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement())
485
486        traced.graph.lint()
487
488        ref_outs = comparison(x)
489        test_outs = traced.forward(x)
490        self.assertEqual(ref_outs, test_outs)
491
492        traced.get_submodule("tanh")
493        with self.assertRaisesRegex(AttributeError, "has no attribute"):
494            traced.get_submodule("sigmoid")
495
496        submod = traced.get_submodule("submod")
497        self.assertEqual(type(submod), torch.nn.ReLU)
498
499    def test_subgraph_rewriter_annotations_int(self):
500        class M1(torch.nn.Module):
501            def forward(self, x):
502                y: int = x
503                return torch.add(x, y)
504
505        class M2(torch.nn.Module):
506            def forward(self, x):
507                y = annotate(x, int)
508                return torch.add(x, y)
509
510        ast_rewriter = RewritingTracer()
511        graph = ast_rewriter.trace(M1())
512
513        module = M2()
514        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
515        for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
516            if n.op == "placeholder":
517                assert n.type == int
518                assert m.type == int
519
520    def test_subgraph_rewriter_replace_consecutive_submodules(self):
521        def f(x):
522            x = torch.sigmoid(x)
523            x = torch.sigmoid(x)
524            return torch.sigmoid(x)
525
526        def pattern(x):
527            return torch.sigmoid(x)
528
529        def replacement(x):
530            return torch.exp(x)
531
532        def comparison(x):
533            x = torch.exp(x)
534            x = torch.exp(x)
535            return torch.exp(x)
536
537        traced = symbolic_trace(f)
538        comparison_fn = symbolic_trace(comparison)
539
540        x = torch.randn(3, 4)
541
542        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
543
544        traced.graph.lint()
545
546        ref_outs = comparison_fn(x)
547        test_outs = traced.forward(x)
548        self.assertEqual(ref_outs, test_outs)
549
550    def test_subgraph_rewriter_with_overlapping_matches(self):
551        def f(x):
552            x = torch.sigmoid(x)
553            x = torch.sigmoid(x)
554            x = torch.sigmoid(x)
555            return torch.sigmoid(x)
556
557        def pattern(x):
558            x = torch.sigmoid(x)
559            x = torch.sigmoid(x)
560            return x
561
562        def replacement(x):
563            return torch.neg(x)
564
565        def comparison(x):
566            x = torch.neg(x)
567            return torch.neg(x)
568
569        traced = symbolic_trace(f)
570        comparison_fn = symbolic_trace(comparison)
571
572        x = torch.randn(3, 4)
573
574        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
575
576        traced.graph.lint()
577
578        ref_outs = comparison_fn(x)
579        test_outs = traced.forward(x)
580        self.assertEqual(ref_outs, test_outs)
581
582    def test_subgraph_rewriter_replace_with_multiple_outputs(self):
583        def f(x):
584            y = torch.sigmoid(x)
585            z = torch.relu(x)
586            return y + z
587
588        def pattern(a):
589            b = torch.sigmoid(a)
590            c = torch.relu(a)
591            return b, c
592
593        def replacement(x):
594            return torch.exp(x), torch.abs(x)
595
596        def comparison(x):
597            y = torch.exp(x)
598            z = torch.abs(x)
599            return y + z
600
601        traced = symbolic_trace(f)
602        comparison_fn = symbolic_trace(comparison)
603
604        x = torch.randn(3, 4)
605
606        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
607
608        traced.graph.lint()
609
610        ref_outs = comparison_fn(x)
611        test_outs = traced.forward(x)
612        self.assertEqual(ref_outs, test_outs)
613
614    def test_subgraph_rewriter_replace_with_duplicated_outputs(self):
615        def f(x1, x2):
616            x = x1 - x2
617            y = torch.sigmoid(x)
618            z = torch.relu(x)
619            return y + z
620
621        def pattern(a1, a2):
622            a = a1 - a2
623            b = torch.sigmoid(a)
624            c = torch.relu(a)
625            return b, c, a
626
627        def replacement(x1, x2):
628            y1 = torch.exp(x1)
629            y2 = torch.abs(x2)
630            return y2, y2, y1
631
632        def comparison(x1, x2):
633            y2 = torch.abs(x2)
634            return y2 + y2
635
636        traced = symbolic_trace(f)
637        comparison_fn = symbolic_trace(comparison)
638
639        x1 = torch.randn(3, 4)
640        x2 = torch.randn(3, 4)
641
642        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
643
644        traced.graph.lint()
645
646        ref_outs = comparison_fn(x1, x2)
647        test_outs = traced.forward(x1, x2)
648        self.assertEqual(ref_outs, test_outs)
649
650    def test_subgraph_rewriter_with_unused_args(self):
651        class M(torch.nn.Module):
652            def forward(self, x, y, z):
653                return x + y
654
655        def pattern(x, y):
656            return x + y
657
658        def replacement(x, y):
659            return x - y
660
661        def comparison(x1, x2, x3):
662            return x1 - x2
663
664        traced = symbolic_trace(M())
665        comparison_fn = symbolic_trace(comparison)
666
667        x1 = torch.randn(3, 4)
668        x2 = torch.randn(3, 4)
669        x3 = torch.randn(3, 4)
670
671        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
672
673        traced.graph.lint()
674        placeholder_nodes = [n for n in traced.graph.nodes if n.op == "placeholder"]
675        assert len(placeholder_nodes) == 3
676
677        ref_outs = comparison_fn(x1, x2, x3)
678        test_outs = traced.forward(x1, x2, x3)
679        self.assertEqual(ref_outs, test_outs)
680
681    def test_subgraph_rewriter_call_method(self):
682        class M(torch.nn.Module):
683            def forward(self, x):
684                x = x.dequantize()
685                x = x.sigmoid()
686                x = x.to(torch.float16)
687                return x
688
689        def pattern(x):
690            x = x.dequantize()
691            x = x.sigmoid()
692            x = x.to(torch.float16)
693            return x
694
695        def replacement(x):
696            return x
697
698        traced = symbolic_trace(M())
699        comparison_fn = symbolic_trace(replacement)
700
701        x1 = torch.randn(3, 4)
702
703        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
704
705        traced.graph.lint()
706
707        ref_outs = comparison_fn(x1)
708        test_outs = traced.forward(x1)
709        self.assertEqual(ref_outs, test_outs)
710
711    def test_subgraph_rewriter_nodes_with_kwargs(self):
712        class M(torch.nn.Module):
713            def __init__(self) -> None:
714                super().__init__()
715                self.w0 = torch.nn.Parameter(torch.empty([128, 128]))
716                self.b0 = torch.nn.Parameter(torch.empty([128]))
717
718            def forward(self, in0):
719                lin_res = torch.nn.functional.linear(in0, self.w0, bias=self.b0)
720                mul_res = in0 * lin_res
721                sum_res = mul_res + in0
722                return sum_res
723
724        def pattern(a, b, bias):
725            lin_res = torch.nn.functional.linear(a, b, bias=bias)
726            mul_res = a * lin_res
727            return lin_res, mul_res
728
729        def replacement(a, b, bias):
730            lin_res, mul_res = wrapped_gemm_bias_mul(a, b, bias)
731            return lin_res, mul_res
732
733        traced = symbolic_trace(M())
734        matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
735
736        self.assertEqual(len(matches), 1)
737
738        found_repalcement_node = False
739        for node in traced.graph.nodes:
740            if node.target == wrapped_gemm_bias_mul:
741                found_repalcement_node = True
742                break
743
744        self.assertTrue(found_repalcement_node)
745
746    def test_subgraph_rewriter_local_revert(self):
747        # Following model will have 3 anchors as the matching candidate with the given pattern
748        # Anchor 1 and 3 is a real match, but anchor 2 is not.
749        # The subgraph rewriter should be able to revert the changes made while matching anchor 2.
750        # Final match with anchor 3 should be successful.
751
752        class M(torch.nn.Module):
753            def __init__(self) -> None:
754                super().__init__()
755                self.w0 = torch.nn.Parameter(torch.empty([128, 128]))
756                self.b0 = torch.nn.Parameter(torch.empty([128]))
757                self.w1 = torch.nn.Parameter(torch.empty([128, 128]))
758                self.b1 = torch.nn.Parameter(torch.empty([128]))
759                self.w2 = torch.nn.Parameter(torch.empty([128, 128]))
760                self.b2 = torch.nn.Parameter(torch.empty([128]))
761                self.w3 = torch.nn.Parameter(torch.empty([128, 128]))
762                self.b3 = torch.nn.Parameter(torch.empty([128]))
763                self.w4 = torch.nn.Parameter(torch.empty([128, 128]))
764                self.b4 = torch.nn.Parameter(torch.empty([128]))
765
766            def forward(self, in0, in1):
767                lin_res_1 = torch.nn.functional.linear(in1, self.w0, bias=self.b0)
768                lin_res_2 = torch.nn.functional.linear(lin_res_1, self.w1, bias=self.b1)
769                # potential match at anchor 1
770                mul_res_1 = in1 * lin_res_2
771                sum_res_1 = mul_res_1 + in1
772                lin_res_3 = torch.nn.functional.linear(sum_res_1, self.w2, bias=self.b2)
773                sigmoid_res_1 = torch.sigmoid(lin_res_3)
774                # potential match at anchor 2
775                mul_res_2 = lin_res_3 * sigmoid_res_1
776                lin_res_4 = torch.nn.functional.linear(in0, self.w3, bias=self.b3)
777                lin_res_5 = torch.nn.functional.linear(lin_res_4, self.w4, bias=self.b4)
778                # potential match at anchor 3
779                mul_res_3 = in0 * lin_res_5
780                sum_res_2 = mul_res_3 + in0
781                cat_res = torch.cat(
782                    [mul_res_2, sum_res_2],
783                    dim=1,
784                )
785                return cat_res
786
787        def gemm_bias_mul_pattern_with_c(a, b, bias, c):
788            lin_res = torch.nn.functional.linear(a, b, bias=bias)
789            mul_res = c * lin_res
790            return lin_res, mul_res
791
792        def gemm_bias_mul_replacement_with_c(a, b, bias, c):
793            lin_res, mul_res = wrapped_gemm_bias_mul_with_c(a, b, bias, c)
794            return lin_res, mul_res
795
796        traced = symbolic_trace(M())
797        matches = subgraph_rewriter.replace_pattern(
798            traced, gemm_bias_mul_pattern_with_c, gemm_bias_mul_replacement_with_c
799        )
800
801        self.assertEqual(len(matches), 2)
802
803        repalcement_node_found = 0
804        for node in traced.graph.nodes:
805            if node.target == wrapped_gemm_bias_mul_with_c:
806                repalcement_node_found += 1
807
808        self.assertEqual(repalcement_node_found, 2)
809
810    def test_replace_pattern_with_filters(self):
811        class M(torch.nn.Module):
812            def forward(self, x, scale, zero_point):
813                # Match, second input to add is a scalar
814                x = x.dequantize()
815                x = torch.add(x, 2)
816                x = x.relu()
817                x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
818
819                y = x + 1
820                # NOT a match, second input to add is NOT a scalar
821                x = x.dequantize()
822                x = torch.add(x, y)
823                x = x.relu()
824                x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
825
826                return x
827
828        def BinaryOpScalarReLUPattern(x, num, scale, zero_point):
829            x = x.dequantize()
830            x = torch.add(x, num)
831            x = x.relu()
832            x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
833            return x
834
835        def BinaryOpScalarReLUReplacement(x, num, scale, zero_point):
836            x = torch.mul(x, num)
837            return x
838
839        def second_input_is_scalar(match, original_graph, pattern_graph):
840            """check the node that's matched to the second input of the pattern graph
841            is a scalar number
842            """
843            input_idx = 0
844            for node in pattern_graph.nodes:
845                if node.op == "placeholder":
846                    if input_idx == 1:
847                        num_node = node
848                    input_idx += 1
849            return isinstance(match.nodes_map[num_node], (int, float))
850
851        def check_replacement_nodes(self, traced, matches):
852            replacement_nodes_in_graph = [
853                node for node in traced.graph.nodes if node.target == torch.mul
854            ]
855            replacement_nodes_in_res = [r for m in matches for r in m.replacements]
856            self.assertEqual(
857                len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
858            )
859            self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
860            return len(replacement_nodes_in_graph)
861
862        # match without filter, should find 2 match
863        traced = symbolic_trace(M())
864        matches = subgraph_rewriter.replace_pattern_with_filters(
865            traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, None
866        )
867        self.assertEqual(len(matches), 2)
868        self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
869
870        # match with filter, should find 1 match
871        traced = symbolic_trace(M())
872        matches = subgraph_rewriter.replace_pattern_with_filters(
873            traced,
874            BinaryOpScalarReLUPattern,
875            BinaryOpScalarReLUReplacement,
876            [second_input_is_scalar],
877        )
878        self.assertEqual(len(matches), 1)
879        self.assertEqual(check_replacement_nodes(self, traced, matches), 1)
880
881    def test_matching_pattern_with_list_type_arg(self):
882        class M(torch.nn.Module):
883            def forward(self, x):
884                return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4])
885
886        def pattern(x, arg0, arg1):
887            return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
888
889        def replacement(x, arg0, arg1):
890            return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0)
891
892        traced = symbolic_trace(M())
893        matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
894
895        self.assertEqual(len(matches), 1)
896
897        self.assertExpectedInline(
898            traced.code.strip(),
899            """\
900def forward(self, x):
901    _reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]);  x = None
902    return _reshape_alias_copy_default_1""",
903        )  # noqa: B950
904
905    def test_replacement_with_attrs(self):
906        class M(torch.nn.Module):
907            def __init__(self) -> None:
908                super().__init__()
909                self.a = torch.tensor([1])
910                self.b = torch.tensor([2])
911
912            def forward(self, x):
913                return x + self.a - self.b
914
915        class Pattern(torch.nn.Module):
916            def __init__(self) -> None:
917                super().__init__()
918                self.a = torch.tensor([1])
919
920            def forward(self, x):
921                return x + self.a
922
923        class Replacement(torch.nn.Module):
924            def __init__(self) -> None:
925                super().__init__()
926                self.c = torch.tensor([3])
927
928            def forward(self, x):
929                return x - self.c
930
931        traced = symbolic_trace(M())
932        matches = subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement())
933        self.assertEqual(len(matches), 1)
934
935    def test_matching_variable_arguments(self):
936        class M(torch.nn.Module):
937            def forward(self, x):
938                return torch.ops.aten.max_pool2d_with_indices.default(
939                    x, [2, 2], stride=[2, 2]
940                )
941
942        def pattern(x, kernel_size, stride):
943            # default padding is [0, 0]
944            return torch.ops.aten.max_pool2d_with_indices.default(
945                x, kernel_size, stride, padding=[0, 0]
946            )
947
948        traced = symbolic_trace(M())
949        matches = subgraph_rewriter.replace_pattern(traced, pattern, pattern)
950
951        self.assertEqual(len(matches), 1)
952
953    def test_replaced_nodes(self):
954        class M(torch.nn.Module):
955            def forward(self, x, y):
956                return torch.add(x, y)
957
958        def pattern(x, y):
959            return torch.add(x, y)
960
961        def replacement(x, y):
962            return torch.sub(torch.mul(x, y), y)
963
964        traced = symbolic_trace(M())
965        matches = subgraph_rewriter.replace_pattern_with_filters(
966            traced, pattern, replacement
967        )
968
969        def check_replacement_nodes(self, traced, matches):
970            replacement_nodes_in_graph = [
971                node
972                for node in traced.graph.nodes
973                if node.target in {torch.sub, torch.mul}
974            ]
975            replacement_nodes_in_res = [r for m in matches for r in m.replacements]
976            self.assertEqual(
977                len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
978            )
979            self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
980            return len(replacement_nodes_in_graph)
981
982        self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
983