1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: mkldnn"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport copy 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport functools 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workertry: 10*da0073e9SAndroid Build Coastguard Worker import torchvision 11*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 12*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 13*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 19*da0073e9SAndroid Build Coastguard Workerimport torch.jit 20*da0073e9SAndroid Build Coastguard Workerimport torch.backends.mkldnn 21*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import mkldnn as mkldnn_utils 22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, \ 23*da0073e9SAndroid Build Coastguard Worker run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \ 24*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, xfailIfTorchDynamo 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 26*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 27*da0073e9SAndroid Build Coastguard Worker dtypes, 28*da0073e9SAndroid Build Coastguard Worker) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker# batched grad doesn't support mkldnn 31*da0073e9SAndroid Build Coastguard Workergradcheck = functools.partial(gradcheck, check_batched_grad=False) 32*da0073e9SAndroid Build Coastguard Workergradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workertypes = [torch.float, torch.bfloat16, torch.half] 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker# Comment the line below to find out the CI machines having MKL-DNN build disabled 38*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") 39*da0073e9SAndroid Build Coastguard Workerclass TestMkldnn(TestCase): 40*da0073e9SAndroid Build Coastguard Worker def test_conversion(self): 41*da0073e9SAndroid Build Coastguard Worker for cpu_tensor in [torch.randn((1, 2, 3, 4), 42*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, device=torch.device('cpu')), 43*da0073e9SAndroid Build Coastguard Worker torch.randn((1, 2, 3, 4, 5), 44*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: 45*da0073e9SAndroid Build Coastguard Worker cpu_tensor.requires_grad_() 46*da0073e9SAndroid Build Coastguard Worker convert_dtypes = {torch.half: [torch.half, torch.float], 47*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: [torch.bfloat16, torch.float], 48*da0073e9SAndroid Build Coastguard Worker torch.float: [torch.bfloat16, torch.half]} 49*da0073e9SAndroid Build Coastguard Worker # float/bfloat16/half cpu tensor to mkldnn tensortensor. 50*da0073e9SAndroid Build Coastguard Worker for dtype1 in types: 51*da0073e9SAndroid Build Coastguard Worker mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1) 52*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, dtype1) 53*da0073e9SAndroid Build Coastguard Worker cpu_tensor_1 = mkldnn_tensor.to_dense() 54*da0073e9SAndroid Build Coastguard Worker # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 56*da0073e9SAndroid Build Coastguard Worker # mkldnn float/bfloat tensor to cpu float or bfloat tensor 57*da0073e9SAndroid Build Coastguard Worker for dtype2 in convert_dtypes[dtype1]: 58*da0073e9SAndroid Build Coastguard Worker cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor_2.dtype, dtype2) 60*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) 65*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 66*da0073e9SAndroid Build Coastguard Worker if dtype1 == torch.float: 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) 68*da0073e9SAndroid Build Coastguard Worker else: 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2) 70*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 71*da0073e9SAndroid Build Coastguard Worker "Cannot access data pointer of Tensor that doesn't have storage", 72*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_tensor.data_ptr() != 0) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker # bfloat cpu tensor to mkldnn float tensor or bfloat tensor. 75*da0073e9SAndroid Build Coastguard Worker for orig_dtype in [torch.half, torch.bfloat16]: 76*da0073e9SAndroid Build Coastguard Worker cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype) 77*da0073e9SAndroid Build Coastguard Worker for dtype1 in convert_dtypes[orig_dtype]: 78*da0073e9SAndroid Build Coastguard Worker mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, dtype1) 80*da0073e9SAndroid Build Coastguard Worker cpu_tensor_1 = mkldnn_tensor.to_dense() 81*da0073e9SAndroid Build Coastguard Worker # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 83*da0073e9SAndroid Build Coastguard Worker # mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor 84*da0073e9SAndroid Build Coastguard Worker for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]: 85*da0073e9SAndroid Build Coastguard Worker cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) 86*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor_2.dtype, dtype2) 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor_lower, 88*da0073e9SAndroid Build Coastguard Worker cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) 92*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 93*da0073e9SAndroid Build Coastguard Worker if dtype1 in [torch.bfloat16, torch.half]: 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size()) 95*da0073e9SAndroid Build Coastguard Worker else: 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2) 97*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 98*da0073e9SAndroid Build Coastguard Worker "Cannot access data pointer of Tensor that doesn't have storage", 99*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_tensor.data_ptr() != 0) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_conversion_byte_char(self): 102*da0073e9SAndroid Build Coastguard Worker int8_types = [torch.int8, torch.uint8] 103*da0073e9SAndroid Build Coastguard Worker for int8_type in int8_types: 104*da0073e9SAndroid Build Coastguard Worker low = -100 if int8_type is torch.int8 else 0 105*da0073e9SAndroid Build Coastguard Worker high = 100 106*da0073e9SAndroid Build Coastguard Worker for cpu_tensor in [torch.randint( 107*da0073e9SAndroid Build Coastguard Worker low=low, 108*da0073e9SAndroid Build Coastguard Worker high=high, 109*da0073e9SAndroid Build Coastguard Worker size=(1, 2, 3, 4), 110*da0073e9SAndroid Build Coastguard Worker dtype=torch.int64, 111*da0073e9SAndroid Build Coastguard Worker device=torch.device('cpu')), 112*da0073e9SAndroid Build Coastguard Worker torch.randint( 113*da0073e9SAndroid Build Coastguard Worker low=low, 114*da0073e9SAndroid Build Coastguard Worker high=high, 115*da0073e9SAndroid Build Coastguard Worker size=(1, 2, 3, 4, 5), 116*da0073e9SAndroid Build Coastguard Worker dtype=torch.int64, 117*da0073e9SAndroid Build Coastguard Worker device=torch.device('cpu'))[:, :, :, :, :]]: 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker cpu_tensor = cpu_tensor.to(dtype=int8_type) 120*da0073e9SAndroid Build Coastguard Worker mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type) 121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, int8_type) 122*da0073e9SAndroid Build Coastguard Worker cpu_tensor_1 = mkldnn_tensor.to_dense() 123*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) 124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor, cpu_tensor_1) 125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size()) 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) 128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) 129*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 130*da0073e9SAndroid Build Coastguard Worker "Cannot access data pointer of Tensor that doesn't have storage", 131*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_tensor.data_ptr() != 0) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def test_copy(self): 134*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, dtype=torch.float32) 135*da0073e9SAndroid Build Coastguard Worker mkldnn_x = x.to_mkldnn() 136*da0073e9SAndroid Build Coastguard Worker mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn() 137*da0073e9SAndroid Build Coastguard Worker mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn() 138*da0073e9SAndroid Build Coastguard Worker mkldnn_y.copy_(mkldnn_x) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mkldnn_y.to_dense()) 140*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 141*da0073e9SAndroid Build Coastguard Worker "copy_mkldnn_: only support same size tensor.", 142*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_z.copy_(mkldnn_x)) 143*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 144*da0073e9SAndroid Build Coastguard Worker "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! " 145*da0073e9SAndroid Build Coastguard Worker "Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor", 146*da0073e9SAndroid Build Coastguard Worker lambda: x.copy_(mkldnn_x)) 147*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 148*da0073e9SAndroid Build Coastguard Worker "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! " 149*da0073e9SAndroid Build Coastguard Worker "Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor", 150*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_x.copy_(x)) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def test_unsupported(self): 153*da0073e9SAndroid Build Coastguard Worker # unsupported types and unsupported types with gpu 154*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.double, torch.uint8, torch.int8, 155*da0073e9SAndroid Build Coastguard Worker torch.short, torch.int, torch.long]: 156*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as context: 157*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn() 158*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 159*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as context: 160*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn() 161*da0073e9SAndroid Build Coastguard Worker # supported type with gpu 162*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 163*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as context: 164*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn() 165*da0073e9SAndroid Build Coastguard Worker # some factory functions 166*da0073e9SAndroid Build Coastguard Worker for creator in [torch.ones, torch.randn, torch.rand]: 167*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as context: 168*da0073e9SAndroid Build Coastguard Worker creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def test_mkldnn_conv_shapecheck(self): 171*da0073e9SAndroid Build Coastguard Worker input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) 172*da0073e9SAndroid Build Coastguard Worker w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) 173*da0073e9SAndroid Build Coastguard Worker b1 = torch.full((1,), 1, dtype=torch.float32) 174*da0073e9SAndroid Build Coastguard Worker w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32) 175*da0073e9SAndroid Build Coastguard Worker b2 = torch.full((2,), 1, dtype=torch.float32) 176*da0073e9SAndroid Build Coastguard Worker options = zip([-1, 0, 0, 0, 0, 0, 0], # padding 177*da0073e9SAndroid Build Coastguard Worker [1, 0, 1, 1, 1, 1, 1], # stride 178*da0073e9SAndroid Build Coastguard Worker [1, 1, 0, 1, 1, 1, 1], # dilation 179*da0073e9SAndroid Build Coastguard Worker [1, 1, 1, 0, 2, 1, 1], # groups 180*da0073e9SAndroid Build Coastguard Worker [w1, w1, w1, w1, w1, w1, w2], # weight 181*da0073e9SAndroid Build Coastguard Worker [b1, b1, b1, b1, b1, b2, b1]) # bias 182*da0073e9SAndroid Build Coastguard Worker for pad, st, dil, gr, w, b in options: 183*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as _: 184*da0073e9SAndroid Build Coastguard Worker torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def test_autograd_to_mkldnn(self): 187*da0073e9SAndroid Build Coastguard Worker # MKLDNN only supports float32 188*da0073e9SAndroid Build Coastguard Worker root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def func(root): 191*da0073e9SAndroid Build Coastguard Worker return root.to_mkldnn().to_dense() 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker # because MKLDNN only supports float32, we need to lessen the precision. 194*da0073e9SAndroid Build Coastguard Worker # these numbers are just empirical results that seem to work. 195*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, 196*da0073e9SAndroid Build Coastguard Worker 'double precision floating point', 197*da0073e9SAndroid Build Coastguard Worker lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2)) 198*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, 199*da0073e9SAndroid Build Coastguard Worker 'double precision floating point', 200*da0073e9SAndroid Build Coastguard Worker lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2)) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def test_autograd_from_mkldnn(self): 203*da0073e9SAndroid Build Coastguard Worker # MKLDNN only supports float32 204*da0073e9SAndroid Build Coastguard Worker root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def func(root): 207*da0073e9SAndroid Build Coastguard Worker return root.to_dense() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker # because MKLDNN only supports float32, we need to lessen the precision. 210*da0073e9SAndroid Build Coastguard Worker # these numbers are just empirical results that seem to work. 211*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, 212*da0073e9SAndroid Build Coastguard Worker 'double precision floating point', 213*da0073e9SAndroid Build Coastguard Worker lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2)) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker def test_detach(self): 216*da0073e9SAndroid Build Coastguard Worker root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker detach = root.detach() 219*da0073e9SAndroid Build Coastguard Worker self.assertEqual((4, 5), detach.size()) 220*da0073e9SAndroid Build Coastguard Worker self.assertFalse(detach.requires_grad) 221*da0073e9SAndroid Build Coastguard Worker self.assertTrue(root.requires_grad) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker detach_ = root.detach_() 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual((4, 5), detach_.size()) 225*da0073e9SAndroid Build Coastguard Worker self.assertFalse(detach_.requires_grad) 226*da0073e9SAndroid Build Coastguard Worker self.assertFalse(root.requires_grad) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def test_repr(self): 229*da0073e9SAndroid Build Coastguard Worker self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4), 230*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, device=torch.device('cpu')).to_mkldnn())) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def _test_conv_base(self, dim): 233*da0073e9SAndroid Build Coastguard Worker conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 234*da0073e9SAndroid Build Coastguard Worker input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)} 235*da0073e9SAndroid Build Coastguard Worker options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 236*da0073e9SAndroid Build Coastguard Worker for train, bias, dilation, groups in options: 237*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 238*da0073e9SAndroid Build Coastguard Worker M = torch.randint(1, 3, (1,)).item() * groups 239*da0073e9SAndroid Build Coastguard Worker C = torch.randint(1, 3, (1,)).item() * groups 240*da0073e9SAndroid Build Coastguard Worker x_shape = (N, C) + input_shapes[dim] 241*da0073e9SAndroid Build Coastguard Worker x = torch.randn(x_shape, dtype=torch.float32) 242*da0073e9SAndroid Build Coastguard Worker conv = conv_module[dim](in_channels=C, 243*da0073e9SAndroid Build Coastguard Worker out_channels=M, 244*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 245*da0073e9SAndroid Build Coastguard Worker stride=2, 246*da0073e9SAndroid Build Coastguard Worker padding=1, 247*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 248*da0073e9SAndroid Build Coastguard Worker bias=bias, 249*da0073e9SAndroid Build Coastguard Worker groups=groups).float() 250*da0073e9SAndroid Build Coastguard Worker x1 = x.clone() 251*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn() 252*da0073e9SAndroid Build Coastguard Worker if not train: 253*da0073e9SAndroid Build Coastguard Worker mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv)) 254*da0073e9SAndroid Build Coastguard Worker elif train and dim != 1: 255*da0073e9SAndroid Build Coastguard Worker # TODO: enable conv1d training. 256*da0073e9SAndroid Build Coastguard Worker x1.requires_grad_() 257*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_() 258*da0073e9SAndroid Build Coastguard Worker mkldnn_conv = copy.deepcopy(conv) 259*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 260*da0073e9SAndroid Build Coastguard Worker y_aten = conv(x1) 261*da0073e9SAndroid Build Coastguard Worker if train and dim != 1: 262*da0073e9SAndroid Build Coastguard Worker loss1 = y_aten.sum() 263*da0073e9SAndroid Build Coastguard Worker loss1.backward() 264*da0073e9SAndroid Build Coastguard Worker if not train or (train and dim != 1): 265*da0073e9SAndroid Build Coastguard Worker y_mkldnn = mkldnn_conv(x2).to_dense() 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_aten, y_mkldnn) 267*da0073e9SAndroid Build Coastguard Worker if not train: 268*da0073e9SAndroid Build Coastguard Worker self._test_serialization(mkldnn_conv, (x.to_mkldnn(),)) 269*da0073e9SAndroid Build Coastguard Worker self._test_tracing(mkldnn_conv, (x.to_mkldnn(),)) 270*da0073e9SAndroid Build Coastguard Worker elif dim != 1: 271*da0073e9SAndroid Build Coastguard Worker loss2 = y_mkldnn.sum() 272*da0073e9SAndroid Build Coastguard Worker loss2.backward() 273*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x2.grad.is_mkldnn) 274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.weight.grad, 276*da0073e9SAndroid Build Coastguard Worker mkldnn_conv.weight.grad, 277*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 278*da0073e9SAndroid Build Coastguard Worker rtol=1e-3) 279*da0073e9SAndroid Build Coastguard Worker if bias: 280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker def test_conv1d(self): 283*da0073e9SAndroid Build Coastguard Worker self._test_conv_base(dim=1) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def test_conv2d(self): 286*da0073e9SAndroid Build Coastguard Worker self._test_conv_base(dim=2) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def test_conv3d(self): 289*da0073e9SAndroid Build Coastguard Worker self._test_conv_base(dim=3) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype): 292*da0073e9SAndroid Build Coastguard Worker input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)} 293*da0073e9SAndroid Build Coastguard Worker options = itertools.product([True, False], [1, 2], [1, 4]) 294*da0073e9SAndroid Build Coastguard Worker for bias, dilation, groups in options: 295*da0073e9SAndroid Build Coastguard Worker N = torch.randint(1, 3, (1,)).item() 296*da0073e9SAndroid Build Coastguard Worker M = torch.randint(1, 3, (1,)).item() * groups 297*da0073e9SAndroid Build Coastguard Worker C = torch.randint(1, 3, (1,)).item() * groups 298*da0073e9SAndroid Build Coastguard Worker x_shape = (N, C) + input_shapes[dim] 299*da0073e9SAndroid Build Coastguard Worker x = torch.randn(x_shape, dtype=torch.float32) 300*da0073e9SAndroid Build Coastguard Worker # TODO: remove this when group depthwise is supported: 301*da0073e9SAndroid Build Coastguard Worker if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, 302*da0073e9SAndroid Build Coastguard Worker torch.nn.ConvTranspose3d] and groups > 1 and C == groups: 303*da0073e9SAndroid Build Coastguard Worker continue 304*da0073e9SAndroid Build Coastguard Worker conv = conv_module(in_channels=C, 305*da0073e9SAndroid Build Coastguard Worker out_channels=M, 306*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 307*da0073e9SAndroid Build Coastguard Worker stride=2, 308*da0073e9SAndroid Build Coastguard Worker padding=1, 309*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 310*da0073e9SAndroid Build Coastguard Worker bias=bias, 311*da0073e9SAndroid Build Coastguard Worker groups=groups).float() 312*da0073e9SAndroid Build Coastguard Worker x_lower = x.to(dtype=dtype) 313*da0073e9SAndroid Build Coastguard Worker if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \ 314*da0073e9SAndroid Build Coastguard Worker (dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()): 315*da0073e9SAndroid Build Coastguard Worker mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv)) 316*da0073e9SAndroid Build Coastguard Worker mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype) 317*da0073e9SAndroid Build Coastguard Worker y = mkldnn_conv(x.to_mkldnn()).to_dense() 318*da0073e9SAndroid Build Coastguard Worker y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32) 319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3) 320*da0073e9SAndroid Build Coastguard Worker else: 321*da0073e9SAndroid Build Coastguard Worker msg = { 322*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq", 323*da0073e9SAndroid Build Coastguard Worker torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16", 324*da0073e9SAndroid Build Coastguard Worker } 325*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg[dtype]): 326*da0073e9SAndroid Build Coastguard Worker mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype) 327*da0073e9SAndroid Build Coastguard Worker y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32) 328*da0073e9SAndroid Build Coastguard Worker # test thnn impl 329*da0073e9SAndroid Build Coastguard Worker conv_lower = copy.deepcopy(conv).to(dtype=dtype) 330*da0073e9SAndroid Build Coastguard Worker conv_ref = copy.deepcopy(conv_lower).float() 331*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 332*da0073e9SAndroid Build Coastguard Worker x_ref = x_lower.clone().float().detach().requires_grad_() 333*da0073e9SAndroid Build Coastguard Worker x_lower.requires_grad_() 334*da0073e9SAndroid Build Coastguard Worker y = conv_ref(x_ref) 335*da0073e9SAndroid Build Coastguard Worker y_lower = conv_lower(x_lower).float() 336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 339*da0073e9SAndroid Build Coastguard Worker def test_conv_deconv_1d_lower_precision(self, dtype): 340*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype) 341*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 344*da0073e9SAndroid Build Coastguard Worker def test_conv_deconv_2d_lower_precision(self, dtype): 345*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype) 346*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 349*da0073e9SAndroid Build Coastguard Worker def test_conv_deconv_3d_lower_precision(self, dtype): 350*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype) 351*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None): 354*da0073e9SAndroid Build Coastguard Worker input_shapes = {2: (55, 55), 3: (14, 14, 14)} 355*da0073e9SAndroid Build Coastguard Worker options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 356*da0073e9SAndroid Build Coastguard Worker if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: 357*da0073e9SAndroid Build Coastguard Worker cl_format = torch.channels_last 358*da0073e9SAndroid Build Coastguard Worker input_shape = input_shapes[2] 359*da0073e9SAndroid Build Coastguard Worker elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]: 360*da0073e9SAndroid Build Coastguard Worker cl_format = torch.channels_last_3d 361*da0073e9SAndroid Build Coastguard Worker input_shape = input_shapes[3] 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker for train, bias, dilation, groups in options: 364*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 365*da0073e9SAndroid Build Coastguard Worker M = torch.randint(1, 3, (1,)).item() * groups 366*da0073e9SAndroid Build Coastguard Worker C = torch.randint(1, 3, (1,)).item() * groups 367*da0073e9SAndroid Build Coastguard Worker x_shape = (N, C) + input_shape 368*da0073e9SAndroid Build Coastguard Worker x = torch.randn(x_shape, dtype=dtype) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker # conv1: mkldnn conv/deconv in contiguous memory format (nchw) 371*da0073e9SAndroid Build Coastguard Worker # conv2: mkldnn conv/deconv in channels last memory format (nhwc) 372*da0073e9SAndroid Build Coastguard Worker conv1 = conv_module(in_channels=C, 373*da0073e9SAndroid Build Coastguard Worker out_channels=M, 374*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 375*da0073e9SAndroid Build Coastguard Worker stride=2, 376*da0073e9SAndroid Build Coastguard Worker padding=1, 377*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 378*da0073e9SAndroid Build Coastguard Worker bias=bias, 379*da0073e9SAndroid Build Coastguard Worker groups=groups).to(dtype=dtype) 380*da0073e9SAndroid Build Coastguard Worker conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format) 381*da0073e9SAndroid Build Coastguard Worker x1 = x.clone() 382*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to(memory_format=cl_format) 383*da0073e9SAndroid Build Coastguard Worker if train: 384*da0073e9SAndroid Build Coastguard Worker x1.requires_grad_() 385*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_() 386*da0073e9SAndroid Build Coastguard Worker y1 = conv1(x1) 387*da0073e9SAndroid Build Coastguard Worker y2 = conv2(x2) 388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2, atol=prec, rtol=prec) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker if train: 391*da0073e9SAndroid Build Coastguard Worker y1.sum().backward() 392*da0073e9SAndroid Build Coastguard Worker y2.sum().backward() 393*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format)) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv1.weight.grad, 395*da0073e9SAndroid Build Coastguard Worker conv2.weight.grad, 396*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 397*da0073e9SAndroid Build Coastguard Worker rtol=1e-3) 398*da0073e9SAndroid Build Coastguard Worker if bias: 399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) 400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker def test_conv_nhwc_fp32(self): 403*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) 404*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) 405*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32) 406*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 409*da0073e9SAndroid Build Coastguard Worker def test_conv_nhwc_lower_precision(self, dtype): 410*da0073e9SAndroid Build Coastguard Worker # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported() 411*da0073e9SAndroid Build Coastguard Worker # returns false, bf16/fp16 CPU conv will fall back to thnn impl 412*da0073e9SAndroid Build Coastguard Worker support_checks = { 413*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 414*da0073e9SAndroid Build Coastguard Worker torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported 415*da0073e9SAndroid Build Coastguard Worker } 416*da0073e9SAndroid Build Coastguard Worker if support_checks[dtype](): 417*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype) 418*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype) 419*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype) 420*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # BF16/FP16 fallback implementations are divided into two parts im2col+gemm, 423*da0073e9SAndroid Build Coastguard Worker # and the number of data type conversions in the middle is more than that of onednn's direct conv, 424*da0073e9SAndroid Build Coastguard Worker # resulting in additional accuracy loss. 425*da0073e9SAndroid Build Coastguard Worker precisions = { 426*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 1e-2, 427*da0073e9SAndroid Build Coastguard Worker torch.float16: 2e-3, 428*da0073e9SAndroid Build Coastguard Worker } 429*da0073e9SAndroid Build Coastguard Worker prec = precisions[dtype] 430*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 431*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec) 432*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec) 433*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec) 434*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_nhwc_fp32(self): 438*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) 439*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) 440*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32) 441*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 444*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_nhwc_lower_precision(self, dtype): 445*da0073e9SAndroid Build Coastguard Worker # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported() 446*da0073e9SAndroid Build Coastguard Worker # returns false, bf16/fp16 CPU conv will fall back to thnn impl 447*da0073e9SAndroid Build Coastguard Worker support_checks = { 448*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 449*da0073e9SAndroid Build Coastguard Worker torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported 450*da0073e9SAndroid Build Coastguard Worker } 451*da0073e9SAndroid Build Coastguard Worker if support_checks[dtype](): 452*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype) 453*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype) 454*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype) 455*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker # BF16/FP16 fallback implementations are divided into two parts col2im+gemm, 458*da0073e9SAndroid Build Coastguard Worker # and the number of data type conversions in the middle is more than that of onednn's direct conv, 459*da0073e9SAndroid Build Coastguard Worker # resulting in additional accuracy loss. 460*da0073e9SAndroid Build Coastguard Worker precisions = { 461*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 2e-2, 462*da0073e9SAndroid Build Coastguard Worker torch.float16: 3e-3, 463*da0073e9SAndroid Build Coastguard Worker } 464*da0073e9SAndroid Build Coastguard Worker prec = precisions[dtype] 465*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 466*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec) 467*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec) 468*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec) 469*da0073e9SAndroid Build Coastguard Worker self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker def _test_conv_transpose_base(self, dim): 472*da0073e9SAndroid Build Coastguard Worker conv_module = { 473*da0073e9SAndroid Build Coastguard Worker 1: torch.nn.ConvTranspose1d, 474*da0073e9SAndroid Build Coastguard Worker 2: torch.nn.ConvTranspose2d, 475*da0073e9SAndroid Build Coastguard Worker 3: torch.nn.ConvTranspose3d 476*da0073e9SAndroid Build Coastguard Worker } 477*da0073e9SAndroid Build Coastguard Worker input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)} 478*da0073e9SAndroid Build Coastguard Worker options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) 479*da0073e9SAndroid Build Coastguard Worker for train, bias, dilation, groups in options: 480*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 481*da0073e9SAndroid Build Coastguard Worker M = torch.randint(1, 3, (1,)).item() * groups 482*da0073e9SAndroid Build Coastguard Worker C = torch.randint(1, 3, (1,)).item() * groups 483*da0073e9SAndroid Build Coastguard Worker x_shape = (N, C) + input_shapes[dim] 484*da0073e9SAndroid Build Coastguard Worker data = torch.randn(x_shape, dtype=torch.float32) 485*da0073e9SAndroid Build Coastguard Worker # conv: mkldnn tranpose conv fp32 486*da0073e9SAndroid Build Coastguard Worker # conv_ref: thnn transpose conv fp32 487*da0073e9SAndroid Build Coastguard Worker conv = conv_module[dim](in_channels=C, 488*da0073e9SAndroid Build Coastguard Worker out_channels=M, 489*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 490*da0073e9SAndroid Build Coastguard Worker stride=1, 491*da0073e9SAndroid Build Coastguard Worker padding=1, 492*da0073e9SAndroid Build Coastguard Worker dilation=dilation, 493*da0073e9SAndroid Build Coastguard Worker bias=bias, 494*da0073e9SAndroid Build Coastguard Worker groups=groups).to(dtype=torch.float32) 495*da0073e9SAndroid Build Coastguard Worker x = data.clone() 496*da0073e9SAndroid Build Coastguard Worker x_ref = x.clone() 497*da0073e9SAndroid Build Coastguard Worker if train: 498*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 499*da0073e9SAndroid Build Coastguard Worker x_ref.requires_grad_() 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker conv_ref = copy.deepcopy(conv) 502*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 503*da0073e9SAndroid Build Coastguard Worker y_ref = conv_ref(x_ref) 504*da0073e9SAndroid Build Coastguard Worker if train: 505*da0073e9SAndroid Build Coastguard Worker y_ref.sum().backward() 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker y = conv(x) 508*da0073e9SAndroid Build Coastguard Worker if train: 509*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 512*da0073e9SAndroid Build Coastguard Worker if train: 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_ref.grad) 514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.weight.grad, 515*da0073e9SAndroid Build Coastguard Worker conv_ref.weight.grad, 516*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 517*da0073e9SAndroid Build Coastguard Worker rtol=1e-3) 518*da0073e9SAndroid Build Coastguard Worker if bias: 519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.bias.grad, conv_ref.bias.grad) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose1d(self): 522*da0073e9SAndroid Build Coastguard Worker self._test_conv_transpose_base(dim=1) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose2d(self): 525*da0073e9SAndroid Build Coastguard Worker self._test_conv_transpose_base(dim=2) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose3d(self): 528*da0073e9SAndroid Build Coastguard Worker self._test_conv_transpose_base(dim=3) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def test_conv2d_legacy_jit_model(self): 531*da0073e9SAndroid Build Coastguard Worker """ 532*da0073e9SAndroid Build Coastguard Worker MKLDNN integration used to serialize models with 5d weight for grouped 533*da0073e9SAndroid Build Coastguard Worker convolutions, we'd like to preserve this behavior 534*da0073e9SAndroid Build Coastguard Worker """ 535*da0073e9SAndroid Build Coastguard Worker g = 4 536*da0073e9SAndroid Build Coastguard Worker conv2d = torch.nn.Conv2d(16, 16, 3, groups=g) 537*da0073e9SAndroid Build Coastguard Worker conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # contrive legacy conv2d module with a 5-d weight 540*da0073e9SAndroid Build Coastguard Worker o, i, h, w = conv2d.weight.shape 541*da0073e9SAndroid Build Coastguard Worker weight_5d = conv2d.weight.reshape((g, o // g, i, h, w)) 542*da0073e9SAndroid Build Coastguard Worker conv2d_mkldnn.weight = weight_5d.to_mkldnn() 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 16, 8, 8) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 547*da0073e9SAndroid Build Coastguard Worker torch.jit.save(conv2d_mkldnn, fname) 548*da0073e9SAndroid Build Coastguard Worker conv2d_loaded = torch.jit.load(fname) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5) 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv2d_loaded.weight.ndimension(), 4) 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 553*da0073e9SAndroid Build Coastguard Worker conv2d(x), 554*da0073e9SAndroid Build Coastguard Worker conv2d_loaded(x.to_mkldnn()).to_dense()) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker # This test is to check whether 1D conv is supported for mkldnn tensor, 557*da0073e9SAndroid Build Coastguard Worker # which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034. 558*da0073e9SAndroid Build Coastguard Worker def test_conv1d_functional(self): 559*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 10).to_mkldnn() 560*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(3, 3, 3).to_mkldnn() 561*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(3).to_mkldnn() 562*da0073e9SAndroid Build Coastguard Worker output = torch.nn.functional.conv1d(input, weight, bias) 563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.size(), torch.Size([2, 3, 8])) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker def test_relu(self): 566*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 5), dtype=torch.float32) * 10 567*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 568*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 569*da0073e9SAndroid Build Coastguard Worker y1 = torch.relu(x1) 570*da0073e9SAndroid Build Coastguard Worker y2 = torch.relu(x2).to_dense() 571*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 572*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 573*da0073e9SAndroid Build Coastguard Worker loss1.backward() 574*da0073e9SAndroid Build Coastguard Worker loss2.backward() 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker def test_relu_(self): 579*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 5), dtype=torch.float32) * 10 580*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 581*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 582*da0073e9SAndroid Build Coastguard Worker y1 = torch.relu_(x1.clone()) 583*da0073e9SAndroid Build Coastguard Worker y2 = torch.relu_(x2.clone()).to_dense() 584*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 585*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 586*da0073e9SAndroid Build Coastguard Worker loss1.backward() 587*da0073e9SAndroid Build Coastguard Worker loss2.backward() 588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 592*da0073e9SAndroid Build Coastguard Worker def _test_relu_bf16_base(self, name): 593*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 5), dtype=torch.float32) * 10 594*da0073e9SAndroid Build Coastguard Worker x_bf16 = x.bfloat16() 595*da0073e9SAndroid Build Coastguard Worker fn = getattr(torch, name) 596*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 597*da0073e9SAndroid Build Coastguard Worker y = fn(x.to_mkldnn()).to_dense() 598*da0073e9SAndroid Build Coastguard Worker y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32) 599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 600*da0073e9SAndroid Build Coastguard Worker else: 601*da0073e9SAndroid Build Coastguard Worker msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 602*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 603*da0073e9SAndroid Build Coastguard Worker msg, 604*da0073e9SAndroid Build Coastguard Worker lambda: fn(x_bf16.to_mkldnn())) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker def test_relu_bf16(self): 607*da0073e9SAndroid Build Coastguard Worker self._test_relu_bf16_base("relu") 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def test_relu_inplace_bf16(self): 610*da0073e9SAndroid Build Coastguard Worker self._test_relu_bf16_base("relu_") 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def test_gelu(self): 613*da0073e9SAndroid Build Coastguard Worker m = torch.nn.GELU() 614*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 5), dtype=torch.float32) * 10 615*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 616*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 617*da0073e9SAndroid Build Coastguard Worker y1 = m(x1) 618*da0073e9SAndroid Build Coastguard Worker y2 = m(x2).to_dense() 619*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 620*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 621*da0073e9SAndroid Build Coastguard Worker loss1.backward() 622*da0073e9SAndroid Build Coastguard Worker loss2.backward() 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 627*da0073e9SAndroid Build Coastguard Worker def test_gelu_bf16(self): 628*da0073e9SAndroid Build Coastguard Worker m = torch.nn.GELU() 629*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 5), dtype=torch.float32) * 10 630*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().to_mkldnn().requires_grad_() 631*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_() 632*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 633*da0073e9SAndroid Build Coastguard Worker y1 = m(x1).to_dense() 634*da0073e9SAndroid Build Coastguard Worker y2 = m(x2).to_dense() 635*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 636*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 637*da0073e9SAndroid Build Coastguard Worker loss1.backward() 638*da0073e9SAndroid Build Coastguard Worker loss2.backward() 639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0) 640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0) 641*da0073e9SAndroid Build Coastguard Worker else: 642*da0073e9SAndroid Build Coastguard Worker msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 643*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 644*da0073e9SAndroid Build Coastguard Worker msg, 645*da0073e9SAndroid Build Coastguard Worker lambda: m(x2)) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker def _test_prelu_base(self, size, num_channels): 648*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=torch.float32) 649*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 650*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 651*da0073e9SAndroid Build Coastguard Worker x3 = x.clone().to_mkldnn().requires_grad_() 652*da0073e9SAndroid Build Coastguard Worker m1 = torch.nn.PReLU(num_channels) 653*da0073e9SAndroid Build Coastguard Worker m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1)) 654*da0073e9SAndroid Build Coastguard Worker m3 = copy.deepcopy(m1) 655*da0073e9SAndroid Build Coastguard Worker y1 = m1(x1) 656*da0073e9SAndroid Build Coastguard Worker y2 = m2(x2).to_dense() 657*da0073e9SAndroid Build Coastguard Worker y3 = m3(x3).to_dense() # Only convert data to mkldnn, weight is Aten tensor 658*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 659*da0073e9SAndroid Build Coastguard Worker loss1.backward() 660*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 661*da0073e9SAndroid Build Coastguard Worker loss2.backward() 662*da0073e9SAndroid Build Coastguard Worker loss3 = y3.sum() 663*da0073e9SAndroid Build Coastguard Worker loss3.backward() 664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y3) 666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x3.grad.to_dense()) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_prelu(self): 670*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16]), 1) 671*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64]), 1) 672*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64]), 64) 673*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112]), 1) 674*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112]), 64) 675*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1) 676*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64) 677*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1) 678*da0073e9SAndroid Build Coastguard Worker self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 681*da0073e9SAndroid Build Coastguard Worker def _test_prelu_bf16_base(self, size, num_channels): 682*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 683*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=torch.float32) 684*da0073e9SAndroid Build Coastguard Worker x_fp32 = x.clone().to_mkldnn().requires_grad_() 685*da0073e9SAndroid Build Coastguard Worker x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_() 686*da0073e9SAndroid Build Coastguard Worker m = mkldnn_utils.to_mkldnn(torch.nn.PReLU()) 687*da0073e9SAndroid Build Coastguard Worker m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker y = m(x_fp32).to_dense() 690*da0073e9SAndroid Build Coastguard Worker y_bf16 = m_bf16(x_bf16).to_dense() 691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker loss = y.sum() 694*da0073e9SAndroid Build Coastguard Worker loss.backward() 695*da0073e9SAndroid Build Coastguard Worker loss_bf16 = y_bf16.sum() 696*da0073e9SAndroid Build Coastguard Worker loss_bf16.backward() 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32)) 698*da0073e9SAndroid Build Coastguard Worker else: 699*da0073e9SAndroid Build Coastguard Worker x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_() 700*da0073e9SAndroid Build Coastguard Worker m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16) 701*da0073e9SAndroid Build Coastguard Worker msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 702*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 703*da0073e9SAndroid Build Coastguard Worker msg, 704*da0073e9SAndroid Build Coastguard Worker lambda: m_bf16(x_bf16)) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker def test_prelu_bf16(self): 707*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16]), 1) 708*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64]), 1) 709*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64]), 64) 710*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1) 711*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64) 712*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1) 713*da0073e9SAndroid Build Coastguard Worker self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64) 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker def _test_max_pool_base(self, dim, input): 716*da0073e9SAndroid Build Coastguard Worker pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d} 717*da0073e9SAndroid Build Coastguard Worker for stride in [1, 2, 3]: 718*da0073e9SAndroid Build Coastguard Worker for ceil_mode in [False, True]: 719*da0073e9SAndroid Build Coastguard Worker max_pool = pool_module[dim]( 720*da0073e9SAndroid Build Coastguard Worker kernel_size=3 if not ceil_mode else 7, 721*da0073e9SAndroid Build Coastguard Worker stride=stride, 722*da0073e9SAndroid Build Coastguard Worker padding=1, 723*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker x1 = input.clone().requires_grad_() 726*da0073e9SAndroid Build Coastguard Worker x2 = input.clone().to_mkldnn().requires_grad_() 727*da0073e9SAndroid Build Coastguard Worker y1 = max_pool(x1) 728*da0073e9SAndroid Build Coastguard Worker y2 = max_pool(x2).to_dense() 729*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 730*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 731*da0073e9SAndroid Build Coastguard Worker loss1.backward() 732*da0073e9SAndroid Build Coastguard Worker loss2.backward() 733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d(self): 737*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 738*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 739*da0073e9SAndroid Build Coastguard Worker for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 740*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 741*da0073e9SAndroid Build Coastguard Worker self._test_max_pool_base(dim=2, input=x) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def test_max_pool3d(self): 744*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 745*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 746*da0073e9SAndroid Build Coastguard Worker for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]: 747*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10 748*da0073e9SAndroid Build Coastguard Worker self._test_max_pool_base(dim=3, input=x) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 752*da0073e9SAndroid Build Coastguard Worker def _test_max_pool_bf16_base(self, dim, input): 753*da0073e9SAndroid Build Coastguard Worker pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d} 754*da0073e9SAndroid Build Coastguard Worker x_bf16 = input.bfloat16() 755*da0073e9SAndroid Build Coastguard Worker for stride in [1, 2, 3]: 756*da0073e9SAndroid Build Coastguard Worker for ceil_mode in [False, True]: 757*da0073e9SAndroid Build Coastguard Worker max_pool = pool_module[dim]( 758*da0073e9SAndroid Build Coastguard Worker kernel_size=3 if not ceil_mode else 7, 759*da0073e9SAndroid Build Coastguard Worker stride=stride, 760*da0073e9SAndroid Build Coastguard Worker padding=1, 761*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 764*da0073e9SAndroid Build Coastguard Worker y = max_pool(input.to_mkldnn()).to_dense() 765*da0073e9SAndroid Build Coastguard Worker y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32) 766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3) 767*da0073e9SAndroid Build Coastguard Worker else: 768*da0073e9SAndroid Build Coastguard Worker msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim 769*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 770*da0073e9SAndroid Build Coastguard Worker msg, 771*da0073e9SAndroid Build Coastguard Worker lambda: max_pool(x_bf16.to_mkldnn())) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d_bf16(self): 774*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 775*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 776*da0073e9SAndroid Build Coastguard Worker for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 777*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 778*da0073e9SAndroid Build Coastguard Worker self._test_max_pool_bf16_base(dim=2, input=x) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker def test_max_pool3d_bf16(self): 781*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 782*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 783*da0073e9SAndroid Build Coastguard Worker for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]: 784*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10 785*da0073e9SAndroid Build Coastguard Worker self._test_max_pool_bf16_base(dim=3, input=x) 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d_stride_none(self): 788*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 789*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]: 792*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, H, W, dtype=torch.float32) * 10 793*da0073e9SAndroid Build Coastguard Worker for ceil_mode in [False, True]: 794*da0073e9SAndroid Build Coastguard Worker y1 = F.max_pool2d( 795*da0073e9SAndroid Build Coastguard Worker x, 796*da0073e9SAndroid Build Coastguard Worker kernel_size=3 if not ceil_mode else 7, 797*da0073e9SAndroid Build Coastguard Worker stride=None, 798*da0073e9SAndroid Build Coastguard Worker padding=1, 799*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker y2 = F.max_pool2d( 802*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn(), 803*da0073e9SAndroid Build Coastguard Worker kernel_size=3 if not ceil_mode else 7, 804*da0073e9SAndroid Build Coastguard Worker stride=None, 805*da0073e9SAndroid Build Coastguard Worker padding=1, 806*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2.to_dense()) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127111 811*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 812*da0073e9SAndroid Build Coastguard Worker def test_max_pool_unsupported(self): 813*da0073e9SAndroid Build Coastguard Worker # OneDNN not support dilation max_pooling, will be avilabled in v2.0. 814*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 815*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker # 2d dilation case 818*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn() 819*da0073e9SAndroid Build Coastguard Worker max_pool2d = torch.nn.MaxPool2d( 820*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 821*da0073e9SAndroid Build Coastguard Worker stride=3, 822*da0073e9SAndroid Build Coastguard Worker padding=1, 823*da0073e9SAndroid Build Coastguard Worker dilation=2) 824*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 825*da0073e9SAndroid Build Coastguard Worker 'mkldnn_max_pool2d does not support dilation case', 826*da0073e9SAndroid Build Coastguard Worker lambda: max_pool2d(x)) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker # 3d dilation case 829*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn() 830*da0073e9SAndroid Build Coastguard Worker max_pool3d = torch.nn.MaxPool3d( 831*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 832*da0073e9SAndroid Build Coastguard Worker stride=3, 833*da0073e9SAndroid Build Coastguard Worker padding=1, 834*da0073e9SAndroid Build Coastguard Worker dilation=2) 835*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 836*da0073e9SAndroid Build Coastguard Worker 'mkldnn_max_pool3d does not support dilation case', 837*da0073e9SAndroid Build Coastguard Worker lambda: max_pool3d(x)) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def _test_avg_pool_base(self, dim, input): 840*da0073e9SAndroid Build Coastguard Worker avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d} 841*da0073e9SAndroid Build Coastguard Worker for count_include_pad in [True, False]: 842*da0073e9SAndroid Build Coastguard Worker avg_pool = avg_module[dim]( 843*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 844*da0073e9SAndroid Build Coastguard Worker stride=2, 845*da0073e9SAndroid Build Coastguard Worker padding=1, 846*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker x1 = input.clone().requires_grad_() 849*da0073e9SAndroid Build Coastguard Worker x2 = input.clone().to_mkldnn().requires_grad_() 850*da0073e9SAndroid Build Coastguard Worker y1 = avg_pool(x1) 851*da0073e9SAndroid Build Coastguard Worker y2 = avg_pool(x2).to_dense() 852*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 853*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 854*da0073e9SAndroid Build Coastguard Worker loss1.backward() 855*da0073e9SAndroid Build Coastguard Worker loss2.backward() 856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d(self): 860*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 861*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 862*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 863*da0073e9SAndroid Build Coastguard Worker self._test_avg_pool_base(dim=2, input=x) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker def test_avg_pool3d(self): 866*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 867*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 868*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10 869*da0073e9SAndroid Build Coastguard Worker self._test_avg_pool_base(dim=3, input=x) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 872*da0073e9SAndroid Build Coastguard Worker def _test_avg_pool_bf16_base(self, dim, input): 873*da0073e9SAndroid Build Coastguard Worker avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d} 874*da0073e9SAndroid Build Coastguard Worker x_bf16 = input.bfloat16() 875*da0073e9SAndroid Build Coastguard Worker for count_include_pad in [True, False]: 876*da0073e9SAndroid Build Coastguard Worker avg_pool = avg_module[dim]( 877*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 878*da0073e9SAndroid Build Coastguard Worker stride=2, 879*da0073e9SAndroid Build Coastguard Worker padding=1, 880*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad) 881*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 882*da0073e9SAndroid Build Coastguard Worker y = avg_pool(input.to_mkldnn()).to_dense() 883*da0073e9SAndroid Build Coastguard Worker y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float) 884*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 885*da0073e9SAndroid Build Coastguard Worker else: 886*da0073e9SAndroid Build Coastguard Worker msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim 887*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 888*da0073e9SAndroid Build Coastguard Worker msg, 889*da0073e9SAndroid Build Coastguard Worker lambda: avg_pool(x_bf16.to_mkldnn())) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_bf16(self): 892*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 893*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 894*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 895*da0073e9SAndroid Build Coastguard Worker self._test_avg_pool_bf16_base(dim=2, input=x) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker def test_avg_pool3d_bf16(self): 898*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 899*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 900*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10 901*da0073e9SAndroid Build Coastguard Worker self._test_avg_pool_bf16_base(dim=3, input=x) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_stride_none(self): 904*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 905*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 906*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker for count_include_pad in [True, False]: 909*da0073e9SAndroid Build Coastguard Worker y1 = F.avg_pool2d( 910*da0073e9SAndroid Build Coastguard Worker x, 911*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 912*da0073e9SAndroid Build Coastguard Worker stride=None, 913*da0073e9SAndroid Build Coastguard Worker padding=1, 914*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad) 915*da0073e9SAndroid Build Coastguard Worker y2 = F.avg_pool2d( 916*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn(), 917*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 918*da0073e9SAndroid Build Coastguard Worker stride=None, 919*da0073e9SAndroid Build Coastguard Worker padding=1, 920*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2.to_dense()) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d(self): 925*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 926*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 927*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7) 930*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 931*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 932*da0073e9SAndroid Build Coastguard Worker y1 = adaptive_avg_pool2d(x1) 933*da0073e9SAndroid Build Coastguard Worker y2 = adaptive_avg_pool2d(x2).to_dense() 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 936*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 937*da0073e9SAndroid Build Coastguard Worker loss1.backward() 938*da0073e9SAndroid Build Coastguard Worker loss2.backward() 939*da0073e9SAndroid Build Coastguard Worker 940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 944*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d_bf16(self): 945*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 946*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 10, (1,)).item() 947*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker x_bf16 = x.bfloat16() 950*da0073e9SAndroid Build Coastguard Worker adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 953*da0073e9SAndroid Build Coastguard Worker y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense() 954*da0073e9SAndroid Build Coastguard Worker y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32) 955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 956*da0073e9SAndroid Build Coastguard Worker else: 957*da0073e9SAndroid Build Coastguard Worker msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 958*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 959*da0073e9SAndroid Build Coastguard Worker msg, 960*da0073e9SAndroid Build Coastguard Worker lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn())) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker def _test_batch_norm_base(self, dim, channels, input): 963*da0073e9SAndroid Build Coastguard Worker bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} 964*da0073e9SAndroid Build Coastguard Worker bn = bn_module[dim](channels).float().train(False) 965*da0073e9SAndroid Build Coastguard Worker mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) 966*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 967*da0073e9SAndroid Build Coastguard Worker bn(input), 968*da0073e9SAndroid Build Coastguard Worker mkldnn_bn(input.to_mkldnn()).to_dense()) 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker self._test_serialization(mkldnn_bn, (input.to_mkldnn(),)) 971*da0073e9SAndroid Build Coastguard Worker self._test_tracing(mkldnn_bn, (input.to_mkldnn(),)) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker def _test_batch_norm_train_base(self, dim, channels, input): 974*da0073e9SAndroid Build Coastguard Worker # TODO: support 3d batchnorm training. 975*da0073e9SAndroid Build Coastguard Worker bn_module = {2 : torch.nn.BatchNorm2d} 976*da0073e9SAndroid Build Coastguard Worker # TODO: support none affine. 977*da0073e9SAndroid Build Coastguard Worker options = itertools.product([True], [True, False]) 978*da0073e9SAndroid Build Coastguard Worker for affine, track_running_stats in options: 979*da0073e9SAndroid Build Coastguard Worker bn = bn_module[dim]( 980*da0073e9SAndroid Build Coastguard Worker num_features=channels, 981*da0073e9SAndroid Build Coastguard Worker affine=affine, 982*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats).float().train(True) 983*da0073e9SAndroid Build Coastguard Worker mkldnn_bn = copy.deepcopy(bn) 984*da0073e9SAndroid Build Coastguard Worker x1 = input.clone().requires_grad_() 985*da0073e9SAndroid Build Coastguard Worker x2 = input.clone().to_mkldnn().requires_grad_() 986*da0073e9SAndroid Build Coastguard Worker y1 = bn(x1) 987*da0073e9SAndroid Build Coastguard Worker y2 = mkldnn_bn(x2).to_dense() 988*da0073e9SAndroid Build Coastguard Worker loss1 = y1.sum() 989*da0073e9SAndroid Build Coastguard Worker loss2 = y2.sum() 990*da0073e9SAndroid Build Coastguard Worker loss1.backward() 991*da0073e9SAndroid Build Coastguard Worker loss2.backward() 992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3) 995*da0073e9SAndroid Build Coastguard Worker if track_running_stats: 996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.running_mean, mkldnn_bn.running_mean) 997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_2d(self): 1000*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1001*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1002*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1003*da0073e9SAndroid Build Coastguard Worker self._test_batch_norm_base(dim=2, channels=C, input=x) 1004*da0073e9SAndroid Build Coastguard Worker self._test_batch_norm_train_base(dim=2, channels=C, input=x) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_3d(self): 1007*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1008*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1009*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10 1010*da0073e9SAndroid Build Coastguard Worker self._test_batch_norm_base(dim=3, channels=C, input=x) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") 1013*da0073e9SAndroid Build Coastguard Worker def _test_batch_norm_bf16_base(self, dim, channels, input): 1014*da0073e9SAndroid Build Coastguard Worker bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} 1015*da0073e9SAndroid Build Coastguard Worker x_bf16 = input.bfloat16() 1016*da0073e9SAndroid Build Coastguard Worker # TODO: support training 1017*da0073e9SAndroid Build Coastguard Worker for train in [False]: 1018*da0073e9SAndroid Build Coastguard Worker bn = bn_module[dim](channels).float().train(train) 1019*da0073e9SAndroid Build Coastguard Worker mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) 1020*da0073e9SAndroid Build Coastguard Worker if torch.ops.mkldnn._is_mkldnn_bf16_supported(): 1021*da0073e9SAndroid Build Coastguard Worker y = bn(input.to_mkldnn().to_dense()) 1022*da0073e9SAndroid Build Coastguard Worker y_bf16 = bn(input.to_mkldnn().to_dense(torch.float)) 1023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3) 1024*da0073e9SAndroid Build Coastguard Worker else: 1025*da0073e9SAndroid Build Coastguard Worker msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" 1026*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 1027*da0073e9SAndroid Build Coastguard Worker msg, 1028*da0073e9SAndroid Build Coastguard Worker lambda: bn(x_bf16.to_mkldnn())) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_2d_bf16(self): 1031*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1032*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1033*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1034*da0073e9SAndroid Build Coastguard Worker self._test_batch_norm_bf16_base(dim=2, channels=C, input=x) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_3d_bf16(self): 1037*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1038*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1039*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10 1040*da0073e9SAndroid Build Coastguard Worker self._test_batch_norm_bf16_base(dim=3, channels=C, input=x) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker def test_add(self): 1043*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1044*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1045*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(1, dtype=torch.float32).item() 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1048*da0073e9SAndroid Build Coastguard Worker y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1049*da0073e9SAndroid Build Coastguard Worker mx = x.to_mkldnn() 1050*da0073e9SAndroid Build Coastguard Worker my = y.to_mkldnn() 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker # add 1053*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1054*da0073e9SAndroid Build Coastguard Worker x + y, 1055*da0073e9SAndroid Build Coastguard Worker (mx + my).to_dense()) 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1058*da0073e9SAndroid Build Coastguard Worker torch.add(x, y, alpha=alpha), 1059*da0073e9SAndroid Build Coastguard Worker torch.add(mx, my, alpha=alpha).to_dense()) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker # add_ 1062*da0073e9SAndroid Build Coastguard Worker x += y 1063*da0073e9SAndroid Build Coastguard Worker mx += my 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mx.to_dense()) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker # add_out 1067*da0073e9SAndroid Build Coastguard Worker out = x.clone() 1068*da0073e9SAndroid Build Coastguard Worker mkldnn_out = out.to_mkldnn() 1069*da0073e9SAndroid Build Coastguard Worker torch.add(x, y, alpha=alpha, out=out) 1070*da0073e9SAndroid Build Coastguard Worker torch.add(mx, my, alpha=alpha, out=mkldnn_out) 1071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, mkldnn_out.to_dense()) 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker # add_out inplace case: first input 1074*da0073e9SAndroid Build Coastguard Worker torch.add(x, y, alpha=alpha, out=x) 1075*da0073e9SAndroid Build Coastguard Worker torch.add(mx, my, alpha=alpha, out=mx) 1076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mx.to_dense()) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker # add_out inplace case: second input 1079*da0073e9SAndroid Build Coastguard Worker torch.add(x, y, alpha=alpha, out=y) 1080*da0073e9SAndroid Build Coastguard Worker torch.add(mx, my, alpha=alpha, out=my) 1081*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, my.to_dense()) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker def test_mul(self): 1084*da0073e9SAndroid Build Coastguard Worker N = torch.randint(3, 10, (1,)).item() 1085*da0073e9SAndroid Build Coastguard Worker C = torch.randint(3, 100, (1,)).item() 1086*da0073e9SAndroid Build Coastguard Worker value = torch.randn(1, dtype=torch.float32).item() 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1089*da0073e9SAndroid Build Coastguard Worker y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 1090*da0073e9SAndroid Build Coastguard Worker mx = x.to_mkldnn() 1091*da0073e9SAndroid Build Coastguard Worker my = y.to_mkldnn() 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker # mul 1094*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1095*da0073e9SAndroid Build Coastguard Worker x * y, 1096*da0073e9SAndroid Build Coastguard Worker (mx * my).to_dense()) 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1099*da0073e9SAndroid Build Coastguard Worker x * value, 1100*da0073e9SAndroid Build Coastguard Worker (mx * value).to_dense()) 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1103*da0073e9SAndroid Build Coastguard Worker torch.mul(x, y), 1104*da0073e9SAndroid Build Coastguard Worker torch.mul(mx, my).to_dense()) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1107*da0073e9SAndroid Build Coastguard Worker torch.mul(x, value), 1108*da0073e9SAndroid Build Coastguard Worker torch.mul(mx, value).to_dense()) 1109*da0073e9SAndroid Build Coastguard Worker 1110*da0073e9SAndroid Build Coastguard Worker # mul_ 1111*da0073e9SAndroid Build Coastguard Worker x *= y 1112*da0073e9SAndroid Build Coastguard Worker mx *= my 1113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mx.to_dense()) 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker x *= value 1116*da0073e9SAndroid Build Coastguard Worker mx *= value 1117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mx.to_dense()) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker # mul_out 1120*da0073e9SAndroid Build Coastguard Worker out = x.clone() 1121*da0073e9SAndroid Build Coastguard Worker mkldnn_out = out.to_mkldnn() 1122*da0073e9SAndroid Build Coastguard Worker torch.mul(x, y, out=out) 1123*da0073e9SAndroid Build Coastguard Worker torch.mul(mx, my, out=mkldnn_out) 1124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, mkldnn_out.to_dense()) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker out = x.clone() 1127*da0073e9SAndroid Build Coastguard Worker mkldnn_out = out.to_mkldnn() 1128*da0073e9SAndroid Build Coastguard Worker torch.mul(x, value, out=out) 1129*da0073e9SAndroid Build Coastguard Worker torch.mul(mx, value, out=mkldnn_out) 1130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, mkldnn_out.to_dense()) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker def test_0_dimension_tensor(self): 1133*da0073e9SAndroid Build Coastguard Worker x = torch.rand([20, 20, 1, 1], dtype=torch.float) 1134*da0073e9SAndroid Build Coastguard Worker y = torch.rand([20, 20, 0, 1], dtype=torch.float) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # unary ops work without modification 1137*da0073e9SAndroid Build Coastguard Worker out_relu = torch.relu(y) 1138*da0073e9SAndroid Build Coastguard Worker out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense() 1139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_relu, out_relu_mkldnn) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker out_mul = x * y 1142*da0073e9SAndroid Build Coastguard Worker out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense() 1143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mul, out_mul_mkldnn) 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker out_add = x + y 1146*da0073e9SAndroid Build Coastguard Worker out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense() 1147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_add, out_add_mkldnn) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 1150*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 1151*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"): 1152*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn() + y.to_mkldnn() 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must match"): 1155*da0073e9SAndroid Build Coastguard Worker torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn() 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker C = 7 1158*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Conv2d(C, C, 3) 1159*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, C, C, 8, dtype=torch.float) 1160*da0073e9SAndroid Build Coastguard Worker out_eager = m(x) 1161*da0073e9SAndroid Build Coastguard Worker out_mkldnn = mkldnn_utils.to_mkldnn(m)(x) 1162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_eager, out_mkldnn) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127111 1165*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 1166*da0073e9SAndroid Build Coastguard Worker def test_view(self): 1167*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() 1168*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 1169*da0073e9SAndroid Build Coastguard Worker "Change to use reshape", 1170*da0073e9SAndroid Build Coastguard Worker lambda: x.view(x.size(0), -1)) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker def test_reshape(self): 1173*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1174*da0073e9SAndroid Build Coastguard Worker size = (x.size(0), -1) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1177*da0073e9SAndroid Build Coastguard Worker x.reshape(size), 1178*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn().reshape(size).to_dense(), 1179*da0073e9SAndroid Build Coastguard Worker ) 1180*da0073e9SAndroid Build Coastguard Worker # test whether share same memory for plain format tensor 1181*da0073e9SAndroid Build Coastguard Worker y = x.to_mkldnn() 1182*da0073e9SAndroid Build Coastguard Worker z = y.reshape(size).add_(y.reshape(size)) 1183*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1184*da0073e9SAndroid Build Coastguard Worker y.reshape(size).to_dense(), 1185*da0073e9SAndroid Build Coastguard Worker z.to_dense(), 1186*da0073e9SAndroid Build Coastguard Worker ) 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker def test_reshape_blocked_format(self): 1189*da0073e9SAndroid Build Coastguard Worker # construct an mkldnn blocked tensor with mkldnn conv2d 1190*da0073e9SAndroid Build Coastguard Worker C = 7 1191*da0073e9SAndroid Build Coastguard Worker m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3)) 1192*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, C, 8, 8).to_mkldnn() 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker # mkldnn tensor w/ blocked format 1195*da0073e9SAndroid Build Coastguard Worker y_block = m(x) 1196*da0073e9SAndroid Build Coastguard Worker # aten tensor w/ plain format 1197*da0073e9SAndroid Build Coastguard Worker y_plain = y_block.to_dense() 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker y_block_reshape = y_block.reshape(C, -1) 1200*da0073e9SAndroid Build Coastguard Worker y_plain_reshape = y_plain.reshape(C, -1) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_plain_reshape, y_block_reshape.to_dense()) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker def test_reshape_backward(self): 1205*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1206*da0073e9SAndroid Build Coastguard Worker size = (x.size(0), -1) 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 1209*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 1210*da0073e9SAndroid Build Coastguard Worker in_features = 20 1211*da0073e9SAndroid Build Coastguard Worker out_features = torch.randint(3, 100, (1,)).item() 1212*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(in_features, out_features).float() 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker y1 = linear(x1.reshape(size)).sum() 1215*da0073e9SAndroid Build Coastguard Worker y2 = linear(x2.reshape(size).to_dense()).sum() 1216*da0073e9SAndroid Build Coastguard Worker y1.backward() 1217*da0073e9SAndroid Build Coastguard Worker y2.backward() 1218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker def test_clone(self): 1221*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, dtype=torch.float32) * 10 1222*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1223*da0073e9SAndroid Build Coastguard Worker x.clone(), 1224*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn().clone().to_dense(), 1225*da0073e9SAndroid Build Coastguard Worker ) 1226*da0073e9SAndroid Build Coastguard Worker # test whether share same memory 1227*da0073e9SAndroid Build Coastguard Worker y = x.to_mkldnn() 1228*da0073e9SAndroid Build Coastguard Worker z = y.clone().add_(y) 1229*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 1230*da0073e9SAndroid Build Coastguard Worker y.to_dense(), 1231*da0073e9SAndroid Build Coastguard Worker z.to_dense(), 1232*da0073e9SAndroid Build Coastguard Worker ) 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker def test_transpose(self): 1235*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1236*da0073e9SAndroid Build Coastguard Worker for dim1 in range(x.ndim): 1237*da0073e9SAndroid Build Coastguard Worker for dim2 in range(x.ndim): 1238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1239*da0073e9SAndroid Build Coastguard Worker x.transpose(dim1, dim2), 1240*da0073e9SAndroid Build Coastguard Worker x.to_mkldnn().transpose(dim1, dim2).to_dense(), 1241*da0073e9SAndroid Build Coastguard Worker ) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker def test_transpose_invalid_dime(self): 1244*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() 1245*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1246*da0073e9SAndroid Build Coastguard Worker torch._mkldnn_transpose(x, 0, 12) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker def test_linear_non_contiguous_weight(self): 1249*da0073e9SAndroid Build Coastguard Worker in_features = torch.randint(3, 10, (1,)).item() 1250*da0073e9SAndroid Build Coastguard Worker out_features = torch.randint(3, 100, (1,)).item() 1251*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, in_features, dtype=torch.float32) * 10 1252*da0073e9SAndroid Build Coastguard Worker w = torch.randn(in_features, out_features, dtype=torch.float32) 1253*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 1254*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 1255*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 1256*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(in_features, out_features).float() 1257*da0073e9SAndroid Build Coastguard Worker linear.weight = torch.nn.Parameter(w.t()) 1258*da0073e9SAndroid Build Coastguard Worker mkldnn_linear = copy.deepcopy(linear) 1259*da0073e9SAndroid Build Coastguard Worker y1 = linear(x1).sum() 1260*da0073e9SAndroid Build Coastguard Worker y2 = mkldnn_linear(x2).to_dense().sum() 1261*da0073e9SAndroid Build Coastguard Worker y1.backward() 1262*da0073e9SAndroid Build Coastguard Worker y2.backward() 1263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 1264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad) 1265*da0073e9SAndroid Build Coastguard Worker if bias: 1266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad) 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker def test_linear(self): 1269*da0073e9SAndroid Build Coastguard Worker in_features = torch.randint(3, 10, (1,)).item() 1270*da0073e9SAndroid Build Coastguard Worker out_features = torch.randint(3, 100, (1,)).item() 1271*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, in_features, dtype=torch.float32) * 10 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 1274*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(in_features, out_features, bias=bias).float() 1275*da0073e9SAndroid Build Coastguard Worker mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) 1276*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1277*da0073e9SAndroid Build Coastguard Worker linear(x), 1278*da0073e9SAndroid Build Coastguard Worker mkldnn_linear(x.to_mkldnn()).to_dense()) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker self._test_serialization(mkldnn_linear, (x.to_mkldnn(),)) 1281*da0073e9SAndroid Build Coastguard Worker self._test_tracing(mkldnn_linear, (x.to_mkldnn(),)) 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker def test_linear_backward(self): 1284*da0073e9SAndroid Build Coastguard Worker in_features = torch.randint(3, 10, (1,)).item() 1285*da0073e9SAndroid Build Coastguard Worker out_features = torch.randint(3, 100, (1,)).item() 1286*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, in_features, dtype=torch.float32) * 10 1287*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 1288*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_() 1289*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().to_mkldnn().requires_grad_() 1290*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(in_features, out_features).float() 1291*da0073e9SAndroid Build Coastguard Worker mkldnn_linear = copy.deepcopy(linear) 1292*da0073e9SAndroid Build Coastguard Worker y1 = linear(x1).sum() 1293*da0073e9SAndroid Build Coastguard Worker y2 = mkldnn_linear(x2).to_dense().sum() 1294*da0073e9SAndroid Build Coastguard Worker y1.backward() 1295*da0073e9SAndroid Build Coastguard Worker y2.backward() 1296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad.to_dense()) 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad) 1298*da0073e9SAndroid Build Coastguard Worker if bias: 1299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 1302*da0073e9SAndroid Build Coastguard Worker def test_linear_lowp(self, dtype): 1303*da0073e9SAndroid Build Coastguard Worker in_features = torch.randint(3, 10, (1,)).item() 1304*da0073e9SAndroid Build Coastguard Worker out_features = torch.randint(3, 100, (1,)).item() 1305*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, in_features, dtype=torch.float32) * 10 1306*da0073e9SAndroid Build Coastguard Worker x_lowp = x.to(dtype=dtype) 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 1309*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(in_features, out_features, bias=bias).float() 1310*da0073e9SAndroid Build Coastguard Worker mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) 1311*da0073e9SAndroid Build Coastguard Worker mkldnn_linear_lowp = mkldnn_utils.to_mkldnn( 1312*da0073e9SAndroid Build Coastguard Worker copy.deepcopy(linear), dtype 1313*da0073e9SAndroid Build Coastguard Worker ) 1314*da0073e9SAndroid Build Coastguard Worker lowp_support = { 1315*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 1316*da0073e9SAndroid Build Coastguard Worker torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported, 1317*da0073e9SAndroid Build Coastguard Worker } 1318*da0073e9SAndroid Build Coastguard Worker if lowp_support[dtype](): 1319*da0073e9SAndroid Build Coastguard Worker y = mkldnn_linear(x.to_mkldnn()).to_dense() 1320*da0073e9SAndroid Build Coastguard Worker y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense( 1321*da0073e9SAndroid Build Coastguard Worker torch.float32 1322*da0073e9SAndroid Build Coastguard Worker ) 1323*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 1324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3) 1325*da0073e9SAndroid Build Coastguard Worker else: 1326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3) 1327*da0073e9SAndroid Build Coastguard Worker else: 1328*da0073e9SAndroid Build Coastguard Worker msg = { 1329*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq", 1330*da0073e9SAndroid Build Coastguard Worker torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16", 1331*da0073e9SAndroid Build Coastguard Worker } 1332*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1333*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1334*da0073e9SAndroid Build Coastguard Worker msg[dtype], 1335*da0073e9SAndroid Build Coastguard Worker lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()), 1336*da0073e9SAndroid Build Coastguard Worker ) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker def test_softmax(self): 1339*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 1340*da0073e9SAndroid Build Coastguard Worker for dim in range(x.ndim): 1341*da0073e9SAndroid Build Coastguard Worker softmax = torch.nn.Softmax(dim=dim) 1342*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1343*da0073e9SAndroid Build Coastguard Worker softmax(x), 1344*da0073e9SAndroid Build Coastguard Worker softmax(x.to_mkldnn()).to_dense()) 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker def test_sigmoid(self): 1347*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, dtype=torch.float32) * 10 1348*da0073e9SAndroid Build Coastguard Worker mkldnn_x = x.to_mkldnn() 1349*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1350*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(x), 1351*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(mkldnn_x).to_dense(), 1352*da0073e9SAndroid Build Coastguard Worker ) 1353*da0073e9SAndroid Build Coastguard Worker # inplace 1354*da0073e9SAndroid Build Coastguard Worker torch.sigmoid_(x) 1355*da0073e9SAndroid Build Coastguard Worker torch.sigmoid_(mkldnn_x) 1356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mkldnn_x.to_dense()) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker def test_tanh(self): 1359*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, dtype=torch.float32) * 10 1360*da0073e9SAndroid Build Coastguard Worker mkldnn_x = x.to_mkldnn() 1361*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1362*da0073e9SAndroid Build Coastguard Worker torch.tanh(x), 1363*da0073e9SAndroid Build Coastguard Worker torch.tanh(mkldnn_x).to_dense(), 1364*da0073e9SAndroid Build Coastguard Worker ) 1365*da0073e9SAndroid Build Coastguard Worker # inplace 1366*da0073e9SAndroid Build Coastguard Worker torch.tanh_(x) 1367*da0073e9SAndroid Build Coastguard Worker torch.tanh_(mkldnn_x) 1368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, mkldnn_x.to_dense()) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker def _test_serialization(self, module, inputs): 1371*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 1372*da0073e9SAndroid Build Coastguard Worker torch.jit.save(module, fname) 1373*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 1374*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1375*da0073e9SAndroid Build Coastguard Worker module(*inputs).to_dense(), 1376*da0073e9SAndroid Build Coastguard Worker loaded(*inputs).to_dense()) 1377*da0073e9SAndroid Build Coastguard Worker 1378*da0073e9SAndroid Build Coastguard Worker def _test_tracing(self, module, inputs): 1379*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(module, inputs) 1380*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1381*da0073e9SAndroid Build Coastguard Worker module(*inputs).to_dense(), 1382*da0073e9SAndroid Build Coastguard Worker traced(*inputs).to_dense()) 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker def test_set_data_tensorimpl_type(self): 1385*da0073e9SAndroid Build Coastguard Worker # Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl 1386*da0073e9SAndroid Build Coastguard Worker # of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`. 1387*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu')) 1388*da0073e9SAndroid Build Coastguard Worker x_mkldnn = x.to_mkldnn() 1389*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'): 1390*da0073e9SAndroid Build Coastguard Worker x.data = x_mkldnn 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def test_empty(self): 1393*da0073e9SAndroid Build Coastguard Worker x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32) 1394*da0073e9SAndroid Build Coastguard Worker x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn) 1395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.size(), x2.to_dense().size()) 1396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.dtype, x2.to_dense().dtype) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker def test_zero_(self): 1399*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(4, 5, dtype=torch.float32) * 10 1400*da0073e9SAndroid Build Coastguard Worker x2 = x1.clone().to_mkldnn() 1401*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1402*da0073e9SAndroid Build Coastguard Worker x1.zero_(), 1403*da0073e9SAndroid Build Coastguard Worker x2.zero_().to_dense(), 1404*da0073e9SAndroid Build Coastguard Worker ) 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker def test_is_mkldnn(self): 1407*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, dtype=torch.float32) 1408*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_mkldnn) 1409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.to_mkldnn().is_mkldnn) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker # legacy constructor/new doesn't support mkldnn tensors 1412*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992") 1413*da0073e9SAndroid Build Coastguard Worker def test_legacy_new_failure(self): 1414*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, dtype=torch.float32) 1415*da0073e9SAndroid Build Coastguard Worker x_mkldnn = x.to_mkldnn() 1416*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu')) 1417*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage())) 1418*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x)) 1419*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3]))) 1420*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6])) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker def test_is_mkldnn_jit(self): 1423*da0073e9SAndroid Build Coastguard Worker class EnsureMkldnn(torch.jit.ScriptModule): 1424*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 1425*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1426*da0073e9SAndroid Build Coastguard Worker if not x.is_mkldnn: 1427*da0073e9SAndroid Build Coastguard Worker x = x.to_mkldnn() 1428*da0073e9SAndroid Build Coastguard Worker return x 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker m = EnsureMkldnn() 1431*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, dtype=torch.float32) 1432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m(x).is_mkldnn) 1433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m(x.to_mkldnn()).is_mkldnn) 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker def _test_imagenet_model(self, model): 1436*da0073e9SAndroid Build Coastguard Worker model = model.train(False).float() 1437*da0073e9SAndroid Build Coastguard Worker mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model)) 1438*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 3, 224, 224, dtype=torch.float32) 1439*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1440*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1441*da0073e9SAndroid Build Coastguard Worker model(x), 1442*da0073e9SAndroid Build Coastguard Worker mkldnn_model(x.to_mkldnn()).to_dense(), 1443*da0073e9SAndroid Build Coastguard Worker ) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1446*da0073e9SAndroid Build Coastguard Worker def test_resnet18(self): 1447*da0073e9SAndroid Build Coastguard Worker model = torchvision.models.resnet.resnet18(weights=None) 1448*da0073e9SAndroid Build Coastguard Worker self._test_imagenet_model(model) 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1451*da0073e9SAndroid Build Coastguard Worker def test_resnext50_32x4d(self): 1452*da0073e9SAndroid Build Coastguard Worker model = torchvision.models.resnet.resnext50_32x4d(weights=None) 1453*da0073e9SAndroid Build Coastguard Worker self._test_imagenet_model(model) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker def _lstm_params_list(self): 1456*da0073e9SAndroid Build Coastguard Worker params_dict = { 1457*da0073e9SAndroid Build Coastguard Worker "input_size": [1, 5], 1458*da0073e9SAndroid Build Coastguard Worker "hidden_size": [5, 16], 1459*da0073e9SAndroid Build Coastguard Worker "num_layers": [1, 3], 1460*da0073e9SAndroid Build Coastguard Worker "bidirectional": [False, True], 1461*da0073e9SAndroid Build Coastguard Worker "bias": [False, True], 1462*da0073e9SAndroid Build Coastguard Worker "batch_first": [False, True], 1463*da0073e9SAndroid Build Coastguard Worker "dropout": [0, 0.4, 0.7, 1], 1464*da0073e9SAndroid Build Coastguard Worker "batch_size": [1, 2], 1465*da0073e9SAndroid Build Coastguard Worker "seq_len": [1, 3], 1466*da0073e9SAndroid Build Coastguard Worker "training": [False, True] 1467*da0073e9SAndroid Build Coastguard Worker } 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker params_list = list(params_dict.values()) 1470*da0073e9SAndroid Build Coastguard Worker return params_list 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker def _cast_dtype(self, input, dtype): 1473*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 1474*da0073e9SAndroid Build Coastguard Worker input = input.to(torch.bfloat16) 1475*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.half: 1476*da0073e9SAndroid Build Coastguard Worker input = input.to(torch.half) 1477*da0073e9SAndroid Build Coastguard Worker return input 1478*da0073e9SAndroid Build Coastguard Worker 1479*da0073e9SAndroid Build Coastguard Worker def test_lstm(self): 1480*da0073e9SAndroid Build Coastguard Worker seed = 2023 1481*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker params_list = self._lstm_params_list() 1484*da0073e9SAndroid Build Coastguard Worker for dtype in types: 1485*da0073e9SAndroid Build Coastguard Worker bf16 = dtype == torch.bfloat16 1486*da0073e9SAndroid Build Coastguard Worker fp16 = dtype == torch.half 1487*da0073e9SAndroid Build Coastguard Worker rtol = 1.3e-6 1488*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker if bf16: 1491*da0073e9SAndroid Build Coastguard Worker rtol = 0.02 1492*da0073e9SAndroid Build Coastguard Worker atol = 0.02 1493*da0073e9SAndroid Build Coastguard Worker if fp16: 1494*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 1495*da0073e9SAndroid Build Coastguard Worker atol = 1e-3 1496*da0073e9SAndroid Build Coastguard Worker for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \ 1497*da0073e9SAndroid Build Coastguard Worker in itertools.product(*params_list): 1498*da0073e9SAndroid Build Coastguard Worker num_directions = 2 if bidirectional else 1 1499*da0073e9SAndroid Build Coastguard Worker if batch_first: 1500*da0073e9SAndroid Build Coastguard Worker input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32) 1501*da0073e9SAndroid Build Coastguard Worker else: 1502*da0073e9SAndroid Build Coastguard Worker input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32) 1503*da0073e9SAndroid Build Coastguard Worker h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) 1504*da0073e9SAndroid Build Coastguard Worker c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) 1505*da0073e9SAndroid Build Coastguard Worker if fp16: 1506*da0073e9SAndroid Build Coastguard Worker # TODO add traing support when oneDNN support lstm FP16 training 1507*da0073e9SAndroid Build Coastguard Worker training = False 1508*da0073e9SAndroid Build Coastguard Worker model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, 1509*da0073e9SAndroid Build Coastguard Worker bias=bias, dropout=dropout, batch_first=batch_first).float() 1510*da0073e9SAndroid Build Coastguard Worker model.train() if training else model.eval() 1511*da0073e9SAndroid Build Coastguard Worker input1 = input.clone().requires_grad_(training) 1512*da0073e9SAndroid Build Coastguard Worker input2 = input.clone().requires_grad_(training) 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker h1 = h.clone().requires_grad_(training) 1515*da0073e9SAndroid Build Coastguard Worker h2 = h.clone().requires_grad_(training) 1516*da0073e9SAndroid Build Coastguard Worker c1 = c.clone().requires_grad_(training) 1517*da0073e9SAndroid Build Coastguard Worker c2 = c.clone().requires_grad_(training) 1518*da0073e9SAndroid Build Coastguard Worker 1519*da0073e9SAndroid Build Coastguard Worker model1 = copy.deepcopy(model) 1520*da0073e9SAndroid Build Coastguard Worker model2 = copy.deepcopy(model) 1521*da0073e9SAndroid Build Coastguard Worker with torch.no_grad() if not training else nullcontext(): 1522*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 1523*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1524*da0073e9SAndroid Build Coastguard Worker output1, (hn1, cn1) = self._cast_dtype(model1, dtype)( 1525*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(input1, dtype), 1526*da0073e9SAndroid Build Coastguard Worker ( 1527*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(h1, dtype), 1528*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(c1, dtype), 1529*da0073e9SAndroid Build Coastguard Worker ), 1530*da0073e9SAndroid Build Coastguard Worker ) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1533*da0073e9SAndroid Build Coastguard Worker output2, (hn2, cn2) = self._cast_dtype(model2, dtype)( 1534*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(input2, dtype), 1535*da0073e9SAndroid Build Coastguard Worker ( 1536*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(h2, dtype), 1537*da0073e9SAndroid Build Coastguard Worker self._cast_dtype(c2, dtype), 1538*da0073e9SAndroid Build Coastguard Worker ), 1539*da0073e9SAndroid Build Coastguard Worker ) 1540*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2, rtol=rtol, atol=atol) 1541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hn1, hn2, rtol=rtol, atol=atol) 1542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cn1, cn2, rtol=rtol, atol=atol) 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker if training: 1545*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 1546*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1547*da0073e9SAndroid Build Coastguard Worker output1.sum().backward(retain_graph=True) 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1550*da0073e9SAndroid Build Coastguard Worker output2.sum().backward(retain_graph=True) 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol) 1553*da0073e9SAndroid Build Coastguard Worker for name, para in model1.named_parameters(): 1554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(para, getattr(model2, name)) 1555*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1556*da0073e9SAndroid Build Coastguard Worker para.grad, 1557*da0073e9SAndroid Build Coastguard Worker getattr(model2, name).grad, 1558*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1559*da0073e9SAndroid Build Coastguard Worker atol=atol, 1560*da0073e9SAndroid Build Coastguard Worker ) 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 1563*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1564*da0073e9SAndroid Build Coastguard Worker hn1.sum().backward(retain_graph=True) 1565*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1566*da0073e9SAndroid Build Coastguard Worker hn2.sum().backward(retain_graph=True) 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 1570*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1571*da0073e9SAndroid Build Coastguard Worker cn1.sum().backward(retain_graph=True) 1572*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 1573*da0073e9SAndroid Build Coastguard Worker cn2.sum().backward(retain_graph=True) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 1577*da0073e9SAndroid Build Coastguard Worker def test_matmul_lower_precision(self, dtype): 1578*da0073e9SAndroid Build Coastguard Worker support_check = { 1579*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported, 1580*da0073e9SAndroid Build Coastguard Worker torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported, 1581*da0073e9SAndroid Build Coastguard Worker } 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker def common(self, shape1, shape2, op, dtype): 1584*da0073e9SAndroid Build Coastguard Worker a = torch.randn(shape1, dtype=dtype) 1585*da0073e9SAndroid Build Coastguard Worker a_ref = a.float() 1586*da0073e9SAndroid Build Coastguard Worker b = torch.randn(shape2, dtype=dtype) 1587*da0073e9SAndroid Build Coastguard Worker b_ref = b.float() 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker y = op(a, b) 1590*da0073e9SAndroid Build Coastguard Worker y_ref = op(a_ref, b_ref) 1591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref, exact_dtype=False) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker if support_check[dtype](): 1594*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn([64, 1, 33], dtype=dtype) 1595*da0073e9SAndroid Build Coastguard Worker # a2 is contiguous tensor but it's strides 1596*da0073e9SAndroid Build Coastguard Worker # is not default contiguous strides. 1597*da0073e9SAndroid Build Coastguard Worker a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1]) 1598*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a2.is_contiguous()) 1599*da0073e9SAndroid Build Coastguard Worker b = torch.randn(64, 33, 256).to(dtype=dtype) 1600*da0073e9SAndroid Build Coastguard Worker y1 = torch.ops.aten.bmm(a1, b) 1601*da0073e9SAndroid Build Coastguard Worker y2 = torch.bmm(a2, b) 1602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker for shape1, shape2, op in [ 1605*da0073e9SAndroid Build Coastguard Worker ((33, 77), (77, 22), torch.matmul), 1606*da0073e9SAndroid Build Coastguard Worker ((128, 256), (256, 10), torch.matmul), 1607*da0073e9SAndroid Build Coastguard Worker ((7, 300), (300, 3), torch.matmul), 1608*da0073e9SAndroid Build Coastguard Worker ((1, 100), (100, 60), torch.matmul), 1609*da0073e9SAndroid Build Coastguard Worker ((100, 1), (1, 100), torch.matmul), 1610*da0073e9SAndroid Build Coastguard Worker ((20, 54, 78), (20, 78, 10), torch.bmm), 1611*da0073e9SAndroid Build Coastguard Worker ((1, 300, 1), (1, 1, 300), torch.bmm), 1612*da0073e9SAndroid Build Coastguard Worker ]: 1613*da0073e9SAndroid Build Coastguard Worker common(self, shape1, shape2, op, dtype) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker 1616*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 1619*da0073e9SAndroid Build Coastguard Worker run_tests() 1620