xref: /aosp_15_r20/external/pytorch/test/inductor/test_custom_post_grad_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import operator
4from collections import defaultdict
5
6import torch
7import torch._inductor.pattern_matcher as pattern_matcher
8import torch.fx as fx
9from torch._dynamo.utils import counters
10from torch._inductor import config
11from torch._inductor.lowering import lowerings as L
12from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
13from torch._inductor.test_case import run_tests, TestCase
14from torch.testing._internal.common_utils import IS_LINUX
15from torch.testing._internal.inductor_utils import HAS_CPU
16
17
18@config.patch({"freezing": True})
19class TestCustomPassBase(TestCase):
20    def _clone_inputs(self, inputs):
21        def clone(x):
22            if not isinstance(x, torch.Tensor):
23                return x
24            return x.clone()
25
26        return tuple(clone(x) for x in inputs)
27
28    def _test_common(
29        self,
30        mod,
31        inputs,
32        matcher_count,
33        matcher_nodes,
34        atol=1e-5,
35        rtol=1.3e-6,
36    ):
37        counters.clear()
38        maybe_autocast = contextlib.nullcontext()
39        with torch.no_grad(), maybe_autocast:
40            clone_inputs = self._clone_inputs(inputs)
41            expected = mod(*inputs)
42            actual = torch.compile(mod)(*clone_inputs)
43            torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
44            self.assertEqual(
45                counters["inductor"]["pattern_matcher_count"], matcher_count
46            )
47            self.assertEqual(
48                counters["inductor"]["pattern_matcher_nodes"],
49                matcher_nodes,
50            )
51
52
53aten = torch.ops.aten
54mkldnn = torch.ops.mkldnn
55
56
57def change_cos_pass(graph):
58    for node in graph.nodes:
59        if node.op == "call_function" and node.target == aten.cos.default:
60            node.target = aten.sin.default
61
62
63class TestPostGradCustomPrePostPass(TestCustomPassBase):
64    #  mkldnn fusion's pattern_matcher
65    # (torch/_inductor/fx_passes/mkldnn_fusion.py),
66    # and apply it to custom post_grad_passes.
67    def _register_mkldnn_conv_relu_fusion(self, custom_pass_dict):
68        # pattern
69        def _mkldnn_conv_relu_pattern():
70            return CallFunction(
71                aten.relu,
72                CallFunction(
73                    mkldnn._convolution_pointwise.default,
74                    Arg(),
75                    Arg(),
76                    Arg(),
77                    Arg(),
78                    Arg(),
79                    Arg(),
80                    Arg(),
81                    Arg(),
82                    Arg(),
83                    Arg(),
84                    _users=1,
85                ),
86            )
87
88        # utils of pattern matcher registration
89        def _register_fusion_lowering(pattern, custom_pass_dict):
90            def dummy_check(m):
91                return True
92
93            def register_custom_lowering_pattern(
94                pattern, extra_check, custom_pass_dict
95            ):
96                return pattern_matcher.register_lowering_pattern(
97                    pattern, extra_check, pass_dict=custom_pass_dict
98                )
99
100            @register_custom_lowering_pattern(pattern, dummy_check, custom_pass_dict)
101            def fn(match, *args, **kwargs):
102                computation_args = list(args)[:-3] + ["relu", [], ""]
103                return L[mkldnn._convolution_pointwise.default](*computation_args)
104
105            return fn
106
107        _register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)
108
109    # custom post grad pass
110    class _CustomPass(PatternMatcherPass):
111        def __init__(self) -> None:
112            super().__init__()
113
114        def __call__(self, g: torch.fx.graph.Graph):
115            self.apply(g)
116
117    # case model
118    class _ConvReLU(torch.nn.Module):
119        def __init__(self, ic, oc):
120            super().__init__()
121            self.conv = torch.nn.Conv2d(ic, oc, kernel_size=3, stride=1, padding=1)
122
123        def forward(self, x):
124            x1 = self.conv(x)
125            return x1.relu()
126
127    def test_custom_joint_pass_pre(self):
128        with config.patch(joint_custom_pre_pass=change_cos_pass):
129
130            def g(x):
131                return x.sin().sin().sin()
132
133            def f(x):
134                return x.cos().cos().cos()
135
136            x = torch.randn(8, dtype=torch.float32)
137            torch.testing.assert_close(torch.compile(f)(x), g(x))
138
139    def test_custom_joint_pass_post(self):
140        with config.patch(joint_custom_post_pass=change_cos_pass):
141
142            def g(x):
143                return x.sin().sin().sin()
144
145            def f(x):
146                return x.cos().cos().cos()
147
148            x = torch.randn(8, dtype=torch.float32)
149            torch.testing.assert_close(torch.compile(f)(x), g(x))
150
151    def test_custom_pre_pass(self):
152        with config.patch(
153            # leave custom pass only in post_grad_passes()
154            pattern_matcher=False,
155            post_grad_custom_pre_pass=self._CustomPass(),
156            # define pattern match as custom post grad opt pass
157            post_grad_custom_post_pass=None,
158        ):
159            # init mkldnn fusion on custom_matcher
160            self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_pre_pass)
161
162            mod = self._ConvReLU(16, 16).eval()
163            x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
164
165            match_count = 1
166            match_nodes = 2
167            other_match_count = 1  # conv prepack weight
168            other_match_nodes = 1  # conv prepack weight
169            self._test_common(
170                mod,
171                (x,),
172                match_count + other_match_count,
173                match_nodes + other_match_nodes,
174            )
175
176    def test_custom_post_pass(self):
177        with config.patch(
178            # leave custom pass only in post_grad_passes()
179            pattern_matcher=False,
180            # define pattern match as custom post grad opt pass
181            post_grad_custom_pre_pass=None,
182            post_grad_custom_post_pass=self._CustomPass(),
183        ):
184            # init mkldnn fusion on custom_matcher
185            self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_post_pass)
186
187            mod = self._ConvReLU(16, 16).eval()
188            x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
189
190            match_count = 1
191            match_nodes = 2
192            other_match_count = 1  # conv prepack weight
193            other_match_nodes = 1  # conv prepack weight
194            self._test_common(
195                mod,
196                (x,),
197                match_count + other_match_count,
198                match_nodes + other_match_nodes,
199            )
200
201    def test_custom_pre_grad_pass(self):
202        saved_graph = [None]
203
204        def merge_mm_shared_rhs(graph: fx.Graph):
205            """
206            Bad POC of merging mm with a shared RHS.
207            i.e. [mm(x, W), mm(x2, W)] => mm(cat(x, x2), W).split()
208
209            Isn't actually safe for a couple reasons. For example, it doesn't handle the
210            case where the LHS inputs depend on each other
211            """
212            saved_graph[0] = graph
213            matmuls = [n for n in graph.nodes if n.target == torch.mm]
214            rhs_vals = defaultdict(set)
215            for m in matmuls:
216                rhs_vals[m.args[1]].add(m)
217
218            order = {}
219            for idx, n in enumerate(graph.nodes):
220                order[n] = idx
221
222            for rhs, matmuls in rhs_vals.items():
223                if len(matmuls) == 1:
224                    continue
225                matmuls = sorted(matmuls, key=lambda x: order[x])
226                with graph.inserting_before(matmuls[0]):
227                    lhs_vals = [m.args[0] for m in matmuls]
228                    new_cat = graph.create_node(
229                        "call_function", torch.cat, args=(lhs_vals, 0)
230                    )
231                    new_mm = graph.create_node(
232                        "call_function", torch.mm, args=(new_cat, rhs)
233                    )
234                    split_vals = graph.create_node(
235                        "call_function",
236                        torch.split,
237                        args=(
238                            new_mm,
239                            [l.meta["example_value"].shape[0] for l in lhs_vals],
240                        ),
241                    )
242                for idx, m in enumerate(matmuls):
243                    m.target = operator.getitem
244                    m.args = (split_vals, idx)
245
246        @config.patch(pre_grad_custom_pass=merge_mm_shared_rhs)
247        def inner_test():
248            @torch.compile
249            def f(W, nested_seqs):
250                outs = [torch.mm(s, W) for s in nested_seqs]
251                return outs
252
253            W = torch.randn(16, 16, dtype=torch.bfloat16)
254            nested_seqs = [
255                torch.randn(l, 16, dtype=torch.bfloat16) for l in [4, 8, 5, 3]
256            ]
257
258            f(W, nested_seqs)
259            assert saved_graph[0] is not None
260            matmuls = [n for n in saved_graph[0].nodes if n.target == torch.mm]
261            assert len(matmuls) == 1
262
263        inner_test()
264
265
266if __name__ == "__main__":
267    if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
268        run_tests()
269