xref: /aosp_15_r20/external/pytorch/test/test_mkldnn_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mkldnn"]
2import itertools
3import unittest
4from typing import NamedTuple, List
5
6import torch
7from torch import nn
8
9from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
10from torch.testing._internal.jit_utils import JitTestCase
11
12from test_tensorexpr import warmup_and_run_forward
13
14FUSION_GROUP = 'prim::TensorExprGroup'
15
16class PointwisePostOp(NamedTuple):
17    attr : str
18    pointwise_module : nn.Module
19    scalars : List = []
20    algorithm : str = ""
21
22CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
23CONV_TRANSPOSE_MODULES = {2: torch.nn.ConvTranspose2d}
24
25@skipIfTorchDynamo("too slow")
26@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
27class TestMkldnnFusion(JitTestCase):
28    def assertFused(self, graph, fused_patterns):
29        for pat in fused_patterns:
30            self.assertGraphContainsExactly(graph, pat, 0)
31
32    def _check_model(self, m, x, trace=False):
33        old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
34        torch._C._debug_set_fusion_group_inlining(False)
35
36        old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
37        torch._C._jit_override_can_fuse_on_cpu(True)
38
39        old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
40        torch._C._jit_set_te_must_use_llvm_cpu(False)
41
42        m.eval()
43        with torch.no_grad():
44            if trace:
45                script = torch.jit.trace(m, x)
46            else:
47                script = torch.jit.script(m)
48        script = torch.jit.freeze(script)
49
50        with torch.no_grad():
51            y = warmup_and_run_forward(script, x)
52            y = script(x)
53            y_ref = m(x)
54
55            graph = script.graph_for(*x)
56            self.assertEqual(y, y_ref)
57
58        torch._C._debug_set_fusion_group_inlining(old_fusion_inlining)
59        torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
60        torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu)
61        return graph
62
63    def test_single_conv(self):
64        class M(nn.Module):
65            def __init__(self, in_channels, out_channels, bias, **kwargs):
66                super().__init__()
67                self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
68
69            def forward(self, x):
70                res = self.conv(x)
71                return res
72
73        for memory_format, enabled in [
74            [torch.contiguous_format, False],
75            [torch.channels_last, True],
76        ]:
77            for trace in [True, False]:
78                input_size = 224
79                batch_size = 1
80                kernel_size = 3
81                options = itertools.product([True, False], [1, 2], [1, 4])
82                for bias, dilation, groups in options:
83                    iC = 3 * groups
84                    oC = 10 * groups
85                    m = M(iC,
86                          oC,
87                          bias,
88                          kernel_size=(kernel_size, kernel_size),
89                          stride=2,
90                          padding=1,
91                          dilation=dilation,
92                          groups=groups).to(memory_format=memory_format)
93                    x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
94                    graph = self._check_model(m, x, trace)
95                    conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d'
96                    if enabled:
97                        self.assertFused(graph, [conv_node_name])
98                        self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
99                    else:
100                        self.assertGraphContains(graph, kind=conv_node_name)
101
102    def test_conv_unary_fusion_nnc(self):
103        class M(nn.Module):
104            def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
105                super().__init__()
106                self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
107                self.unary = unary_fn
108
109            def forward(self, x):
110                x = self.conv(x)
111                x = self.unary(x)
112                return x
113
114        for memory_format, enabled in [
115            [torch.contiguous_format, False],
116            [torch.channels_last, True],
117        ]:
118            for unary_fn in [torch.relu]:
119                for bias in [True, False]:
120                    for oC in [1, 10]:
121                        m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
122                        x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format)
123
124                        graph = self._check_model(m, x)
125                        if enabled:
126                            self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__])
127                            self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
128                        else:
129                            self.assertGraphContains(graph, kind='aten::conv2d')
130
131    def test_unsupported_conv(self):
132        class M(nn.Module):
133            def __init__(self, m, in_channels, out_channels, bias, **kwargs):
134                super().__init__()
135                self.conv = m(in_channels, out_channels, bias=bias, **kwargs)
136
137            def forward(self, x):
138                res = self.conv(x)
139                return res
140
141        for module, dim, memory_format in [
142            [nn.Conv3d, 3, torch.contiguous_format],
143            [nn.Conv3d, 3, torch.channels_last_3d],
144            [nn.ConvTranspose2d, 2, torch.contiguous_format],
145            [nn.ConvTranspose2d, 2, torch.channels_last],
146        ]:
147            trace = True
148            input_size = 224
149            batch_size = 1
150            kernel_size = 3
151            groups = 2
152            bias = True
153            iC = 3 * groups
154            oC = 10 * groups
155            dilation = 2
156            m = M(module,
157                  iC,
158                  oC,
159                  bias,
160                  kernel_size=kernel_size,
161                  stride=2,
162                  padding=1,
163                  dilation=dilation,
164                  groups=groups).to(memory_format=memory_format)
165            input_sizes = [batch_size, iC, input_size, input_size]
166            if dim == 3:
167                input_sizes.append(input_size)
168            x = torch.randn(input_sizes).to(memory_format=memory_format)
169            graph = self._check_model(m, x, trace)
170            self.assertGraphContains(graph, kind='aten::_convolution')
171
172    def _unary_list(self):
173        unary_list = {
174            "relu": PointwisePostOp("relu", nn.ReLU()),
175            "sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()),
176            "tanh": PointwisePostOp("tanh", nn.Tanh()),
177            "hardswish": PointwisePostOp("hardswish", nn.Hardswish()),
178            "leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]),
179            "hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]),
180            "gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"),
181            "gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"),
182        }
183        return unary_list
184
185    def _binary_list(self):
186        binary_list = {
187            "add": torch.add,
188            "sub": torch.sub,
189            "mul": torch.mul,
190            "div": torch.div,
191        }
192        return binary_list
193
194    def test_linear_unary_fusion_ops(self):
195        class M(nn.Module):
196            def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
197                super().__init__()
198                self.linear = torch.nn.Linear(
199                    in_channels, out_channels, bias=bias, **kwargs
200                )
201                self.unary = unary_fn
202
203            def forward(self, x):
204                x = self.linear(x)
205                x = self.unary(x)
206                return x
207
208        for pointwise_info in self._unary_list().values():
209            # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
210            # but it's strides is not default contiguous strides.
211            options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
212            for (input_shape, input_stride), bias in options:
213                with torch.no_grad():
214                    mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval()
215                    v = torch.randn(input_shape)
216                    if input_stride is not None:
217                        v = v.as_strided(input_shape, input_stride)
218                    ref = mod(v)
219                    attr = pointwise_info.attr
220                    scalars = pointwise_info.scalars
221                    algorithm = pointwise_info.algorithm
222                    fused = torch.ops.mkldnn._linear_pointwise(
223                        v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm
224                    )
225                    self.assertEqual(ref, fused)
226
227
228    def test_conv_unary_fusion_ops(self):
229        class M(nn.Module):
230            def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs):
231                super().__init__()
232                self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs)
233                self.unary = unary_fn
234
235            def forward(self, x):
236                x = self.conv(x)
237                x = self.unary(x)
238                return x
239
240        input_shapes = {2: (112, 112), 3: (55, 55, 55)}
241        for pointwise_info in self._unary_list().values():
242            for dim in [2, 3]:
243                channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
244                options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
245                for bias, dilation, groups, memory_format in options:
246                    oC = 32 * groups
247                    iC = 3 * groups
248                    x_shape = (1, iC) + input_shapes[dim]
249                    x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
250                    mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3)
251                    mod = mod.to(memory_format=memory_format).eval()
252                    with torch.no_grad():
253                        ref = mod(x)
254                        attr = pointwise_info.attr
255                        scalars = pointwise_info.scalars
256                        algorithm = pointwise_info.algorithm
257                        fused = torch.ops.mkldnn._convolution_pointwise(
258                            x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
259                            mod.conv.groups, attr, scalars, algorithm
260                        )
261                    self.assertEqual(ref, fused)
262
263
264    def test_conv_binary_fusion_ops(self):
265        class M(nn.Module):
266            def __init__(self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs):
267                super().__init__()
268                self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs)
269                self.binary = binary_fn
270
271            def forward(self, x, other):
272                x = self.conv(x)
273                x = self.binary(x, other)
274                return x
275
276        input_shapes = {2: (112, 112), 3: (22, 22, 22)}
277        for pointwise_name, pointwise_fn in self._binary_list().items():
278            for dim in [2, 3]:
279                channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
280                options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
281                for fuse_relu, bias, dilation, groups, memory_format in options:
282                    oC = 32 * groups
283                    iC = 3 * groups
284                    x_shape = (1, iC) + input_shapes[dim]
285                    x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
286                    mod = M(pointwise_fn, dim, iC, oC, dilation, groups, bias, kernel_size=3)
287                    mod = mod.to(memory_format=memory_format).eval()
288                    other = torch.randn_like(mod.conv(x))
289                    with torch.no_grad():
290                        ref = mod(x, other)
291                        unary_attr = None
292                        if fuse_relu:
293                            ref.relu_()
294                            unary_attr = "relu"
295                        attr = pointwise_name
296                        fused = torch.ops.mkldnn._convolution_pointwise(
297                            x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
298                            mod.conv.groups, attr, None, unary_attr, [], None
299                        )
300                        # for binary add, we support inplace version.
301                        if attr == "add":
302                            fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
303                                other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
304                                mod.conv.groups, attr, None, unary_attr, [], None
305                            )
306                            self.assertEqual(ref, other)
307                            self.assertEqual(ref, fused_inplace)
308
309                        self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
310
311
312    def test_linear_binary_fusion_ops(self):
313        class M(nn.Module):
314            def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
315                super().__init__()
316                self.linear = torch.nn.Linear(
317                    in_channels, out_channels, bias=bias, **kwargs
318                )
319                self.binary = binary_fn
320
321            def forward(self, x, other):
322                x = self.linear(x)
323                x = self.binary(x, other)
324                return x
325
326        out_feature = 20
327        for pointwise_name, pointwise_fn in self._binary_list().items():
328            # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
329            # but it's strides is not default contiguous strides.
330            options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
331            for (input_shape, input_stride), bias in options:
332                with torch.no_grad():
333                    mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval()
334                    v = torch.randn(input_shape)
335                    if input_stride is not None:
336                        v = v.as_strided(input_shape, input_stride)
337                    other = torch.randn(input_shape[:-1] + [out_feature])
338                    ref = mod(v, other)
339                    attr = pointwise_name
340                    fused = torch.ops.mkldnn._linear_pointwise(
341                        v, other, mod.linear.weight, mod.linear.bias, attr
342                    )
343                    self.assertEqual(ref, fused)
344
345    def test_conv_transpose_unary_fusion_ops(self):
346        class M(nn.Module):
347            def __init__(self, unary_fn, dim, in_channels, out_channels, kernel_size, **kwargs):
348                super().__init__()
349                self.conv_transpose = CONV_TRANSPOSE_MODULES[dim](in_channels, out_channels, kernel_size, **kwargs)
350                self.unary = unary_fn
351
352            def forward(self, x):
353                x = self.conv_transpose(x)
354                x = self.unary(x)
355                return x
356
357        input_shapes = {2: (28, 28)}
358        kernel_size = 3
359        for pointwise_info in self._unary_list().values():
360            for dim in [2]:
361                channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
362                options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True])
363                for bias, dilation, groups, memory_format, prepack_weight in options:
364                    oC = 32 * groups
365                    iC = 3 * groups
366                    x_shape = (1, iC) + input_shapes[dim]
367                    x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
368                    mod = M(pointwise_info.pointwise_module, dim, iC, oC, kernel_size, dilation=dilation, groups=groups, bias=bias)
369                    mod = mod.to(memory_format=memory_format).eval()
370                    with torch.no_grad():
371                        ref = mod(x)
372                        attr = pointwise_info.attr
373                        scalars = pointwise_info.scalars
374                        algorithm = pointwise_info.algorithm
375
376                        if prepack_weight:
377                            packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
378                                mod.conv_transpose.weight,
379                                mod.conv_transpose.padding,
380                                mod.conv_transpose.output_padding,
381                                mod.conv_transpose.stride,
382                                mod.conv_transpose.dilation,
383                                mod.conv_transpose.groups,
384                                x.size())
385                            mod.conv_transpose.weight = torch.nn.Parameter(
386                                packed_weight,
387                                requires_grad=mod.conv_transpose.weight.requires_grad,
388                            )
389
390                        fused = torch.ops.mkldnn._convolution_transpose_pointwise(
391                            x,
392                            mod.conv_transpose.weight,
393                            mod.conv_transpose.bias,
394                            mod.conv_transpose.padding,
395                            mod.conv_transpose.output_padding,
396                            mod.conv_transpose.stride,
397                            mod.conv_transpose.dilation,
398                            mod.conv_transpose.groups,
399                            attr,
400                            scalars,
401                            algorithm)
402                    self.assertEqual(ref, fused)
403
404if __name__ == "__main__":
405    run_tests()
406