xref: /aosp_15_r20/external/pytorch/test/inductor/test_fx_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2from typing import Any, Callable
3
4import torch
5from torch._inductor.fx_passes.pre_grad import (
6    linear_permute_fusion,
7    linear_transpose,
8    permute_linear_fusion,
9    permute_matmul_fusion,
10    sink_cat_after_pointwise,
11    transpose_linear,
12    transpose_matmul,
13)
14from torch._inductor.test_case import run_tests, TestCase
15from torch.fx.passes.shape_prop import ShapeProp
16
17
18PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule]
19
20
21def chain_passes(*passes: PassFunc) -> PassFunc:
22    def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule:
23        for pass_ in passes:
24            if isinstance(module, torch.fx.GraphModule):
25                ShapeProp(module).propagate(*input)
26            module = pass_(module)
27        return module
28
29    return parent_pass
30
31
32def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int:
33    return sum(
34        1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes
35    )
36
37
38def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int:
39    return count_call(module, "call_function", target_op)
40
41
42def count_call_method(module: torch.fx.GraphModule, target_op: Any) -> int:
43    return count_call(module, "call_method", target_op)
44
45
46class TestFxFusion(TestCase):
47    def test_sink_cat_after_pointwise(self):
48        def test_kwarg(x, y):
49            return torch.cat([x, y], dim=-1).view(-1).view(128).tanh()
50
51        def test_arg(x, y):
52            return torch.cat([x, y], -1).view(-1).view(128).tanh()
53
54        def test_arg2(x, y):
55            return torch.cat([x, y]).view(-1).view(128).tanh()
56
57        def test_kwarg2(x, y):
58            return torch.cat(tensors=[x, y], dim=0).tanh()
59
60        def test_kwarg3(x, y):
61            return torch.cat(tensors=[x, y], dim=0).view(128).tanh()
62
63        trace_func = chain_passes(torch.fx.symbolic_trace, sink_cat_after_pointwise)
64        inputs = [
65            torch.randn(8, 8),
66            torch.randn(8, 8),
67        ]
68        for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]:
69            traced = trace_func(f, inputs)
70            torch.testing.assert_close(f(*inputs), traced(*inputs))
71            self.assertEqual(count_call_method(traced, "tanh"), 2)
72
73    def test_linear_permute_fusion(self):
74        class TestModule(torch.nn.Module):
75            def __init__(self, k: int, n: int, has_bias: bool):
76                super().__init__()
77                self.weight = torch.nn.Parameter(torch.randn(n, k))
78                self.has_bias = has_bias
79                if has_bias:
80                    self.bias = torch.nn.Parameter(torch.randn(n))
81
82            def forward(self, input: torch.Tensor):
83                if self.has_bias:
84                    a0 = torch.nn.functional.linear(input, self.weight, self.bias)
85                else:
86                    a0 = torch.nn.functional.linear(input, self.weight)
87                b0 = a0.permute(0, 2, 1)
88                return b0
89
90        m, k, n = 16, 8, 4
91        trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion)
92        for has_bias in [True, False]:
93            module = TestModule(k, n, has_bias).eval()
94            input = torch.randn(6, m, k)
95            traced = trace_func(module, [input])
96            num_linear = count_call_function(traced, torch.nn.functional.linear)
97            num_linear_transpose = count_call_function(traced, linear_transpose)
98            self.assertEqual(num_linear, 0)
99            self.assertEqual(num_linear_transpose, 1)
100
101            torch.testing.assert_close(module(input), traced(input))
102
103    def test_permute_linear_fusion(self):
104        class TestModule(torch.nn.Module):
105            def __init__(self, k: int, n: int, has_bias: bool):
106                super().__init__()
107                self.weight = torch.nn.Parameter(torch.randn(n, k))
108                self.has_bias = has_bias
109                if has_bias:
110                    self.bias = torch.nn.Parameter(torch.randn(n))
111
112            def forward(self, input: torch.Tensor):
113                input1 = input.permute(0, 2, 1)
114                if self.has_bias:
115                    return torch.nn.functional.linear(input1, self.weight, self.bias)
116                return torch.nn.functional.linear(input1, self.weight)
117
118        m, k, n = 16, 8, 4
119
120        trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion)
121        for has_bias in [True, False]:
122            module = TestModule(k, n, has_bias).eval()
123            input = torch.randn(6, k, m)
124            traced = trace_func(module, [input])
125            num_linear = count_call_function(traced, torch.nn.functional.linear)
126            num_transpose_linear = count_call_function(traced, transpose_linear)
127            self.assertEqual(num_linear, 0)
128            self.assertEqual(num_transpose_linear, 1)
129
130            torch.testing.assert_close(module(input), traced(input))
131
132    def test_permute_bmm_fusion(self):
133        class TestModule(torch.nn.Module):
134            def __init__(self, batch: int, k: int, n: int):
135                super().__init__()
136                self.other = torch.randn(batch, k, n)
137
138            def forward(self, input: torch.Tensor):
139                input1 = input.permute(0, 2, 1)
140                output = torch.bmm(input1, self.other)
141                return output
142
143        batch, m, k, n = 6, 16, 8, 4
144
145        trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion)
146        module = TestModule(batch, k, n).eval()
147        input = torch.randn(batch, k, m)
148        traced = trace_func(module, [input])
149        num_bmm = count_call_function(traced, torch.bmm)
150        num_transpose_matmul = count_call_function(traced, transpose_matmul)
151        self.assertEqual(num_bmm, 0)
152        self.assertEqual(num_transpose_matmul, 1)
153
154        torch.testing.assert_close(module(input), traced(input))
155
156
157if __name__ == "__main__":
158    run_tests()
159