1# Owner(s): ["module: mkldnn"] 2 3import copy 4import itertools 5import functools 6import unittest 7from contextlib import nullcontext 8 9try: 10 import torchvision 11 HAS_TORCHVISION = True 12except ImportError: 13 HAS_TORCHVISION = False 14 15skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 16 17import torch 18import torch.nn.functional as F 19import torch.jit 20import torch.backends.mkldnn 21from torch.utils import mkldnn as mkldnn_utils 22from torch.testing._internal.common_utils import TestCase, \ 23 run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \ 24 skipIfTorchDynamo, xfailIfTorchDynamo 25from torch.testing._internal.common_device_type import ( 26 instantiate_device_type_tests, 27 dtypes, 28) 29 30# batched grad doesn't support mkldnn 31gradcheck = functools.partial(gradcheck, check_batched_grad=False) 32gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False) 33 34 35types = [torch.float, torch.bfloat16, torch.half] 36 37# Comment the line below to find out the CI machines having MKL-DNN build disabled 38@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") 39class TestMkldnn(TestCase): 40 def test_conversion(self): 41 for cpu_tensor in [torch.randn((1, 2, 3, 4), 42 dtype=torch.float, device=torch.device('cpu')), 43 torch.randn((1, 2, 3, 4, 5), 44 dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: 45 cpu_tensor.requires_grad_() 46 convert_dtypes = {torch.half: [torch.half, torch.float], 47 torch.bfloat16: [torch.bfloat16, torch.float], 48 torch.float: [torch.bfloat16, torch.half]} 49 # float/bfloat16/half cpu tensor to mkldnn tensortensor. 50 for dtype1 in types: 51 mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1) 52 self.assertEqual(mkldnn_tensor.dtype, dtype1) 53 cpu_tensor_1 = mkldnn_tensor.to_dense() 54 # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor 55 self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 56 # mkldnn float/bfloat tensor to cpu float or bfloat tensor 57 for dtype2 in convert_dtypes[dtype1]: 58 cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) 59 self.assertEqual(cpu_tensor_2.dtype, dtype2) 60 atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2 61 self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0) 62 63 self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 64 self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) 65 self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 66 if dtype1 == torch.float: 67 self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) 68 else: 69 self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2) 70 self.assertRaisesRegex(RuntimeError, 71 "Cannot access data pointer of Tensor that doesn't have storage", 72 lambda: mkldnn_tensor.data_ptr() != 0) 73 74 # bfloat cpu tensor to mkldnn float tensor or bfloat tensor. 75 for orig_dtype in [torch.half, torch.bfloat16]: 76 cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype) 77 for dtype1 in convert_dtypes[orig_dtype]: 78 mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1) 79 self.assertEqual(mkldnn_tensor.dtype, dtype1) 80 cpu_tensor_1 = mkldnn_tensor.to_dense() 81 # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor 82 self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 83 # mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor 84 for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]: 85 cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) 86 self.assertEqual(cpu_tensor_2.dtype, dtype2) 87 self.assertEqual(cpu_tensor_lower, 88 cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0) 89 90 self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 91 self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) 92 self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 93 if dtype1 in [torch.bfloat16, torch.half]: 94 self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size()) 95 else: 96 self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2) 97 self.assertRaisesRegex(RuntimeError, 98 "Cannot access data pointer of Tensor that doesn't have storage", 99 lambda: mkldnn_tensor.data_ptr() != 0) 100 101 def test_conversion_byte_char(self): 102 int8_types = [torch.int8, torch.uint8] 103 for int8_type in int8_types: 104 low = -100 if int8_type is torch.int8 else 0 105 high = 100 106 for cpu_tensor in [torch.randint( 107 low=low, 108 high=high, 109 size=(1, 2, 3, 4), 110 dtype=torch.int64, 111 device=torch.device('cpu')), 112 torch.randint( 113 low=low, 114 high=high, 115 size=(1, 2, 3, 4, 5), 116 dtype=torch.int64, 117 device=torch.device('cpu'))[:, :, :, :, :]]: 118 119 cpu_tensor = cpu_tensor.to(dtype=int8_type) 120 mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type) 121 self.assertEqual(mkldnn_tensor.dtype, int8_type) 122 cpu_tensor_1 = mkldnn_tensor.to_dense() 123 self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 124 self.assertEqual(cpu_tensor, cpu_tensor_1) 125 self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 126 self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size()) 127 self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 128 self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) 129 self.assertRaisesRegex(RuntimeError, 130 "Cannot access data pointer of Tensor that doesn't have storage", 131 lambda: mkldnn_tensor.data_ptr() != 0) 132 133 def test_copy(self): 134 x = torch.randn(4, 5, dtype=torch.float32) 135 mkldnn_x = x.to_mkldnn() 136 mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn() 137 mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn() 138 mkldnn_y.copy_(mkldnn_x) 139 self.assertEqual(x, mkldnn_y.to_dense()) 140 self.assertRaisesRegex(RuntimeError, 141 "copy_mkldnn_: only support same size tensor.", 142 lambda: mkldnn_z.copy_(mkldnn_x)) 143 self.assertRaisesRegex(RuntimeError, 144 "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! " 145 "Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor", 146 lambda: x.copy_(mkldnn_x)) 147 self.assertRaisesRegex(RuntimeError, 148 "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! " 149 "Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor", 150 lambda: mkldnn_x.copy_(x)) 151 152 def test_unsupported(self): 153 # unsupported types and unsupported types with gpu 154 for dtype in [torch.double, torch.uint8, torch.int8, 155 torch.short, torch.int, torch.long]: 156 with self.assertRaises(RuntimeError) as context: 157 torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn() 158 if torch.cuda.is_available(): 159 with self.assertRaises(RuntimeError) as context: 160 torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn() 161 # supported type with gpu 162 if torch.cuda.is_available(): 163 with self.assertRaises(RuntimeError) as context: 164 torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn() 165 # some factory functions 166 for creator in [torch.ones, torch.randn, torch.rand]: 167 with self.assertRaises(RuntimeError) as context: 168 creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) 169 170 def test_mkldnn_conv_shapecheck(self): 171 input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) 172 w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) 173 b1 = torch.full((1,), 1, dtype=torch.float32) 174 w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32) 175 b2 = torch.full((2,), 1, dtype=torch.float32) 176 options = zip([-1, 0, 0, 0, 0, 0, 0], # padding 177 [1, 0, 1, 1, 1, 1, 1], # stride 178 [1, 1, 0, 1, 1, 1, 1], # dilation 179 [1, 1, 1, 0, 2, 1, 1], # groups 180 [w1, w1, w1, w1, w1, w1, w2], # weight 181 [b1, b1, b1, b1, b1, b2, b1]) # bias 182 for pad, st, dil, gr, w, b in options: 183 with self.assertRaises(RuntimeError) as _: 184 torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr) 185 186 def test_autograd_to_mkldnn(self): 187 # MKLDNN only supports float32 188 root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) 189 190 def func(root): 191 return root.to_mkldnn().to_dense() 192 193 # because MKLDNN only supports float32, we need to lessen the precision. 194 # these numbers are just empirical results that seem to work. 195 self.assertWarnsRegex(UserWarning, 196 'double precision floating point', 197 lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2)) 198 self.assertWarnsRegex(UserWarning, 199 'double precision floating point', 200 lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2)) 201 202 def test_autograd_from_mkldnn(self): 203 # MKLDNN only supports float32 204 root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() 205 206 def func(root): 207 return root.to_dense() 208 209 # because MKLDNN only supports float32, we need to lessen the precision. 210 # these numbers are just empirical results that seem to work. 211 self.assertWarnsRegex(UserWarning, 212 'double precision floating point', 213 lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2)) 214 215 def test_detach(self): 216 root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() 217 218 detach = root.detach() 219 self.assertEqual((4, 5), detach.size()) 220 self.assertFalse(detach.requires_grad) 221 self.assertTrue(root.requires_grad) 222 223 detach_ = root.detach_() 224 self.assertEqual((4, 5), detach_.size()) 225 self.assertFalse(detach_.requires_grad) 226 self.assertFalse(root.requires_grad) 227 228 def test_repr(self): 229 self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4), 230 dtype=torch.float, device=torch.device('cpu')).to_mkldnn())) 231 232 def _test_conv_base(self, dim): 233 conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 234 input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)} 235 options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 236 for train, bias, dilation, groups in options: 237 N = torch.randint(3, 10, (1,)).item() 238 M = torch.randint(1, 3, (1,)).item() * groups 239 C = torch.randint(1, 3, (1,)).item() * groups 240 x_shape = (N, C) + input_shapes[dim] 241 x = torch.randn(x_shape, dtype=torch.float32) 242 conv = conv_module[dim](in_channels=C, 243 out_channels=M, 244 kernel_size=3, 245 stride=2, 246 padding=1, 247 dilation=dilation, 248 bias=bias, 249 groups=groups).float() 250 x1 = x.clone() 251 x2 = x.clone().to_mkldnn() 252 if not train: 253 mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv)) 254 elif train and dim != 1: 255 # TODO: enable conv1d training. 256 x1.requires_grad_() 257 x2.requires_grad_() 258 mkldnn_conv = copy.deepcopy(conv) 259 with torch.backends.mkldnn.flags(enabled=False): 260 y_aten = conv(x1) 261 if train and dim != 1: 262 loss1 = y_aten.sum() 263 loss1.backward() 264 if not train or (train and dim != 1): 265 y_mkldnn = mkldnn_conv(x2).to_dense() 266 self.assertEqual(y_aten, y_mkldnn) 267 if not train: 268 self._test_serialization(mkldnn_conv, (x.to_mkldnn(),)) 269 self._test_tracing(mkldnn_conv, (x.to_mkldnn(),)) 270 elif dim != 1: 271 loss2 = y_mkldnn.sum() 272 loss2.backward() 273 self.assertTrue(x2.grad.is_mkldnn) 274 self.assertEqual(x1.grad, x2.grad.to_dense()) 275 self.assertEqual(conv.weight.grad, 276 mkldnn_conv.weight.grad, 277 atol=1e-3, 278 rtol=1e-3) 279 if bias: 280 self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) 281 282 def test_conv1d(self): 283 self._test_conv_base(dim=1) 284 285 def test_conv2d(self): 286 self._test_conv_base(dim=2) 287 288 def test_conv3d(self): 289 self._test_conv_base(dim=3) 290 291 def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype): 292 input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)} 293 options = itertools.product([True, False], [1, 2], [1, 4]) 294 for bias, dilation, groups in options: 295 N = torch.randint(1, 3, (1,)).item() 296 M = torch.randint(1, 3, (1,)).item() * groups 297 C = torch.randint(1, 3, (1,)).item() * groups 298 x_shape = (N, C) + input_shapes[dim] 299 x = torch.randn(x_shape, dtype=torch.float32) 300 # TODO: remove this when group depthwise is supported: 301 if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, 302 torch.nn.ConvTranspose3d] and groups > 1 and C == groups: 303 continue 304 conv = conv_module(in_channels=C, 305 out_channels=M, 306 kernel_size=3, 307 stride=2, 308 padding=1, 309 dilation=dilation, 310 bias=bias, 311 groups=groups).float() 312 x_lower = x.to(dtype=dtype) 313 if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \ 314 (dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()): 315 mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv)) 316 mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype) 317 y = mkldnn_conv(x.to_mkldnn()).to_dense() 318 y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32) 319 self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3) 320 else: 321 msg = { 322 torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq", 323 torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16", 324 } 325 with self.assertRaisesRegex(RuntimeError, msg[dtype]): 326 mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype) 327 y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32) 328 # test thnn impl 329 conv_lower = copy.deepcopy(conv).to(dtype=dtype) 330 conv_ref = copy.deepcopy(conv_lower).float() 331 with torch.backends.mkldnn.flags(enabled=False): 332 x_ref = x_lower.clone().float().detach().requires_grad_() 333 x_lower.requires_grad_() 334 y = conv_ref(x_ref) 335 y_lower = conv_lower(x_lower).float() 336 self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3) 337 338 @dtypes(torch.float16, torch.bfloat16) 339 def test_conv_deconv_1d_lower_precision(self, dtype): 340 self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype) 341 self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype) 342 343 @dtypes(torch.float16, torch.bfloat16) 344 def test_conv_deconv_2d_lower_precision(self, dtype): 345 self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype) 346 self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype) 347 348 @dtypes(torch.float16, torch.bfloat16) 349 def test_conv_deconv_3d_lower_precision(self, dtype): 350 self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype) 351 self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype) 352 353 def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None): 354 input_shapes = {2: (55, 55), 3: (14, 14, 14)} 355 options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 356 if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: 357 cl_format = torch.channels_last 358 input_shape = input_shapes[2] 359 elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]: 360 cl_format = torch.channels_last_3d 361 input_shape = input_shapes[3] 362 363 for train, bias, dilation, groups in options: 364 N = torch.randint(3, 10, (1,)).item() 365 M = torch.randint(1, 3, (1,)).item() * groups 366 C = torch.randint(1, 3, (1,)).item() * groups 367 x_shape = (N, C) + input_shape 368 x = torch.randn(x_shape, dtype=dtype) 369 370 # conv1: mkldnn conv/deconv in contiguous memory format (nchw) 371 # conv2: mkldnn conv/deconv in channels last memory format (nhwc) 372 conv1 = conv_module(in_channels=C, 373 out_channels=M, 374 kernel_size=3, 375 stride=2, 376 padding=1, 377 dilation=dilation, 378 bias=bias, 379 groups=groups).to(dtype=dtype) 380 conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format) 381 x1 = x.clone() 382 x2 = x.clone().to(memory_format=cl_format) 383 if train: 384 x1.requires_grad_() 385 x2.requires_grad_() 386 y1 = conv1(x1) 387 y2 = conv2(x2) 388 self.assertEqual(y1, y2, atol=prec, rtol=prec) 389 390 if train: 391 y1.sum().backward() 392 y2.sum().backward() 393 self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format)) 394 self.assertEqual(conv1.weight.grad, 395 conv2.weight.grad, 396 atol=1e-3, 397 rtol=1e-3) 398 if bias: 399 self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) 400 self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) 401 402 def test_conv_nhwc_fp32(self): 403 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) 404 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) 405 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32) 406 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32) 407 408 @dtypes(torch.float16, torch.bfloat16) 409 def test_conv_nhwc_lower_precision(self, dtype): 410 # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported() 411 # returns false, bf16/fp16 CPU conv will fall back to thnn impl 412 support_checks = { 413 torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 414 torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported 415 } 416 if support_checks[dtype](): 417 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype) 418 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype) 419 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype) 420 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype) 421 422 # BF16/FP16 fallback implementations are divided into two parts im2col+gemm, 423 # and the number of data type conversions in the middle is more than that of onednn's direct conv, 424 # resulting in additional accuracy loss. 425 precisions = { 426 torch.bfloat16: 1e-2, 427 torch.float16: 2e-3, 428 } 429 prec = precisions[dtype] 430 with torch.backends.mkldnn.flags(enabled=False): 431 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec) 432 self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec) 433 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec) 434 self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) 435 436 437 def test_conv_transpose_nhwc_fp32(self): 438 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) 439 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) 440 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32) 441 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32) 442 443 @dtypes(torch.float16, torch.bfloat16) 444 def test_conv_transpose_nhwc_lower_precision(self, dtype): 445 # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported() 446 # returns false, bf16/fp16 CPU conv will fall back to thnn impl 447 support_checks = { 448 torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 449 torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported 450 } 451 if support_checks[dtype](): 452 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype) 453 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype) 454 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype) 455 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype) 456 457 # BF16/FP16 fallback implementations are divided into two parts col2im+gemm, 458 # and the number of data type conversions in the middle is more than that of onednn's direct conv, 459 # resulting in additional accuracy loss. 460 precisions = { 461 torch.bfloat16: 2e-2, 462 torch.float16: 3e-3, 463 } 464 prec = precisions[dtype] 465 with torch.backends.mkldnn.flags(enabled=False): 466 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec) 467 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec) 468 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec) 469 self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec) 470 471 def _test_conv_transpose_base(self, dim): 472 conv_module = { 473 1: torch.nn.ConvTranspose1d, 474 2: torch.nn.ConvTranspose2d, 475 3: torch.nn.ConvTranspose3d 476 } 477 input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)} 478 options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 479 for train, bias, dilation, groups in options: 480 N = torch.randint(3, 10, (1,)).item() 481 M = torch.randint(1, 3, (1,)).item() * groups 482 C = torch.randint(1, 3, (1,)).item() * groups 483 x_shape = (N, C) + input_shapes[dim] 484 data = torch.randn(x_shape, dtype=torch.float32) 485 # conv: mkldnn tranpose conv fp32 486 # conv_ref: thnn transpose conv fp32 487 conv = conv_module[dim](in_channels=C, 488 out_channels=M, 489 kernel_size=3, 490 stride=1, 491 padding=1, 492 dilation=dilation, 493 bias=bias, 494 groups=groups).to(dtype=torch.float32) 495 x = data.clone() 496 x_ref = x.clone() 497 if train: 498 x.requires_grad_() 499 x_ref.requires_grad_() 500 501 conv_ref = copy.deepcopy(conv) 502 with torch.backends.mkldnn.flags(enabled=False): 503 y_ref = conv_ref(x_ref) 504 if train: 505 y_ref.sum().backward() 506 507 y = conv(x) 508 if train: 509 y.sum().backward() 510 511 self.assertEqual(y, y_ref) 512 if train: 513 self.assertEqual(x.grad, x_ref.grad) 514 self.assertEqual(conv.weight.grad, 515 conv_ref.weight.grad, 516 atol=1e-3, 517 rtol=1e-3) 518 if bias: 519 self.assertEqual(conv.bias.grad, conv_ref.bias.grad) 520 521 def test_conv_transpose1d(self): 522 self._test_conv_transpose_base(dim=1) 523 524 def test_conv_transpose2d(self): 525 self._test_conv_transpose_base(dim=2) 526 527 def test_conv_transpose3d(self): 528 self._test_conv_transpose_base(dim=3) 529 530 def test_conv2d_legacy_jit_model(self): 531 """ 532 MKLDNN integration used to serialize models with 5d weight for grouped 533 convolutions, we'd like to preserve this behavior 534 """ 535 g = 4 536 conv2d = torch.nn.Conv2d(16, 16, 3, groups=g) 537 conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d) 538 539 # contrive legacy conv2d module with a 5-d weight 540 o, i, h, w = conv2d.weight.shape 541 weight_5d = conv2d.weight.reshape((g, o // g, i, h, w)) 542 conv2d_mkldnn.weight = weight_5d.to_mkldnn() 543 544 x = torch.randn(1, 16, 8, 8) 545 546 with TemporaryFileName() as fname: 547 torch.jit.save(conv2d_mkldnn, fname) 548 conv2d_loaded = torch.jit.load(fname) 549 550 self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5) 551 self.assertEqual(conv2d_loaded.weight.ndimension(), 4) 552 self.assertEqual( 553 conv2d(x), 554 conv2d_loaded(x.to_mkldnn()).to_dense()) 555 556 # This test is to check whether 1D conv is supported for mkldnn tensor, 557 # which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034. 558 def test_conv1d_functional(self): 559 input = torch.randn(2, 3, 10).to_mkldnn() 560 weight = torch.randn(3, 3, 3).to_mkldnn() 561 bias = torch.randn(3).to_mkldnn() 562 output = torch.nn.functional.conv1d(input, weight, bias) 563 self.assertEqual(output.size(), torch.Size([2, 3, 8])) 564 565 def test_relu(self): 566 x = torch.randn((4, 5), dtype=torch.float32) * 10 567 x1 = x.clone().requires_grad_() 568 x2 = x.clone().to_mkldnn().requires_grad_() 569 y1 = torch.relu(x1) 570 y2 = torch.relu(x2).to_dense() 571 loss1 = y1.sum() 572 loss2 = y2.sum() 573 loss1.backward() 574 loss2.backward() 575 self.assertEqual(y1, y2) 576 self.assertEqual(x1.grad, x2.grad.to_dense()) 577 578 def test_relu_(self): 579 x = torch.randn((4, 5), dtype=torch.float32) * 10 580 x1 = x.clone().requires_grad_() 581 x2 = x.clone().to_mkldnn().requires_grad_() 582 y1 = torch.relu_(x1.clone()) 583 y2 = torch.relu_(x2.clone()).to_dense() 584 loss1 = y1.sum() 585 loss2 = y2.sum() 586 loss1.backward() 587 loss2.backward() 588 self.assertEqual(y1, y2) 589 self.assertEqual(x1.grad, x2.grad.to_dense()) 590 591 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 592 def _test_relu_bf16_base(self, name): 593 x = torch.randn((4, 5), dtype=torch.float32) * 10 594 x_bf16 = x.bfloat16() 595 fn = getattr(torch, name) 596 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 597 y = fn(x.to_mkldnn()).to_dense() 598 y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32) 599 self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 600 else: 601 msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 602 self.assertRaisesRegex(RuntimeError, 603 msg, 604 lambda: fn(x_bf16.to_mkldnn())) 605 606 def test_relu_bf16(self): 607 self._test_relu_bf16_base("relu") 608 609 def test_relu_inplace_bf16(self): 610 self._test_relu_bf16_base("relu_") 611 612 def test_gelu(self): 613 m = torch.nn.GELU() 614 x = torch.randn((4, 5), dtype=torch.float32) * 10 615 x1 = x.clone().requires_grad_() 616 x2 = x.clone().to_mkldnn().requires_grad_() 617 y1 = m(x1) 618 y2 = m(x2).to_dense() 619 loss1 = y1.sum() 620 loss2 = y2.sum() 621 loss1.backward() 622 loss2.backward() 623 self.assertEqual(y1, y2) 624 self.assertEqual(x1.grad, x2.grad.to_dense()) 625 626 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 627 def test_gelu_bf16(self): 628 m = torch.nn.GELU() 629 x = torch.randn((4, 5), dtype=torch.float32) * 10 630 x1 = x.clone().to_mkldnn().requires_grad_() 631 x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_() 632 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 633 y1 = m(x1).to_dense() 634 y2 = m(x2).to_dense() 635 loss1 = y1.sum() 636 loss2 = y2.sum() 637 loss1.backward() 638 loss2.backward() 639 self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0) 640 self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0) 641 else: 642 msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 643 self.assertRaisesRegex(RuntimeError, 644 msg, 645 lambda: m(x2)) 646 647 def _test_prelu_base(self, size, num_channels): 648 x = torch.randn(size, dtype=torch.float32) 649 x1 = x.clone().requires_grad_() 650 x2 = x.clone().to_mkldnn().requires_grad_() 651 x3 = x.clone().to_mkldnn().requires_grad_() 652 m1 = torch.nn.PReLU(num_channels) 653 m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1)) 654 m3 = copy.deepcopy(m1) 655 y1 = m1(x1) 656 y2 = m2(x2).to_dense() 657 y3 = m3(x3).to_dense() # Only convert data to mkldnn, weight is Aten tensor 658 loss1 = y1.sum() 659 loss1.backward() 660 loss2 = y2.sum() 661 loss2.backward() 662 loss3 = y3.sum() 663 loss3.backward() 664 self.assertEqual(y1, y2) 665 self.assertEqual(y1, y3) 666 self.assertEqual(x1.grad, x2.grad.to_dense()) 667 self.assertEqual(x1.grad, x3.grad.to_dense()) 668 669 def test_prelu(self): 670 self._test_prelu_base(torch.Size([16]), 1) 671 self._test_prelu_base(torch.Size([16, 64]), 1) 672 self._test_prelu_base(torch.Size([16, 64]), 64) 673 self._test_prelu_base(torch.Size([16, 64, 112]), 1) 674 self._test_prelu_base(torch.Size([16, 64, 112]), 64) 675 self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1) 676 self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64) 677 self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1) 678 self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64) 679 680 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 681 def _test_prelu_bf16_base(self, size, num_channels): 682 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 683 x = torch.randn(size, dtype=torch.float32) 684 x_fp32 = x.clone().to_mkldnn().requires_grad_() 685 x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_() 686 m = mkldnn_utils.to_mkldnn(torch.nn.PReLU()) 687 m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16) 688 689 y = m(x_fp32).to_dense() 690 y_bf16 = m_bf16(x_bf16).to_dense() 691 self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3) 692 693 loss = y.sum() 694 loss.backward() 695 loss_bf16 = y_bf16.sum() 696 loss_bf16.backward() 697 self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32)) 698 else: 699 x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_() 700 m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16) 701 msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 702 self.assertRaisesRegex(RuntimeError, 703 msg, 704 lambda: m_bf16(x_bf16)) 705 706 def test_prelu_bf16(self): 707 self._test_prelu_bf16_base(torch.Size([16]), 1) 708 self._test_prelu_bf16_base(torch.Size([16, 64]), 1) 709 self._test_prelu_bf16_base(torch.Size([16, 64]), 64) 710 self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1) 711 self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64) 712 self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1) 713 self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64) 714 715 def _test_max_pool_base(self, dim, input): 716 pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d} 717 for stride in [1, 2, 3]: 718 for ceil_mode in [False, True]: 719 max_pool = pool_module[dim]( 720 kernel_size=3 if not ceil_mode else 7, 721 stride=stride, 722 padding=1, 723 ceil_mode=ceil_mode) 724 725 x1 = input.clone().requires_grad_() 726 x2 = input.clone().to_mkldnn().requires_grad_() 727 y1 = max_pool(x1) 728 y2 = max_pool(x2).to_dense() 729 loss1 = y1.sum() 730 loss2 = y2.sum() 731 loss1.backward() 732 loss2.backward() 733 self.assertEqual(y1, y2) 734 self.assertEqual(x1.grad, x2.grad.to_dense()) 735 736 def test_max_pool2d(self): 737 N = torch.randint(3, 10, (1,)).item() 738 C = torch.randint(3, 10, (1,)).item() 739 for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 740 x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 741 self._test_max_pool_base(dim=2, input=x) 742 743 def test_max_pool3d(self): 744 N = torch.randint(3, 10, (1,)).item() 745 C = torch.randint(3, 10, (1,)).item() 746 for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]: 747 x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10 748 self._test_max_pool_base(dim=3, input=x) 749 750 751 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 752 def _test_max_pool_bf16_base(self, dim, input): 753 pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d} 754 x_bf16 = input.bfloat16() 755 for stride in [1, 2, 3]: 756 for ceil_mode in [False, True]: 757 max_pool = pool_module[dim]( 758 kernel_size=3 if not ceil_mode else 7, 759 stride=stride, 760 padding=1, 761 ceil_mode=ceil_mode) 762 763 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 764 y = max_pool(input.to_mkldnn()).to_dense() 765 y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32) 766 self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3) 767 else: 768 msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim 769 self.assertRaisesRegex(RuntimeError, 770 msg, 771 lambda: max_pool(x_bf16.to_mkldnn())) 772 773 def test_max_pool2d_bf16(self): 774 N = torch.randint(3, 10, (1,)).item() 775 C = torch.randint(3, 10, (1,)).item() 776 for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 777 x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 778 self._test_max_pool_bf16_base(dim=2, input=x) 779 780 def test_max_pool3d_bf16(self): 781 N = torch.randint(3, 10, (1,)).item() 782 C = torch.randint(3, 10, (1,)).item() 783 for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]: 784 x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10 785 self._test_max_pool_bf16_base(dim=3, input=x) 786 787 def test_max_pool2d_stride_none(self): 788 N = torch.randint(3, 10, (1,)).item() 789 C = torch.randint(3, 10, (1,)).item() 790 791 for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 792 x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 793 for ceil_mode in [False, True]: 794 y1 = F.max_pool2d( 795 x, 796 kernel_size=3 if not ceil_mode else 7, 797 stride=None, 798 padding=1, 799 ceil_mode=ceil_mode) 800 801 y2 = F.max_pool2d( 802 x.to_mkldnn(), 803 kernel_size=3 if not ceil_mode else 7, 804 stride=None, 805 padding=1, 806 ceil_mode=ceil_mode) 807 808 self.assertEqual(y1, y2.to_dense()) 809 810 # https://github.com/pytorch/pytorch/issues/127111 811 @xfailIfTorchDynamo 812 def test_max_pool_unsupported(self): 813 # OneDNN not support dilation max_pooling, will be avilabled in v2.0. 814 N = torch.randint(3, 10, (1,)).item() 815 C = torch.randint(3, 10, (1,)).item() 816 817 # 2d dilation case 818 x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn() 819 max_pool2d = torch.nn.MaxPool2d( 820 kernel_size=3, 821 stride=3, 822 padding=1, 823 dilation=2) 824 self.assertRaisesRegex(RuntimeError, 825 'mkldnn_max_pool2d does not support dilation case', 826 lambda: max_pool2d(x)) 827 828 # 3d dilation case 829 x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn() 830 max_pool3d = torch.nn.MaxPool3d( 831 kernel_size=3, 832 stride=3, 833 padding=1, 834 dilation=2) 835 self.assertRaisesRegex(RuntimeError, 836 'mkldnn_max_pool3d does not support dilation case', 837 lambda: max_pool3d(x)) 838 839 def _test_avg_pool_base(self, dim, input): 840 avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d} 841 for count_include_pad in [True, False]: 842 avg_pool = avg_module[dim]( 843 kernel_size=3, 844 stride=2, 845 padding=1, 846 count_include_pad=count_include_pad) 847 848 x1 = input.clone().requires_grad_() 849 x2 = input.clone().to_mkldnn().requires_grad_() 850 y1 = avg_pool(x1) 851 y2 = avg_pool(x2).to_dense() 852 loss1 = y1.sum() 853 loss2 = y2.sum() 854 loss1.backward() 855 loss2.backward() 856 self.assertEqual(y1, y2) 857 self.assertEqual(x1.grad, x2.grad.to_dense()) 858 859 def test_avg_pool2d(self): 860 N = torch.randint(3, 10, (1,)).item() 861 C = torch.randint(3, 10, (1,)).item() 862 x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 863 self._test_avg_pool_base(dim=2, input=x) 864 865 def test_avg_pool3d(self): 866 N = torch.randint(3, 10, (1,)).item() 867 C = torch.randint(3, 10, (1,)).item() 868 x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10 869 self._test_avg_pool_base(dim=3, input=x) 870 871 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 872 def _test_avg_pool_bf16_base(self, dim, input): 873 avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d} 874 x_bf16 = input.bfloat16() 875 for count_include_pad in [True, False]: 876 avg_pool = avg_module[dim]( 877 kernel_size=3, 878 stride=2, 879 padding=1, 880 count_include_pad=count_include_pad) 881 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 882 y = avg_pool(input.to_mkldnn()).to_dense() 883 y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float) 884 self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 885 else: 886 msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim 887 self.assertRaisesRegex(RuntimeError, 888 msg, 889 lambda: avg_pool(x_bf16.to_mkldnn())) 890 891 def test_avg_pool2d_bf16(self): 892 N = torch.randint(3, 10, (1,)).item() 893 C = torch.randint(3, 10, (1,)).item() 894 x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 895 self._test_avg_pool_bf16_base(dim=2, input=x) 896 897 def test_avg_pool3d_bf16(self): 898 N = torch.randint(3, 10, (1,)).item() 899 C = torch.randint(3, 10, (1,)).item() 900 x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10 901 self._test_avg_pool_bf16_base(dim=3, input=x) 902 903 def test_avg_pool2d_stride_none(self): 904 N = torch.randint(3, 10, (1,)).item() 905 C = torch.randint(3, 10, (1,)).item() 906 x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 907 908 for count_include_pad in [True, False]: 909 y1 = F.avg_pool2d( 910 x, 911 kernel_size=3, 912 stride=None, 913 padding=1, 914 count_include_pad=count_include_pad) 915 y2 = F.avg_pool2d( 916 x.to_mkldnn(), 917 kernel_size=3, 918 stride=None, 919 padding=1, 920 count_include_pad=count_include_pad) 921 922 self.assertEqual(y1, y2.to_dense()) 923 924 def test_adaptive_avg_pool2d(self): 925 N = torch.randint(3, 10, (1,)).item() 926 C = torch.randint(3, 10, (1,)).item() 927 x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 928 929 adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7) 930 x1 = x.clone().requires_grad_() 931 x2 = x.clone().to_mkldnn().requires_grad_() 932 y1 = adaptive_avg_pool2d(x1) 933 y2 = adaptive_avg_pool2d(x2).to_dense() 934 935 loss1 = y1.sum() 936 loss2 = y2.sum() 937 loss1.backward() 938 loss2.backward() 939 940 self.assertEqual(y1, y2) 941 self.assertEqual(x1.grad, x2.grad.to_dense()) 942 943 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 944 def test_adaptive_avg_pool2d_bf16(self): 945 N = torch.randint(3, 10, (1,)).item() 946 C = torch.randint(3, 10, (1,)).item() 947 x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 948 949 x_bf16 = x.bfloat16() 950 adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7) 951 952 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 953 y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense() 954 y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32) 955 self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 956 else: 957 msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 958 self.assertRaisesRegex(RuntimeError, 959 msg, 960 lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn())) 961 962 def _test_batch_norm_base(self, dim, channels, input): 963 bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} 964 bn = bn_module[dim](channels).float().train(False) 965 mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) 966 self.assertEqual( 967 bn(input), 968 mkldnn_bn(input.to_mkldnn()).to_dense()) 969 970 self._test_serialization(mkldnn_bn, (input.to_mkldnn(),)) 971 self._test_tracing(mkldnn_bn, (input.to_mkldnn(),)) 972 973 def _test_batch_norm_train_base(self, dim, channels, input): 974 # TODO: support 3d batchnorm training. 975 bn_module = {2 : torch.nn.BatchNorm2d} 976 # TODO: support none affine. 977 options = itertools.product([True], [True, False]) 978 for affine, track_running_stats in options: 979 bn = bn_module[dim]( 980 num_features=channels, 981 affine=affine, 982 track_running_stats=track_running_stats).float().train(True) 983 mkldnn_bn = copy.deepcopy(bn) 984 x1 = input.clone().requires_grad_() 985 x2 = input.clone().to_mkldnn().requires_grad_() 986 y1 = bn(x1) 987 y2 = mkldnn_bn(x2).to_dense() 988 loss1 = y1.sum() 989 loss2 = y2.sum() 990 loss1.backward() 991 loss2.backward() 992 self.assertEqual(y1, y2) 993 self.assertEqual(x1.grad, x2.grad.to_dense()) 994 self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3) 995 if track_running_stats: 996 self.assertEqual(bn.running_mean, mkldnn_bn.running_mean) 997 self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5) 998 999 def test_batch_norm_2d(self): 1000 N = torch.randint(3, 10, (1,)).item() 1001 C = torch.randint(3, 100, (1,)).item() 1002 x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1003 self._test_batch_norm_base(dim=2, channels=C, input=x) 1004 self._test_batch_norm_train_base(dim=2, channels=C, input=x) 1005 1006 def test_batch_norm_3d(self): 1007 N = torch.randint(3, 10, (1,)).item() 1008 C = torch.randint(3, 100, (1,)).item() 1009 x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10 1010 self._test_batch_norm_base(dim=3, channels=C, input=x) 1011 1012 @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 1013 def _test_batch_norm_bf16_base(self, dim, channels, input): 1014 bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} 1015 x_bf16 = input.bfloat16() 1016 # TODO: support training 1017 for train in [False]: 1018 bn = bn_module[dim](channels).float().train(train) 1019 mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) 1020 if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 1021 y = bn(input.to_mkldnn().to_dense()) 1022 y_bf16 = bn(input.to_mkldnn().to_dense(torch.float)) 1023 self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 1024 else: 1025 msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 1026 self.assertRaisesRegex(RuntimeError, 1027 msg, 1028 lambda: bn(x_bf16.to_mkldnn())) 1029 1030 def test_batch_norm_2d_bf16(self): 1031 N = torch.randint(3, 10, (1,)).item() 1032 C = torch.randint(3, 100, (1,)).item() 1033 x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1034 self._test_batch_norm_bf16_base(dim=2, channels=C, input=x) 1035 1036 def test_batch_norm_3d_bf16(self): 1037 N = torch.randint(3, 10, (1,)).item() 1038 C = torch.randint(3, 100, (1,)).item() 1039 x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10 1040 self._test_batch_norm_bf16_base(dim=3, channels=C, input=x) 1041 1042 def test_add(self): 1043 N = torch.randint(3, 10, (1,)).item() 1044 C = torch.randint(3, 100, (1,)).item() 1045 alpha = torch.randn(1, dtype=torch.float32).item() 1046 1047 x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1048 y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1049 mx = x.to_mkldnn() 1050 my = y.to_mkldnn() 1051 1052 # add 1053 self.assertEqual( 1054 x + y, 1055 (mx + my).to_dense()) 1056 1057 self.assertEqual( 1058 torch.add(x, y, alpha=alpha), 1059 torch.add(mx, my, alpha=alpha).to_dense()) 1060 1061 # add_ 1062 x += y 1063 mx += my 1064 self.assertEqual(x, mx.to_dense()) 1065 1066 # add_out 1067 out = x.clone() 1068 mkldnn_out = out.to_mkldnn() 1069 torch.add(x, y, alpha=alpha, out=out) 1070 torch.add(mx, my, alpha=alpha, out=mkldnn_out) 1071 self.assertEqual(out, mkldnn_out.to_dense()) 1072 1073 # add_out inplace case: first input 1074 torch.add(x, y, alpha=alpha, out=x) 1075 torch.add(mx, my, alpha=alpha, out=mx) 1076 self.assertEqual(x, mx.to_dense()) 1077 1078 # add_out inplace case: second input 1079 torch.add(x, y, alpha=alpha, out=y) 1080 torch.add(mx, my, alpha=alpha, out=my) 1081 self.assertEqual(y, my.to_dense()) 1082 1083 def test_mul(self): 1084 N = torch.randint(3, 10, (1,)).item() 1085 C = torch.randint(3, 100, (1,)).item() 1086 value = torch.randn(1, dtype=torch.float32).item() 1087 1088 x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1089 y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1090 mx = x.to_mkldnn() 1091 my = y.to_mkldnn() 1092 1093 # mul 1094 self.assertEqual( 1095 x * y, 1096 (mx * my).to_dense()) 1097 1098 self.assertEqual( 1099 x * value, 1100 (mx * value).to_dense()) 1101 1102 self.assertEqual( 1103 torch.mul(x, y), 1104 torch.mul(mx, my).to_dense()) 1105 1106 self.assertEqual( 1107 torch.mul(x, value), 1108 torch.mul(mx, value).to_dense()) 1109 1110 # mul_ 1111 x *= y 1112 mx *= my 1113 self.assertEqual(x, mx.to_dense()) 1114 1115 x *= value 1116 mx *= value 1117 self.assertEqual(x, mx.to_dense()) 1118 1119 # mul_out 1120 out = x.clone() 1121 mkldnn_out = out.to_mkldnn() 1122 torch.mul(x, y, out=out) 1123 torch.mul(mx, my, out=mkldnn_out) 1124 self.assertEqual(out, mkldnn_out.to_dense()) 1125 1126 out = x.clone() 1127 mkldnn_out = out.to_mkldnn() 1128 torch.mul(x, value, out=out) 1129 torch.mul(mx, value, out=mkldnn_out) 1130 self.assertEqual(out, mkldnn_out.to_dense()) 1131 1132 def test_0_dimension_tensor(self): 1133 x = torch.rand([20, 20, 1, 1], dtype=torch.float) 1134 y = torch.rand([20, 20, 0, 1], dtype=torch.float) 1135 1136 # unary ops work without modification 1137 out_relu = torch.relu(y) 1138 out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense() 1139 self.assertEqual(out_relu, out_relu_mkldnn) 1140 1141 out_mul = x * y 1142 out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense() 1143 self.assertEqual(out_mul, out_mul_mkldnn) 1144 1145 out_add = x + y 1146 out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense() 1147 self.assertEqual(out_add, out_add_mkldnn) 1148 1149 x.requires_grad_(True) 1150 y.requires_grad_(True) 1151 with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"): 1152 x.to_mkldnn() + y.to_mkldnn() 1153 1154 with self.assertRaisesRegex(RuntimeError, "must match"): 1155 torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn() 1156 1157 C = 7 1158 m = torch.nn.Conv2d(C, C, 3) 1159 x = torch.randn(0, C, C, 8, dtype=torch.float) 1160 out_eager = m(x) 1161 out_mkldnn = mkldnn_utils.to_mkldnn(m)(x) 1162 self.assertEqual(out_eager, out_mkldnn) 1163 1164 # https://github.com/pytorch/pytorch/issues/127111 1165 @xfailIfTorchDynamo 1166 def test_view(self): 1167 x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() 1168 self.assertRaisesRegex(RuntimeError, 1169 "Change to use reshape", 1170 lambda: x.view(x.size(0), -1)) 1171 1172 def test_reshape(self): 1173 x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1174 size = (x.size(0), -1) 1175 1176 self.assertEqual( 1177 x.reshape(size), 1178 x.to_mkldnn().reshape(size).to_dense(), 1179 ) 1180 # test whether share same memory for plain format tensor 1181 y = x.to_mkldnn() 1182 z = y.reshape(size).add_(y.reshape(size)) 1183 self.assertEqual( 1184 y.reshape(size).to_dense(), 1185 z.to_dense(), 1186 ) 1187 1188 def test_reshape_blocked_format(self): 1189 # construct an mkldnn blocked tensor with mkldnn conv2d 1190 C = 7 1191 m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3)) 1192 x = torch.randn(1, C, 8, 8).to_mkldnn() 1193 1194 # mkldnn tensor w/ blocked format 1195 y_block = m(x) 1196 # aten tensor w/ plain format 1197 y_plain = y_block.to_dense() 1198 1199 y_block_reshape = y_block.reshape(C, -1) 1200 y_plain_reshape = y_plain.reshape(C, -1) 1201 1202 self.assertEqual(y_plain_reshape, y_block_reshape.to_dense()) 1203 1204 def test_reshape_backward(self): 1205 x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1206 size = (x.size(0), -1) 1207 1208 x1 = x.clone().requires_grad_() 1209 x2 = x.clone().to_mkldnn().requires_grad_() 1210 in_features = 20 1211 out_features = torch.randint(3, 100, (1,)).item() 1212 linear = torch.nn.Linear(in_features, out_features).float() 1213 1214 y1 = linear(x1.reshape(size)).sum() 1215 y2 = linear(x2.reshape(size).to_dense()).sum() 1216 y1.backward() 1217 y2.backward() 1218 self.assertEqual(x1.grad, x2.grad.to_dense()) 1219 1220 def test_clone(self): 1221 x = torch.randn(4, 5, dtype=torch.float32) * 10 1222 self.assertEqual( 1223 x.clone(), 1224 x.to_mkldnn().clone().to_dense(), 1225 ) 1226 # test whether share same memory 1227 y = x.to_mkldnn() 1228 z = y.clone().add_(y) 1229 self.assertNotEqual( 1230 y.to_dense(), 1231 z.to_dense(), 1232 ) 1233 1234 def test_transpose(self): 1235 x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1236 for dim1 in range(x.ndim): 1237 for dim2 in range(x.ndim): 1238 self.assertEqual( 1239 x.transpose(dim1, dim2), 1240 x.to_mkldnn().transpose(dim1, dim2).to_dense(), 1241 ) 1242 1243 def test_transpose_invalid_dime(self): 1244 x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() 1245 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1246 torch._mkldnn_transpose(x, 0, 12) 1247 1248 def test_linear_non_contiguous_weight(self): 1249 in_features = torch.randint(3, 10, (1,)).item() 1250 out_features = torch.randint(3, 100, (1,)).item() 1251 x = torch.randn(3, in_features, dtype=torch.float32) * 10 1252 w = torch.randn(in_features, out_features, dtype=torch.float32) 1253 for bias in [True, False]: 1254 x1 = x.clone().requires_grad_() 1255 x2 = x.clone().to_mkldnn().requires_grad_() 1256 linear = torch.nn.Linear(in_features, out_features).float() 1257 linear.weight = torch.nn.Parameter(w.t()) 1258 mkldnn_linear = copy.deepcopy(linear) 1259 y1 = linear(x1).sum() 1260 y2 = mkldnn_linear(x2).to_dense().sum() 1261 y1.backward() 1262 y2.backward() 1263 self.assertEqual(x1.grad, x2.grad.to_dense()) 1264 self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad) 1265 if bias: 1266 self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad) 1267 1268 def test_linear(self): 1269 in_features = torch.randint(3, 10, (1,)).item() 1270 out_features = torch.randint(3, 100, (1,)).item() 1271 x = torch.randn(3, in_features, dtype=torch.float32) * 10 1272 1273 for bias in [True, False]: 1274 linear = torch.nn.Linear(in_features, out_features, bias=bias).float() 1275 mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) 1276 self.assertEqual( 1277 linear(x), 1278 mkldnn_linear(x.to_mkldnn()).to_dense()) 1279 1280 self._test_serialization(mkldnn_linear, (x.to_mkldnn(),)) 1281 self._test_tracing(mkldnn_linear, (x.to_mkldnn(),)) 1282 1283 def test_linear_backward(self): 1284 in_features = torch.randint(3, 10, (1,)).item() 1285 out_features = torch.randint(3, 100, (1,)).item() 1286 x = torch.randn(3, in_features, dtype=torch.float32) * 10 1287 for bias in [True, False]: 1288 x1 = x.clone().requires_grad_() 1289 x2 = x.clone().to_mkldnn().requires_grad_() 1290 linear = torch.nn.Linear(in_features, out_features).float() 1291 mkldnn_linear = copy.deepcopy(linear) 1292 y1 = linear(x1).sum() 1293 y2 = mkldnn_linear(x2).to_dense().sum() 1294 y1.backward() 1295 y2.backward() 1296 self.assertEqual(x1.grad, x2.grad.to_dense()) 1297 self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad) 1298 if bias: 1299 self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad) 1300 1301 @dtypes(torch.float16, torch.bfloat16) 1302 def test_linear_lowp(self, dtype): 1303 in_features = torch.randint(3, 10, (1,)).item() 1304 out_features = torch.randint(3, 100, (1,)).item() 1305 x = torch.randn(3, in_features, dtype=torch.float32) * 10 1306 x_lowp = x.to(dtype=dtype) 1307 1308 for bias in [True, False]: 1309 linear = torch.nn.Linear(in_features, out_features, bias=bias).float() 1310 mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) 1311 mkldnn_linear_lowp = mkldnn_utils.to_mkldnn( 1312 copy.deepcopy(linear), dtype 1313 ) 1314 lowp_support = { 1315 torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 1316 torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported, 1317 } 1318 if lowp_support[dtype](): 1319 y = mkldnn_linear(x.to_mkldnn()).to_dense() 1320 y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense( 1321 torch.float32 1322 ) 1323 if dtype == torch.bfloat16: 1324 self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3) 1325 else: 1326 self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3) 1327 else: 1328 msg = { 1329 torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq", 1330 torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16", 1331 } 1332 self.assertRaisesRegex( 1333 RuntimeError, 1334 msg[dtype], 1335 lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()), 1336 ) 1337 1338 def test_softmax(self): 1339 x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1340 for dim in range(x.ndim): 1341 softmax = torch.nn.Softmax(dim=dim) 1342 self.assertEqual( 1343 softmax(x), 1344 softmax(x.to_mkldnn()).to_dense()) 1345 1346 def test_sigmoid(self): 1347 x = torch.randn(4, 5, dtype=torch.float32) * 10 1348 mkldnn_x = x.to_mkldnn() 1349 self.assertEqual( 1350 torch.sigmoid(x), 1351 torch.sigmoid(mkldnn_x).to_dense(), 1352 ) 1353 # inplace 1354 torch.sigmoid_(x) 1355 torch.sigmoid_(mkldnn_x) 1356 self.assertEqual(x, mkldnn_x.to_dense()) 1357 1358 def test_tanh(self): 1359 x = torch.randn(4, 5, dtype=torch.float32) * 10 1360 mkldnn_x = x.to_mkldnn() 1361 self.assertEqual( 1362 torch.tanh(x), 1363 torch.tanh(mkldnn_x).to_dense(), 1364 ) 1365 # inplace 1366 torch.tanh_(x) 1367 torch.tanh_(mkldnn_x) 1368 self.assertEqual(x, mkldnn_x.to_dense()) 1369 1370 def _test_serialization(self, module, inputs): 1371 with TemporaryFileName() as fname: 1372 torch.jit.save(module, fname) 1373 loaded = torch.jit.load(fname) 1374 self.assertEqual( 1375 module(*inputs).to_dense(), 1376 loaded(*inputs).to_dense()) 1377 1378 def _test_tracing(self, module, inputs): 1379 traced = torch.jit.trace(module, inputs) 1380 self.assertEqual( 1381 module(*inputs).to_dense(), 1382 traced(*inputs).to_dense()) 1383 1384 def test_set_data_tensorimpl_type(self): 1385 # Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl 1386 # of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`. 1387 x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu')) 1388 x_mkldnn = x.to_mkldnn() 1389 with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'): 1390 x.data = x_mkldnn 1391 1392 def test_empty(self): 1393 x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32) 1394 x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn) 1395 self.assertEqual(x1.size(), x2.to_dense().size()) 1396 self.assertEqual(x1.dtype, x2.to_dense().dtype) 1397 1398 def test_zero_(self): 1399 x1 = torch.randn(4, 5, dtype=torch.float32) * 10 1400 x2 = x1.clone().to_mkldnn() 1401 self.assertEqual( 1402 x1.zero_(), 1403 x2.zero_().to_dense(), 1404 ) 1405 1406 def test_is_mkldnn(self): 1407 x = torch.randn(1, dtype=torch.float32) 1408 self.assertFalse(x.is_mkldnn) 1409 self.assertTrue(x.to_mkldnn().is_mkldnn) 1410 1411 # legacy constructor/new doesn't support mkldnn tensors 1412 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992") 1413 def test_legacy_new_failure(self): 1414 x = torch.randn(1, dtype=torch.float32) 1415 x_mkldnn = x.to_mkldnn() 1416 self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu')) 1417 self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage())) 1418 self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x)) 1419 self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3]))) 1420 self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6])) 1421 1422 def test_is_mkldnn_jit(self): 1423 class EnsureMkldnn(torch.jit.ScriptModule): 1424 @torch.jit.script_method 1425 def forward(self, x): 1426 if not x.is_mkldnn: 1427 x = x.to_mkldnn() 1428 return x 1429 1430 m = EnsureMkldnn() 1431 x = torch.randn(1, dtype=torch.float32) 1432 self.assertTrue(m(x).is_mkldnn) 1433 self.assertTrue(m(x.to_mkldnn()).is_mkldnn) 1434 1435 def _test_imagenet_model(self, model): 1436 model = model.train(False).float() 1437 mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model)) 1438 x = torch.randn(1, 3, 224, 224, dtype=torch.float32) 1439 with torch.no_grad(): 1440 self.assertEqual( 1441 model(x), 1442 mkldnn_model(x.to_mkldnn()).to_dense(), 1443 ) 1444 1445 @skipIfNoTorchVision 1446 def test_resnet18(self): 1447 model = torchvision.models.resnet.resnet18(weights=None) 1448 self._test_imagenet_model(model) 1449 1450 @skipIfNoTorchVision 1451 def test_resnext50_32x4d(self): 1452 model = torchvision.models.resnet.resnext50_32x4d(weights=None) 1453 self._test_imagenet_model(model) 1454 1455 def _lstm_params_list(self): 1456 params_dict = { 1457 "input_size": [1, 5], 1458 "hidden_size": [5, 16], 1459 "num_layers": [1, 3], 1460 "bidirectional": [False, True], 1461 "bias": [False, True], 1462 "batch_first": [False, True], 1463 "dropout": [0, 0.4, 0.7, 1], 1464 "batch_size": [1, 2], 1465 "seq_len": [1, 3], 1466 "training": [False, True] 1467 } 1468 1469 params_list = list(params_dict.values()) 1470 return params_list 1471 1472 def _cast_dtype(self, input, dtype): 1473 if dtype == torch.bfloat16: 1474 input = input.to(torch.bfloat16) 1475 elif dtype == torch.half: 1476 input = input.to(torch.half) 1477 return input 1478 1479 def test_lstm(self): 1480 seed = 2023 1481 torch.manual_seed(seed) 1482 1483 params_list = self._lstm_params_list() 1484 for dtype in types: 1485 bf16 = dtype == torch.bfloat16 1486 fp16 = dtype == torch.half 1487 rtol = 1.3e-6 1488 atol = 1e-5 1489 1490 if bf16: 1491 rtol = 0.02 1492 atol = 0.02 1493 if fp16: 1494 rtol = 1e-3 1495 atol = 1e-3 1496 for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \ 1497 in itertools.product(*params_list): 1498 num_directions = 2 if bidirectional else 1 1499 if batch_first: 1500 input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32) 1501 else: 1502 input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32) 1503 h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) 1504 c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) 1505 if fp16: 1506 # TODO add traing support when oneDNN support lstm FP16 training 1507 training = False 1508 model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, 1509 bias=bias, dropout=dropout, batch_first=batch_first).float() 1510 model.train() if training else model.eval() 1511 input1 = input.clone().requires_grad_(training) 1512 input2 = input.clone().requires_grad_(training) 1513 1514 h1 = h.clone().requires_grad_(training) 1515 h2 = h.clone().requires_grad_(training) 1516 c1 = c.clone().requires_grad_(training) 1517 c2 = c.clone().requires_grad_(training) 1518 1519 model1 = copy.deepcopy(model) 1520 model2 = copy.deepcopy(model) 1521 with torch.no_grad() if not training else nullcontext(): 1522 with torch.backends.mkldnn.flags(enabled=False): 1523 torch.manual_seed(seed) 1524 output1, (hn1, cn1) = self._cast_dtype(model1, dtype)( 1525 self._cast_dtype(input1, dtype), 1526 ( 1527 self._cast_dtype(h1, dtype), 1528 self._cast_dtype(c1, dtype), 1529 ), 1530 ) 1531 1532 torch.manual_seed(seed) 1533 output2, (hn2, cn2) = self._cast_dtype(model2, dtype)( 1534 self._cast_dtype(input2, dtype), 1535 ( 1536 self._cast_dtype(h2, dtype), 1537 self._cast_dtype(c2, dtype), 1538 ), 1539 ) 1540 self.assertEqual(output1, output2, rtol=rtol, atol=atol) 1541 self.assertEqual(hn1, hn2, rtol=rtol, atol=atol) 1542 self.assertEqual(cn1, cn2, rtol=rtol, atol=atol) 1543 1544 if training: 1545 with torch.backends.mkldnn.flags(enabled=False): 1546 torch.manual_seed(seed) 1547 output1.sum().backward(retain_graph=True) 1548 1549 torch.manual_seed(seed) 1550 output2.sum().backward(retain_graph=True) 1551 1552 self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol) 1553 for name, para in model1.named_parameters(): 1554 self.assertEqual(para, getattr(model2, name)) 1555 self.assertEqual( 1556 para.grad, 1557 getattr(model2, name).grad, 1558 rtol=rtol, 1559 atol=atol, 1560 ) 1561 1562 with torch.backends.mkldnn.flags(enabled=False): 1563 torch.manual_seed(seed) 1564 hn1.sum().backward(retain_graph=True) 1565 torch.manual_seed(seed) 1566 hn2.sum().backward(retain_graph=True) 1567 self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol) 1568 1569 with torch.backends.mkldnn.flags(enabled=False): 1570 torch.manual_seed(seed) 1571 cn1.sum().backward(retain_graph=True) 1572 torch.manual_seed(seed) 1573 cn2.sum().backward(retain_graph=True) 1574 self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol) 1575 1576 @dtypes(torch.float16, torch.bfloat16) 1577 def test_matmul_lower_precision(self, dtype): 1578 support_check = { 1579 torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 1580 torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported, 1581 } 1582 1583 def common(self, shape1, shape2, op, dtype): 1584 a = torch.randn(shape1, dtype=dtype) 1585 a_ref = a.float() 1586 b = torch.randn(shape2, dtype=dtype) 1587 b_ref = b.float() 1588 1589 y = op(a, b) 1590 y_ref = op(a_ref, b_ref) 1591 self.assertEqual(y, y_ref, exact_dtype=False) 1592 1593 if support_check[dtype](): 1594 a1 = torch.randn([64, 1, 33], dtype=dtype) 1595 # a2 is contiguous tensor but it's strides 1596 # is not default contiguous strides. 1597 a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1]) 1598 self.assertTrue(a2.is_contiguous()) 1599 b = torch.randn(64, 33, 256).to(dtype=dtype) 1600 y1 = torch.ops.aten.bmm(a1, b) 1601 y2 = torch.bmm(a2, b) 1602 self.assertEqual(y1, y2) 1603 1604 for shape1, shape2, op in [ 1605 ((33, 77), (77, 22), torch.matmul), 1606 ((128, 256), (256, 10), torch.matmul), 1607 ((7, 300), (300, 3), torch.matmul), 1608 ((1, 100), (100, 60), torch.matmul), 1609 ((100, 1), (1, 100), torch.matmul), 1610 ((20, 54, 78), (20, 78, 10), torch.bmm), 1611 ((1, 300, 1), (1, 1, 300), torch.bmm), 1612 ]: 1613 common(self, shape1, shape2, op, dtype) 1614 1615 1616instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) 1617 1618if __name__ == '__main__': 1619 run_tests() 1620