xref: /aosp_15_r20/external/pytorch/test/test_fx_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx.passes"]
2
3from dataclasses import dataclass
4import operator
5import logging
6import sys
7
8import torch
9from torch.fx._symbolic_trace import symbolic_trace
10
11from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
12from torch.fx.passes.operator_support import OperatorSupport
13from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
14from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
15
16from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
17from torch.testing._internal.jit_utils import JitTestCase
18
19logging.basicConfig(level=logging.WARNING)
20logger = logging.getLogger(__name__)
21
22class TestModule(torch.nn.Module):
23    def __init__(self) -> None:
24        super().__init__()
25        self.linear = torch.nn.Linear(4, 4)
26        self.linear2 = torch.nn.Linear(4, 4)
27        self.param = torch.nn.Parameter(torch.rand(4, 4))
28
29    def forward(self, a, b, c):
30        add = a + b
31
32        linear_1 = self.linear(add)
33
34        add_1 = add + c
35        add_2 = add_1 + self.param
36        add_3 = add_1 + linear_1
37        add_4 = add_2 + add_3
38
39        linear_2 = self.linear2(add_4)
40
41        add_5 = linear_2 + add_4
42        add_6 = add_5 + a
43        relu = add_6.relu()
44
45        return add_4, add_6, relu
46
47class TestDeepModule(torch.nn.Module):
48    def __init__(self) -> None:
49        super().__init__()
50        self.linear = torch.nn.Linear(4, 4)
51
52    def forward(self, a, b, c):
53        o = a + b
54        o = o + 1.0
55
56        # testing to avoid DFS uses in passes. Since Python has max recursion depth.
57        for _ in range(sys.getrecursionlimit() + 1):
58            o = o - c
59
60        return o
61
62
63class TestPartitionFunctions:
64    @staticmethod
65    def forward1(a, b, c):
66        add = a + b
67        add_1 = add + b
68        add_2 = add_1 + c
69        relu_1 = add_2.relu()
70        add_3 = add_1 + add_2
71        add_4 = add_1 + relu_1 + add_3
72        relu_2 = add_4.relu()
73        add_5 = relu_2 + add_4
74        add_6 = add_5 + add_4
75        return add_4, add_6
76
77    @staticmethod
78    def forward2(a, b, _):
79        add = a + b
80        add_1 = add + b
81        relu_1 = add_1.relu()  # blocked by this
82        add_3 = add_1 + relu_1
83        add_4 = add_1 + add_3
84        return add_4, add_1
85
86    @staticmethod
87    def forward3(a, b, c):
88        add = a + b
89        add_1 = a + c
90        add_2 = b + c
91        return add, add_1, add_2
92
93    @staticmethod
94    def forward4(a, b, c):
95        add = a + b
96        add_1 = a + c
97        add_2 = b + c
98        return torch.where(add > 0, add_1, add_2)
99
100    @staticmethod
101    def forward5(a, b, c):
102        # add should be fused right branch, as left branch is not supported
103        add = a + 1
104        # left branch
105        relu = add.relu()
106        # right branch
107        add_1 = add + 2
108        return relu, add_1
109
110    @staticmethod
111    def forward6(a, b, c):
112        # add should have its own partition, as neither branchs are supported
113        add = a + 1
114        # left branch
115        relu = add.relu()
116        # right branch
117        relu_1 = add.relu()
118        return relu, relu_1
119
120    @staticmethod
121    def forward7(a, b, c):
122        # both branches are supported, all adds should be fused together
123        add = a + 1
124        # left branch
125        add_1 = add + 2
126        # right branch is larger
127        add_2 = add + 1
128        add_3 = add_2 + 1
129        return add_3, add_1
130
131    @staticmethod
132    def forward8(a, b, c):
133        # both branches are in the same partition, add should join the same partition
134        add = a + 1
135        # left branch
136        add_1 = add + 2
137        # right branch
138        add_2 = add + 1
139        # left and right branch merges
140        add_3 = add_2 + add_1
141
142        return add_3
143
144    @staticmethod
145    def forward9(a, b, c):
146        add = a + 1
147        # branch 1
148        add_1 = add + 1
149        # branch 2
150        add_2 = add + 1
151        # branch_3
152        add_3 = add + 1
153        out = torch.stack([add_1, add_2, add_3])
154        return out
155
156    @staticmethod
157    def forward10(a, b, c):
158        add = a + 1
159        # branch 1
160        add_1 = add + 1
161        # branch 2
162        add_2 = add + 1
163        # branch 3: depends on branch 2
164        add_3 = add + add_2
165        out = torch.stack([add_1, add_2, add_3])
166        return out
167
168    @staticmethod
169    def forward11(a, b, c):
170        add = a + 1
171        # branch 1
172        add_1 = add.relu()
173        # branch 2 depends on branch 1
174        add_2 = add + add_1
175        # branch 3
176        add_3 = add.relu()
177        out = torch.stack([add_1, add_2, add_3])
178        return out
179
180    @staticmethod
181    def forward12(a, b, c):
182        b0 = a + 1.0
183        c0 = a + 1.5
184        x0 = b0.relu()
185        x1 = c0.relu()
186        b1 = b0 + x1
187        c1 = c0 + 1.2
188        # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
189        # this dependency should be updated to the fusion group and reflected
190        # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
191        # the new graph
192        c2 = x0 + c0
193        return b1, c2
194
195    @staticmethod
196    def forward13(a, b, c):
197        a0, a1, a2, a3 = a.split(1, 0)
198        b1 = a0 + b
199        c1 = a1 + c
200        return b1 + c1
201
202    @staticmethod
203    def forward14(a, b, c):
204        a0, a1 = torch.ops.aten.std_mean(a)
205        out = a0 + 1.0
206        return out
207
208    @staticmethod
209    def forward15(a, b, c):
210        a0 = torch.ops.aten.view(a, [2, 2])
211        a1 = torch.ops.aten.permute(a0, [1, 0])
212        a2 = a1 + 1.0
213        a3 = torch.ops.aten.permute(a2, [1, 0])
214        a4 = a3 + 1.0
215        a5 = torch.ops.aten.permute(a4, [1, 0])
216        return torch.ops.aten.permute(a5, [1, 0])
217
218    @staticmethod
219    def forward16(a, b, c):
220        a0 = a - 1.0
221        a1 = torch.ops.aten.view(a0, [2, 2])
222        a2 = torch.ops.aten.permute(a1, [1, 0])
223        a3 = a2 + 1.0
224        a4 = torch.ops.aten.permute(a3, [1, 0])
225        a5 = a4 + 1.0
226        a6 = torch.ops.aten.permute(a5, [1, 0])
227        a7 = torch.ops.aten.permute(a6, [1, 0])
228        return a7 - 1.0
229
230    @staticmethod
231    def forward17(a, b, c, d, e, f):
232        a0 = a + b
233        a1 = c + d
234        a2 = e + f
235        return a0, a1, a2
236
237    @staticmethod
238    def forward18(a, b, c):
239        a0, a1 = torch.ops.aten.var_mean(a)
240        return a0
241
242# A mock OperatorSupport class, where only operator.add is supported
243class MockOperatorSupport(OperatorSupport):
244    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
245        return (node.op == "call_function" and
246                node.target in {operator.add, operator.getitem,
247                                torch.ops.aten.view,
248                                torch.ops.aten.permute,
249                                torch.ops.aten.std_mean})
250
251@instantiate_parametrized_tests
252class TestFXGraphPasses(JitTestCase):
253
254    @parametrize("fn, expected_partition, bookend_non_compute_pass", [
255        (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
256        (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),
257
258        # 1 horizontal fusion with common producer
259        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
260        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),
261
262        # 2 branches cases
263        (TestPartitionFunctions.forward5, [["add_1", "add"]], False),
264        (TestPartitionFunctions.forward6, [["add"]], False),
265        (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
266        (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),
267
268        # 3 branch cases
269        (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
270        (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
271        (TestPartitionFunctions.forward11, [['add_1'], ['add']], False),
272
273        # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
274        (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),
275
276        # 5 getitem special case
277        (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
278        (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),
279
280        # 6 bookend non_compute pass
281        (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
282        (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
283        (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
284        (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
285        # should be empty partition, not a partiton with empty nodes
286        (TestPartitionFunctions.forward18, [], False),
287    ])
288    def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
289        traced = symbolic_trace(fn)
290
291        non_compute_ops = []
292        if bookend_non_compute_pass:
293            non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]
294
295        supported_ops = MockOperatorSupport()
296        partitioner = CapabilityBasedPartitioner(traced,
297                                                 supported_ops,
298                                                 allows_single_node_partition=True,
299                                                 non_compute_ops=non_compute_ops)
300        partitions = partitioner.propose_partitions()
301        if bookend_non_compute_pass:
302            partitioner.remove_bookend_non_compute_ops(partitions)
303
304        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
305        assert len(partitions_name) == len(expected_partition)
306        for i in range(len(partitions_name)):
307            assert set(partitions_name[i]) == set(expected_partition[i])
308
309        fused_graph = partitioner.fuse_partitions(partitions)
310
311        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
312
313        expected = fn(a, b, c)
314        result = fused_graph(a, b, c)
315        torch.testing.assert_close(expected, result)
316
317    @parametrize("fn, expected_partition", [
318        (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
319    ])
320    def test_partitioner_independent_output(self, fn, expected_partition):
321        traced = symbolic_trace(fn)
322
323        supported_ops = MockOperatorSupport()
324        partitioner = CapabilityBasedPartitioner(traced,
325                                                 supported_ops,
326                                                 allows_single_node_partition=True)
327        partitions = partitioner.propose_partitions()
328        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
329        assert len(partitions_name) == len(expected_partition)
330        for i in range(len(partitions_name)):
331            assert set(partitions_name[i]) == set(expected_partition[i])
332
333        fused_graph = partitioner.fuse_partitions(partitions)
334
335        a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)
336
337        expected = fn(a, b, c, d, e, f)
338        result = fused_graph(a, b, c, d, e, f)
339        torch.testing.assert_close(expected, result)
340
341    @parametrize("partition", [
342        [['add', 'add_1'], ['add_5', 'add_6']],
343        [['add', 'add_1', 'add_2']],  # vertical fusion
344        [['add_2', 'add_3']],         # horizontal fusion
345        [['add_3', 'add_4']],
346        [['add_6', 'add_5']],     # arbitray node order
347        [['add_4', 'add_1', 'add_3', 'add_2']],           # arbitray node order
348        [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']],  # arbitray partition order
349        [['add_5', 'linear2']],   # includes call_function + call_module node
350        [['add_6', 'relu']],   # includes call_function + call_module node
351        [['param', 'add_2']],   # includes get_attr + call_module nodes
352        [['param', 'add_1', 'linear']],   # includes get_attr + call_function + call_module nodes
353        [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]],  # full graph
354    ])
355    def test_fuser_util(self, partition):
356        m = TestModule()
357        gm = symbolic_trace(m)
358
359        nodes_by_name = {node.name : node for node in gm.graph.nodes}
360
361        partitions = []
362        for node_names in partition:
363            partitions.append([nodes_by_name[name] for name in node_names])
364
365        fused_graph = fuse_by_partitions(gm, partitions)
366
367        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
368
369        expected = m(a, b, c)
370        result = fused_graph(a, b, c)
371
372        torch.testing.assert_close(expected, result)
373
374    @parametrize("partition", [
375        [['add', 'add_1'], ['add_1', 'add_5', 'add_6']],  # add_1 exists in multiple partitions
376        [['add', 'add_1', 'add_3']],    # invalid partition: circular dependency
377        [['add_4', 'add_5']],    # invalid partition: circular dependency
378        [['relu', 'add_5']],    # invalid partition: circular dependency
379    ])
380    def test_fuser_util_xfail(self, partition):
381        m = TestModule()
382        gm = symbolic_trace(m)
383
384        nodes_by_name = {node.name : node for node in gm.graph.nodes}
385
386        partitions = []
387        for node_names in partition:
388            partitions.append([nodes_by_name[name] for name in node_names])
389
390        with self.assertRaises(Exception):
391            fuse_by_partitions(gm, partitions)
392
393    def test_fuser_pass_deep_model(self):
394        m = TestDeepModule()
395        traced = symbolic_trace(m)
396
397        supported_ops = MockOperatorSupport()
398        partitioner = CapabilityBasedPartitioner(traced,
399                                                 supported_ops,
400                                                 allows_single_node_partition=True)
401        partitions = partitioner.propose_partitions()
402
403@dataclass
404class TestCase:
405    match_output: bool
406    match_placeholder: bool
407    num_matches: int
408    remove_overlapping_matches: bool = True
409
410class SingleNodePattern:
411    @staticmethod
412    def forward(x):
413        val = torch.neg(x)
414        return torch.add(val, val)
415
416    @staticmethod
417    def pattern(a):
418        return torch.neg(a)
419
420    test_cases = [
421        # match_output, match_placeholder, num_matches
422        TestCase(False, False, 1),
423        TestCase(True, False, 0),
424        TestCase(False, True, 1),
425        TestCase(True, True, 0)
426    ]
427class SimplePattern:
428    @staticmethod
429    def forward(x, w1, w2):
430        m1 = torch.cat([w1, w2]).sum()
431        m2 = torch.cat([w2, w1]).sum()
432        m3 = torch.cat([m1, m2]).sum()
433        return x + torch.max(m1) + torch.max(m2) + m3
434
435    @staticmethod
436    def pattern(a, b):
437        return torch.cat([a, b]).sum()
438
439    test_cases = [
440        # match_output, match_placeholder, num_matches
441        TestCase(False, False, 3),
442        TestCase(True, False, 0),
443        TestCase(False, True, 2),
444        TestCase(True, True, 0)
445    ]
446
447class SimpleFullGraphMatching:
448    @staticmethod
449    def forward(x):
450        a = torch.neg(x)
451        return torch.add(a, a)
452
453    @staticmethod
454    def pattern(x):
455        a = torch.neg(x)
456        return torch.add(a, a)
457
458    test_cases = [
459        # match_output, match_placeholder, num_matches
460        TestCase(False, False, 1),
461        TestCase(True, False, 1),
462        TestCase(False, True, 1),
463        TestCase(True, True, 1)
464    ]
465
466class DiamondShapePatternTestCase:
467    @staticmethod
468    def forward(x):
469        a = torch.neg(x)
470
471        a = a.relu()
472        left = a.sigmoid()
473        right = a.relu()
474        out = left + right
475
476        return out
477
478    @staticmethod
479    def pattern(a):
480        a = a.relu()
481        left = a.sigmoid()
482        right = a.relu()
483        out = left + right
484        return out
485
486    test_cases = [
487        # match_output, match_placeholder, num_matches
488        TestCase(False, False, 1),
489        TestCase(True, False, 1),
490        TestCase(False, True, 0),
491        TestCase(True, True, 0)
492    ]
493
494class NonFullyContainedMatches:
495    @staticmethod
496    def forward(x, w1, w2, b1, b2):
497        # fully contained matched subgraph
498        m1 = torch.cat([w1, w2])
499        m2 = torch.cat([x, b2])
500        t0 = torch.addmm(b1, m1, m2.t())
501        t0_sum = torch.sum(t0)   # use of t0 is not leaking
502
503        # leaking matched subgraph, m3 is leaked
504        m3 = torch.cat([w1, w2])
505        m4 = torch.cat([x, b2])
506        t1 = torch.addmm(b1, m3, m4.t())
507        m3_sum = torch.sum(m3)
508
509        return t0_sum, m3_sum
510
511    @staticmethod
512    def pattern(x, w1, w2, b1, b2):
513        m1 = torch.cat([w1, w2])
514        m2 = torch.cat([x, b2])
515        return torch.addmm(b1, m1, m2.t())
516
517    test_cases = [
518        # match_output, match_placeholder, num_matches
519        TestCase(False, False, 1),
520
521        TestCase(True, False, 0),
522
523        TestCase(False, True, 1),     # leaked used of placeholder is not leaking
524    ]
525
526class ChainRepeatedPattern:
527    @staticmethod
528    def forward(x):
529        x = torch.sigmoid(x)
530        x = torch.sigmoid(x)
531        x = torch.sigmoid(x)
532        return torch.sigmoid(x)
533
534    @staticmethod
535    def pattern(x):
536        return torch.sigmoid(torch.sigmoid(x))
537
538    test_cases = [
539        # match_output, match_placeholder, num_matches
540        TestCase(False, False, 3, remove_overlapping_matches=False),
541        TestCase(False, False, 2, remove_overlapping_matches=True),
542        TestCase(True, False, 1),
543        TestCase(False, True, 1),
544        TestCase(True, True, 0)
545    ]
546
547class QuantizationModel:
548    @staticmethod
549    def forward(x):
550        x += 3
551        x = x.dequantize()
552        x = torch.sigmoid(x)
553        x = x.to(torch.float16)
554        return x
555
556    @staticmethod
557    def pattern(x):
558        x = x.dequantize()
559        x = torch.sigmoid(x)
560        x = x.to(torch.float16)
561        return x
562
563    test_cases = [
564        # match_output, match_placeholder, num_matches
565        TestCase(False, False, 1),
566        TestCase(True, False, 1),
567        TestCase(False, True, 0),
568        TestCase(True, True, 0)
569    ]
570
571class MultipleOutputsWithDependency:
572    @staticmethod
573    def forward(x):
574        y = x.relu()
575        z = y.sigmoid()
576        return z, y
577
578    @staticmethod
579    def pattern(a):
580        b = a.relu()
581        c = b.sigmoid()
582        return b, c     # outputs have data dependency
583
584    test_cases = [
585        # match_output, match_placeholder, num_matches
586        TestCase(False, False, 1),
587        TestCase(True, False, 0),
588        TestCase(False, True, 1),
589        TestCase(True, True, 0)
590    ]
591
592class MultipleOutputsWithoutDependency:
593    @staticmethod
594    def forward(x):
595        x = x + 1
596
597        # target subgraph to match
598        x = x.relu()
599        z = x.sum()
600        y = x.sigmoid()
601
602        out = y.sigmoid() + z.sum()
603        return out
604
605    @staticmethod
606    def pattern(a):
607        a = a.relu()
608        b = a.sigmoid()
609        c = a.sum()
610        return b, c
611
612    test_cases = [
613        # match_output, match_placeholder, num_matches
614        TestCase(False, False, 1),
615        TestCase(True, False, 0),
616        TestCase(False, True, 0),
617        TestCase(True, True, 0)
618    ]
619
620class MultipleOutputsMultipleOverlappingMatches:
621    @staticmethod
622    def forward(x):
623        x = x + 1
624
625        # target subgraph to match
626        x = x.relu()
627        z = x.sum()
628        z1 = x.sum()
629        y = x.sigmoid()
630        y1 = x.sigmoid()
631
632        return z + z1 + y + y1
633
634    @staticmethod
635    def pattern(a):
636        a = a.relu()
637        b = a.sigmoid()
638        c = a.sum()
639        return a, b, c
640
641    test_cases = [
642        # match_output, match_placeholder, num_matches
643        TestCase(False, False, 4, remove_overlapping_matches=False),
644        TestCase(False, False, 1, remove_overlapping_matches=True),
645    ]
646
647class MultipleOutputsMultipleNonOverlappingMatches:
648    @staticmethod
649    def forward(x):
650        x = x + 1
651
652        # target subgraph to match
653        x = x.relu()
654        z = x.sum()
655        y = x.sigmoid()
656
657        x = x.relu()
658        z1 = x.sum()
659        y1 = x.sigmoid()
660
661        return z + z1 + y + y1
662
663    @staticmethod
664    def pattern(a):
665        a = a.relu()
666        b = a.sigmoid()
667        c = a.sum()
668        return b, c
669
670    test_cases = [
671        # match_output, match_placeholder, num_matches
672        TestCase(False, False, 1),
673    ]
674
675class MultipleOutputsIdenticalAnchor:
676    @staticmethod
677    def forward(x):
678        x = x + 1
679
680        # target subgraph to match
681        x = x.relu()
682        y = x.sigmoid()
683        y1 = x.sigmoid()
684
685        return y, y1
686
687    @staticmethod
688    def pattern(a):
689        a = a.relu()
690        b = a.sigmoid()
691        b1 = a.sigmoid()
692        return b, b1
693
694    test_cases = [
695        # match_output, match_placeholder, num_matches
696        # (False, False, 2),  # FIXME: currently still matches to 2, should fix to 1
697        TestCase(True, False, 1),
698        TestCase(False, True, 0),
699    ]
700
701
702class MultipleOutputsHorizontalPattern:
703    @staticmethod
704    def forward(x):
705        x = x + 1
706
707        # target subgraph to match
708        y1 = x.relu()
709        y2 = x.sigmoid()
710
711        return y1, y2
712
713    @staticmethod
714    def pattern(a):
715        b1 = a.relu()
716        b2 = a.sigmoid()
717
718        return b1, b2
719
720    test_cases = [
721        # match_output, match_placeholder, num_matches
722        TestCase(False, False, 1),
723        TestCase(True, False, 1),
724        TestCase(False, True, 0),
725        TestCase(True, True, 0)
726    ]
727
728class MultiOutputWithWithInvalidMatches:
729    @staticmethod
730    def forward(x):
731        res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
732        res1 = torch.sigmoid(res0)
733        res2 = res0 * res1
734        res3 = torch.sum(res2, dim=1)
735        return res3
736
737    @staticmethod
738    def pattern(a, b, c):
739        lin_res = torch.nn.functional.linear(a, b)
740        mul_res = lin_res * c
741        return lin_res, mul_res
742
743    test_cases = [
744        # match_output, match_placeholder, num_matches
745        TestCase(False, False, 0),
746        TestCase(True, False, 0),
747        TestCase(False, True, 0),
748    ]
749
750class QuantizationFp8Pattern:
751    @classmethod
752    def setup(cls):
753        cls.quantization = torch.library.Library("fp8_quantization", "DEF")  # noqa: TOR901
754        cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
755        cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
756
757    @classmethod
758    def tearDown(cls):
759        del cls.quantization
760
761    @staticmethod
762    def forward(self, arg0_1, arg1_1):
763        qt = torch.ops.fp8_quantization
764        _scale_0 = self._scale_0
765        quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
766        dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
767        _scale_1 = self._scale_0
768        quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
769        dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
770        add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
771        _scale_2 = self._scale_0
772        quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
773        dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
774        return dequantize_per_tensor_affine_fp8_2
775
776    @staticmethod
777    def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
778        qt = torch.ops.fp8_quantization
779        a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
780        b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
781        output = torch.ops.aten.add.Tensor(a, b)
782
783        qt.dequantize_per_tensor_affine_fp8
784
785        output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
786        return output
787
788    test_cases = [
789        # match_output, match_placeholder, num_matches
790        TestCase(False, False, 1),
791    ]
792
793class NoAnchorFound:
794    # This test case is for pattern where no matching anchor is found in the target graph
795    # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
796    @staticmethod
797    def forward(x):
798        x = x + 1
799        return x
800
801    @staticmethod
802    def pattern(a):
803        b1 = a.relu()
804        return b1
805
806    test_cases = [
807        # match_output, match_placeholder, num_matches
808        TestCase(False, False, 0),
809        TestCase(True, False, 0),
810        TestCase(False, True, 0),
811        TestCase(True, True, 0)
812    ]
813
814@instantiate_parametrized_tests
815class TestFXMatcherUtils(JitTestCase):
816
817    @parametrize("test_model", [
818        SingleNodePattern,
819        SimplePattern,
820        SimpleFullGraphMatching,
821        DiamondShapePatternTestCase,
822        NonFullyContainedMatches,
823        ChainRepeatedPattern,
824        QuantizationModel,
825        MultipleOutputsWithDependency,
826        MultipleOutputsWithoutDependency,
827        MultipleOutputsMultipleOverlappingMatches,
828        MultipleOutputsMultipleNonOverlappingMatches,
829        MultipleOutputsIdenticalAnchor,
830        MultipleOutputsHorizontalPattern,
831        MultiOutputWithWithInvalidMatches,
832        QuantizationFp8Pattern,
833        NoAnchorFound,
834    ])
835    def test_subgraph_matcher(self, test_model):
836
837        setup = getattr(test_model, "setup", None)
838        if callable(setup):
839            setup()
840
841        traced = symbolic_trace(test_model.forward)
842        pattern_traced = symbolic_trace(test_model.pattern)
843
844        for test_case in test_model.test_cases:
845
846            matcher = SubgraphMatcher(pattern_traced.graph,
847                                      match_output=test_case.match_output,
848                                      match_placeholder=test_case.match_placeholder,
849                                      remove_overlapping_matches=test_case.remove_overlapping_matches)
850            matches = matcher.match(traced.graph)
851
852            assert len(matches) == test_case.num_matches
853
854            for match in matches:
855                for node in pattern_traced.graph.nodes:
856                    if not test_case.match_placeholder and node.op == "placeholder":
857                        continue
858                    if not test_case.match_output and node.op == "output":
859                        continue
860                    assert node in match.nodes_map
861
862        tearDown = getattr(test_model, "tearDown", None)
863        if callable(setup):
864            tearDown()
865
866
867if __name__ == "__main__":
868    run_tests()
869