1# Owner(s): ["module: mkldnn"] 2import itertools 3import unittest 4from typing import NamedTuple, List 5 6import torch 7from torch import nn 8 9from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo 10from torch.testing._internal.jit_utils import JitTestCase 11 12from test_tensorexpr import warmup_and_run_forward 13 14FUSION_GROUP = 'prim::TensorExprGroup' 15 16class PointwisePostOp(NamedTuple): 17 attr : str 18 pointwise_module : nn.Module 19 scalars : List = [] 20 algorithm : str = "" 21 22CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 23CONV_TRANSPOSE_MODULES = {2: torch.nn.ConvTranspose2d} 24 25@skipIfTorchDynamo("too slow") 26@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") 27class TestMkldnnFusion(JitTestCase): 28 def assertFused(self, graph, fused_patterns): 29 for pat in fused_patterns: 30 self.assertGraphContainsExactly(graph, pat, 0) 31 32 def _check_model(self, m, x, trace=False): 33 old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() 34 torch._C._debug_set_fusion_group_inlining(False) 35 36 old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() 37 torch._C._jit_override_can_fuse_on_cpu(True) 38 39 old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() 40 torch._C._jit_set_te_must_use_llvm_cpu(False) 41 42 m.eval() 43 with torch.no_grad(): 44 if trace: 45 script = torch.jit.trace(m, x) 46 else: 47 script = torch.jit.script(m) 48 script = torch.jit.freeze(script) 49 50 with torch.no_grad(): 51 y = warmup_and_run_forward(script, x) 52 y = script(x) 53 y_ref = m(x) 54 55 graph = script.graph_for(*x) 56 self.assertEqual(y, y_ref) 57 58 torch._C._debug_set_fusion_group_inlining(old_fusion_inlining) 59 torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) 60 torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu) 61 return graph 62 63 def test_single_conv(self): 64 class M(nn.Module): 65 def __init__(self, in_channels, out_channels, bias, **kwargs): 66 super().__init__() 67 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) 68 69 def forward(self, x): 70 res = self.conv(x) 71 return res 72 73 for memory_format, enabled in [ 74 [torch.contiguous_format, False], 75 [torch.channels_last, True], 76 ]: 77 for trace in [True, False]: 78 input_size = 224 79 batch_size = 1 80 kernel_size = 3 81 options = itertools.product([True, False], [1, 2], [1, 4]) 82 for bias, dilation, groups in options: 83 iC = 3 * groups 84 oC = 10 * groups 85 m = M(iC, 86 oC, 87 bias, 88 kernel_size=(kernel_size, kernel_size), 89 stride=2, 90 padding=1, 91 dilation=dilation, 92 groups=groups).to(memory_format=memory_format) 93 x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format) 94 graph = self._check_model(m, x, trace) 95 conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d' 96 if enabled: 97 self.assertFused(graph, [conv_node_name]) 98 self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) 99 else: 100 self.assertGraphContains(graph, kind=conv_node_name) 101 102 def test_conv_unary_fusion_nnc(self): 103 class M(nn.Module): 104 def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): 105 super().__init__() 106 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) 107 self.unary = unary_fn 108 109 def forward(self, x): 110 x = self.conv(x) 111 x = self.unary(x) 112 return x 113 114 for memory_format, enabled in [ 115 [torch.contiguous_format, False], 116 [torch.channels_last, True], 117 ]: 118 for unary_fn in [torch.relu]: 119 for bias in [True, False]: 120 for oC in [1, 10]: 121 m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format) 122 x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format) 123 124 graph = self._check_model(m, x) 125 if enabled: 126 self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__]) 127 self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) 128 else: 129 self.assertGraphContains(graph, kind='aten::conv2d') 130 131 def test_unsupported_conv(self): 132 class M(nn.Module): 133 def __init__(self, m, in_channels, out_channels, bias, **kwargs): 134 super().__init__() 135 self.conv = m(in_channels, out_channels, bias=bias, **kwargs) 136 137 def forward(self, x): 138 res = self.conv(x) 139 return res 140 141 for module, dim, memory_format in [ 142 [nn.Conv3d, 3, torch.contiguous_format], 143 [nn.Conv3d, 3, torch.channels_last_3d], 144 [nn.ConvTranspose2d, 2, torch.contiguous_format], 145 [nn.ConvTranspose2d, 2, torch.channels_last], 146 ]: 147 trace = True 148 input_size = 224 149 batch_size = 1 150 kernel_size = 3 151 groups = 2 152 bias = True 153 iC = 3 * groups 154 oC = 10 * groups 155 dilation = 2 156 m = M(module, 157 iC, 158 oC, 159 bias, 160 kernel_size=kernel_size, 161 stride=2, 162 padding=1, 163 dilation=dilation, 164 groups=groups).to(memory_format=memory_format) 165 input_sizes = [batch_size, iC, input_size, input_size] 166 if dim == 3: 167 input_sizes.append(input_size) 168 x = torch.randn(input_sizes).to(memory_format=memory_format) 169 graph = self._check_model(m, x, trace) 170 self.assertGraphContains(graph, kind='aten::_convolution') 171 172 def _unary_list(self): 173 unary_list = { 174 "relu": PointwisePostOp("relu", nn.ReLU()), 175 "sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()), 176 "tanh": PointwisePostOp("tanh", nn.Tanh()), 177 "hardswish": PointwisePostOp("hardswish", nn.Hardswish()), 178 "leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]), 179 "hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]), 180 "gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"), 181 "gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"), 182 } 183 return unary_list 184 185 def _binary_list(self): 186 binary_list = { 187 "add": torch.add, 188 "sub": torch.sub, 189 "mul": torch.mul, 190 "div": torch.div, 191 } 192 return binary_list 193 194 def test_linear_unary_fusion_ops(self): 195 class M(nn.Module): 196 def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): 197 super().__init__() 198 self.linear = torch.nn.Linear( 199 in_channels, out_channels, bias=bias, **kwargs 200 ) 201 self.unary = unary_fn 202 203 def forward(self, x): 204 x = self.linear(x) 205 x = self.unary(x) 206 return x 207 208 for pointwise_info in self._unary_list().values(): 209 # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor 210 # but it's strides is not default contiguous strides. 211 options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False]) 212 for (input_shape, input_stride), bias in options: 213 with torch.no_grad(): 214 mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval() 215 v = torch.randn(input_shape) 216 if input_stride is not None: 217 v = v.as_strided(input_shape, input_stride) 218 ref = mod(v) 219 attr = pointwise_info.attr 220 scalars = pointwise_info.scalars 221 algorithm = pointwise_info.algorithm 222 fused = torch.ops.mkldnn._linear_pointwise( 223 v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm 224 ) 225 self.assertEqual(ref, fused) 226 227 228 def test_conv_unary_fusion_ops(self): 229 class M(nn.Module): 230 def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): 231 super().__init__() 232 self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) 233 self.unary = unary_fn 234 235 def forward(self, x): 236 x = self.conv(x) 237 x = self.unary(x) 238 return x 239 240 input_shapes = {2: (112, 112), 3: (55, 55, 55)} 241 for pointwise_info in self._unary_list().values(): 242 for dim in [2, 3]: 243 channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d 244 options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) 245 for bias, dilation, groups, memory_format in options: 246 oC = 32 * groups 247 iC = 3 * groups 248 x_shape = (1, iC) + input_shapes[dim] 249 x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) 250 mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3) 251 mod = mod.to(memory_format=memory_format).eval() 252 with torch.no_grad(): 253 ref = mod(x) 254 attr = pointwise_info.attr 255 scalars = pointwise_info.scalars 256 algorithm = pointwise_info.algorithm 257 fused = torch.ops.mkldnn._convolution_pointwise( 258 x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, 259 mod.conv.groups, attr, scalars, algorithm 260 ) 261 self.assertEqual(ref, fused) 262 263 264 def test_conv_binary_fusion_ops(self): 265 class M(nn.Module): 266 def __init__(self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): 267 super().__init__() 268 self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) 269 self.binary = binary_fn 270 271 def forward(self, x, other): 272 x = self.conv(x) 273 x = self.binary(x, other) 274 return x 275 276 input_shapes = {2: (112, 112), 3: (22, 22, 22)} 277 for pointwise_name, pointwise_fn in self._binary_list().items(): 278 for dim in [2, 3]: 279 channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d 280 options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) 281 for fuse_relu, bias, dilation, groups, memory_format in options: 282 oC = 32 * groups 283 iC = 3 * groups 284 x_shape = (1, iC) + input_shapes[dim] 285 x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) 286 mod = M(pointwise_fn, dim, iC, oC, dilation, groups, bias, kernel_size=3) 287 mod = mod.to(memory_format=memory_format).eval() 288 other = torch.randn_like(mod.conv(x)) 289 with torch.no_grad(): 290 ref = mod(x, other) 291 unary_attr = None 292 if fuse_relu: 293 ref.relu_() 294 unary_attr = "relu" 295 attr = pointwise_name 296 fused = torch.ops.mkldnn._convolution_pointwise( 297 x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, 298 mod.conv.groups, attr, None, unary_attr, [], None 299 ) 300 # for binary add, we support inplace version. 301 if attr == "add": 302 fused_inplace = torch.ops.mkldnn._convolution_pointwise_( 303 other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, 304 mod.conv.groups, attr, None, unary_attr, [], None 305 ) 306 self.assertEqual(ref, other) 307 self.assertEqual(ref, fused_inplace) 308 309 self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4) 310 311 312 def test_linear_binary_fusion_ops(self): 313 class M(nn.Module): 314 def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): 315 super().__init__() 316 self.linear = torch.nn.Linear( 317 in_channels, out_channels, bias=bias, **kwargs 318 ) 319 self.binary = binary_fn 320 321 def forward(self, x, other): 322 x = self.linear(x) 323 x = self.binary(x, other) 324 return x 325 326 out_feature = 20 327 for pointwise_name, pointwise_fn in self._binary_list().items(): 328 # Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor 329 # but it's strides is not default contiguous strides. 330 options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False]) 331 for (input_shape, input_stride), bias in options: 332 with torch.no_grad(): 333 mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval() 334 v = torch.randn(input_shape) 335 if input_stride is not None: 336 v = v.as_strided(input_shape, input_stride) 337 other = torch.randn(input_shape[:-1] + [out_feature]) 338 ref = mod(v, other) 339 attr = pointwise_name 340 fused = torch.ops.mkldnn._linear_pointwise( 341 v, other, mod.linear.weight, mod.linear.bias, attr 342 ) 343 self.assertEqual(ref, fused) 344 345 def test_conv_transpose_unary_fusion_ops(self): 346 class M(nn.Module): 347 def __init__(self, unary_fn, dim, in_channels, out_channels, kernel_size, **kwargs): 348 super().__init__() 349 self.conv_transpose = CONV_TRANSPOSE_MODULES[dim](in_channels, out_channels, kernel_size, **kwargs) 350 self.unary = unary_fn 351 352 def forward(self, x): 353 x = self.conv_transpose(x) 354 x = self.unary(x) 355 return x 356 357 input_shapes = {2: (28, 28)} 358 kernel_size = 3 359 for pointwise_info in self._unary_list().values(): 360 for dim in [2]: 361 channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d 362 options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True]) 363 for bias, dilation, groups, memory_format, prepack_weight in options: 364 oC = 32 * groups 365 iC = 3 * groups 366 x_shape = (1, iC) + input_shapes[dim] 367 x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) 368 mod = M(pointwise_info.pointwise_module, dim, iC, oC, kernel_size, dilation=dilation, groups=groups, bias=bias) 369 mod = mod.to(memory_format=memory_format).eval() 370 with torch.no_grad(): 371 ref = mod(x) 372 attr = pointwise_info.attr 373 scalars = pointwise_info.scalars 374 algorithm = pointwise_info.algorithm 375 376 if prepack_weight: 377 packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight( 378 mod.conv_transpose.weight, 379 mod.conv_transpose.padding, 380 mod.conv_transpose.output_padding, 381 mod.conv_transpose.stride, 382 mod.conv_transpose.dilation, 383 mod.conv_transpose.groups, 384 x.size()) 385 mod.conv_transpose.weight = torch.nn.Parameter( 386 packed_weight, 387 requires_grad=mod.conv_transpose.weight.requires_grad, 388 ) 389 390 fused = torch.ops.mkldnn._convolution_transpose_pointwise( 391 x, 392 mod.conv_transpose.weight, 393 mod.conv_transpose.bias, 394 mod.conv_transpose.padding, 395 mod.conv_transpose.output_padding, 396 mod.conv_transpose.stride, 397 mod.conv_transpose.dilation, 398 mod.conv_transpose.groups, 399 attr, 400 scalars, 401 algorithm) 402 self.assertEqual(ref, fused) 403 404if __name__ == "__main__": 405 run_tests() 406