# Owner(s): ["module: mkldnn"] import itertools import unittest from typing import NamedTuple, List import torch from torch import nn from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase from test_tensorexpr import warmup_and_run_forward FUSION_GROUP = 'prim::TensorExprGroup' class PointwisePostOp(NamedTuple): attr : str pointwise_module : nn.Module scalars : List = [] algorithm : str = "" CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} CONV_TRANSPOSE_MODULES = {2: torch.nn.ConvTranspose2d} @skipIfTorchDynamo("too slow") @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") class TestMkldnnFusion(JitTestCase): def assertFused(self, graph, fused_patterns): for pat in fused_patterns: self.assertGraphContainsExactly(graph, pat, 0) def _check_model(self, m, x, trace=False): old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() torch._C._debug_set_fusion_group_inlining(False) old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() torch._C._jit_override_can_fuse_on_cpu(True) old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() torch._C._jit_set_te_must_use_llvm_cpu(False) m.eval() with torch.no_grad(): if trace: script = torch.jit.trace(m, x) else: script = torch.jit.script(m) script = torch.jit.freeze(script) with torch.no_grad(): y = warmup_and_run_forward(script, x) y = script(x) y_ref = m(x) graph = script.graph_for(*x) self.assertEqual(y, y_ref) torch._C._debug_set_fusion_group_inlining(old_fusion_inlining) torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu) return graph def test_single_conv(self): class M(nn.Module): def __init__(self, in_channels, out_channels, bias, **kwargs): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) def forward(self, x): res = self.conv(x) return res for memory_format, enabled in [ [torch.contiguous_format, False], [torch.channels_last, True], ]: for trace in [True, False]: input_size = 224 batch_size = 1 kernel_size = 3 options = itertools.product([True, False], [1, 2], [1, 4]) for bias, dilation, groups in options: iC = 3 * groups oC = 10 * groups m = M(iC, oC, bias, kernel_size=(kernel_size, kernel_size), stride=2, padding=1, dilation=dilation, groups=groups).to(memory_format=memory_format) x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format) graph = self._check_model(m, x, trace) conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d' if enabled: self.assertFused(graph, [conv_node_name]) self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) else: self.assertGraphContains(graph, kind=conv_node_name) def test_conv_unary_fusion_nnc(self): class M(nn.Module): def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) self.unary = unary_fn def forward(self, x): x = self.conv(x) x = self.unary(x) return x for memory_format, enabled in [ [torch.contiguous_format, False], [torch.channels_last, True], ]: for unary_fn in [torch.relu]: for bias in [True, False]: for oC in [1, 10]: m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format) x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format) graph = self._check_model(m, x) if enabled: self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__]) self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) else: self.assertGraphContains(graph, kind='aten::conv2d') def test_unsupported_conv(self): class M(nn.Module): def __init__(self, m, in_channels, out_channels, bias, **kwargs): super().__init__() self.conv = m(in_channels, out_channels, bias=bias, **kwargs) def forward(self, x): res = self.conv(x) return res for module, dim, memory_format in [ [nn.Conv3d, 3, torch.contiguous_format], [nn.Conv3d, 3, torch.channels_last_3d], [nn.ConvTranspose2d, 2, torch.contiguous_format], [nn.ConvTranspose2d, 2, torch.channels_last], ]: trace = True input_size = 224 batch_size = 1 kernel_size = 3 groups = 2 bias = True iC = 3 * groups oC = 10 * groups dilation = 2 m = M(module, iC, oC, bias, kernel_size=kernel_size, stride=2, padding=1, dilation=dilation, groups=groups).to(memory_format=memory_format) input_sizes = [batch_size, iC, input_size, input_size] if dim == 3: input_sizes.append(input_size) x = torch.randn(input_sizes).to(memory_format=memory_format) graph = self._check_model(m, x, trace) self.assertGraphContains(graph, kind='aten::_convolution') def _unary_list(self): unary_list = { "relu": PointwisePostOp("relu", nn.ReLU()), "sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()), "tanh": PointwisePostOp("tanh", nn.Tanh()), "hardswish": PointwisePostOp("hardswish", nn.Hardswish()), "leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]), "hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]), "gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"), "gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"), } return unary_list def _binary_list(self): binary_list = { "add": torch.add, "sub": torch.sub, "mul": torch.mul, "div": torch.div, } return binary_list def test_linear_unary_fusion_ops(self): class M(nn.Module): def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() self.linear = torch.nn.Linear( in_channels, out_channels, bias=bias, **kwargs ) self.unary = unary_fn def forward(self, x): x = self.linear(x) x = self.unary(x) return x for pointwise_info in self._unary_list().values(): # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor # but it's strides is not default contiguous strides. options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False]) for (input_shape, input_stride), bias in options: with torch.no_grad(): mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval() v = torch.randn(input_shape) if input_stride is not None: v = v.as_strided(input_shape, input_stride) ref = mod(v) attr = pointwise_info.attr scalars = pointwise_info.scalars algorithm = pointwise_info.algorithm fused = torch.ops.mkldnn._linear_pointwise( v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm ) self.assertEqual(ref, fused) def test_conv_unary_fusion_ops(self): class M(nn.Module): def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): super().__init__() self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) self.unary = unary_fn def forward(self, x): x = self.conv(x) x = self.unary(x) return x input_shapes = {2: (112, 112), 3: (55, 55, 55)} for pointwise_info in self._unary_list().values(): for dim in [2, 3]: channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) for bias, dilation, groups, memory_format in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3) mod = mod.to(memory_format=memory_format).eval() with torch.no_grad(): ref = mod(x) attr = pointwise_info.attr scalars = pointwise_info.scalars algorithm = pointwise_info.algorithm fused = torch.ops.mkldnn._convolution_pointwise( x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, scalars, algorithm ) self.assertEqual(ref, fused) def test_conv_binary_fusion_ops(self): class M(nn.Module): def __init__(self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): super().__init__() self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) self.binary = binary_fn def forward(self, x, other): x = self.conv(x) x = self.binary(x, other) return x input_shapes = {2: (112, 112), 3: (22, 22, 22)} for pointwise_name, pointwise_fn in self._binary_list().items(): for dim in [2, 3]: channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) for fuse_relu, bias, dilation, groups, memory_format in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) mod = M(pointwise_fn, dim, iC, oC, dilation, groups, bias, kernel_size=3) mod = mod.to(memory_format=memory_format).eval() other = torch.randn_like(mod.conv(x)) with torch.no_grad(): ref = mod(x, other) unary_attr = None if fuse_relu: ref.relu_() unary_attr = "relu" attr = pointwise_name fused = torch.ops.mkldnn._convolution_pointwise( x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, None, unary_attr, [], None ) # for binary add, we support inplace version. if attr == "add": fused_inplace = torch.ops.mkldnn._convolution_pointwise_( other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, None, unary_attr, [], None ) self.assertEqual(ref, other) self.assertEqual(ref, fused_inplace) self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4) def test_linear_binary_fusion_ops(self): class M(nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() self.linear = torch.nn.Linear( in_channels, out_channels, bias=bias, **kwargs ) self.binary = binary_fn def forward(self, x, other): x = self.linear(x) x = self.binary(x, other) return x out_feature = 20 for pointwise_name, pointwise_fn in self._binary_list().items(): # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor # but it's strides is not default contiguous strides. options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False]) for (input_shape, input_stride), bias in options: with torch.no_grad(): mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval() v = torch.randn(input_shape) if input_stride is not None: v = v.as_strided(input_shape, input_stride) other = torch.randn(input_shape[:-1] + [out_feature]) ref = mod(v, other) attr = pointwise_name fused = torch.ops.mkldnn._linear_pointwise( v, other, mod.linear.weight, mod.linear.bias, attr ) self.assertEqual(ref, fused) def test_conv_transpose_unary_fusion_ops(self): class M(nn.Module): def __init__(self, unary_fn, dim, in_channels, out_channels, kernel_size, **kwargs): super().__init__() self.conv_transpose = CONV_TRANSPOSE_MODULES[dim](in_channels, out_channels, kernel_size, **kwargs) self.unary = unary_fn def forward(self, x): x = self.conv_transpose(x) x = self.unary(x) return x input_shapes = {2: (28, 28)} kernel_size = 3 for pointwise_info in self._unary_list().values(): for dim in [2]: channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True]) for bias, dilation, groups, memory_format, prepack_weight in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) mod = M(pointwise_info.pointwise_module, dim, iC, oC, kernel_size, dilation=dilation, groups=groups, bias=bias) mod = mod.to(memory_format=memory_format).eval() with torch.no_grad(): ref = mod(x) attr = pointwise_info.attr scalars = pointwise_info.scalars algorithm = pointwise_info.algorithm if prepack_weight: packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight( mod.conv_transpose.weight, mod.conv_transpose.padding, mod.conv_transpose.output_padding, mod.conv_transpose.stride, mod.conv_transpose.dilation, mod.conv_transpose.groups, x.size()) mod.conv_transpose.weight = torch.nn.Parameter( packed_weight, requires_grad=mod.conv_transpose.weight.requires_grad, ) fused = torch.ops.mkldnn._convolution_transpose_pointwise( x, mod.conv_transpose.weight, mod.conv_transpose.bias, mod.conv_transpose.padding, mod.conv_transpose.output_padding, mod.conv_transpose.stride, mod.conv_transpose.dilation, mod.conv_transpose.groups, attr, scalars, algorithm) self.assertEqual(ref, fused) if __name__ == "__main__": run_tests()