xref: /aosp_15_r20/external/pytorch/test/test_mkldnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: mkldnn"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport copy
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport functools
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerfrom contextlib import nullcontext
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workertry:
10*da0073e9SAndroid Build Coastguard Worker    import torchvision
11*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
12*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
13*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
19*da0073e9SAndroid Build Coastguard Workerimport torch.jit
20*da0073e9SAndroid Build Coastguard Workerimport torch.backends.mkldnn
21*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import mkldnn as mkldnn_utils
22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, \
23*da0073e9SAndroid Build Coastguard Worker    run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
24*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo, xfailIfTorchDynamo
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
26*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
27*da0073e9SAndroid Build Coastguard Worker    dtypes,
28*da0073e9SAndroid Build Coastguard Worker)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# batched grad doesn't support mkldnn
31*da0073e9SAndroid Build Coastguard Workergradcheck = functools.partial(gradcheck, check_batched_grad=False)
32*da0073e9SAndroid Build Coastguard Workergradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workertypes = [torch.float, torch.bfloat16, torch.half]
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker# Comment the line below to find out the CI machines having MKL-DNN build disabled
38*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
39*da0073e9SAndroid Build Coastguard Workerclass TestMkldnn(TestCase):
40*da0073e9SAndroid Build Coastguard Worker    def test_conversion(self):
41*da0073e9SAndroid Build Coastguard Worker        for cpu_tensor in [torch.randn((1, 2, 3, 4),
42*da0073e9SAndroid Build Coastguard Worker                                       dtype=torch.float, device=torch.device('cpu')),
43*da0073e9SAndroid Build Coastguard Worker                           torch.randn((1, 2, 3, 4, 5),
44*da0073e9SAndroid Build Coastguard Worker                                       dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]:
45*da0073e9SAndroid Build Coastguard Worker            cpu_tensor.requires_grad_()
46*da0073e9SAndroid Build Coastguard Worker            convert_dtypes = {torch.half: [torch.half, torch.float],
47*da0073e9SAndroid Build Coastguard Worker                              torch.bfloat16: [torch.bfloat16, torch.float],
48*da0073e9SAndroid Build Coastguard Worker                              torch.float: [torch.bfloat16, torch.half]}
49*da0073e9SAndroid Build Coastguard Worker            # float/bfloat16/half cpu tensor to mkldnn tensortensor.
50*da0073e9SAndroid Build Coastguard Worker            for dtype1 in types:
51*da0073e9SAndroid Build Coastguard Worker                mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1)
52*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.dtype, dtype1)
53*da0073e9SAndroid Build Coastguard Worker                cpu_tensor_1 = mkldnn_tensor.to_dense()
54*da0073e9SAndroid Build Coastguard Worker                # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
55*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
56*da0073e9SAndroid Build Coastguard Worker                # mkldnn float/bfloat tensor to cpu float or bfloat tensor
57*da0073e9SAndroid Build Coastguard Worker                for dtype2 in convert_dtypes[dtype1]:
58*da0073e9SAndroid Build Coastguard Worker                    cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
59*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cpu_tensor_2.dtype, dtype2)
60*da0073e9SAndroid Build Coastguard Worker                    atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2
61*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
64*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
65*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
66*da0073e9SAndroid Build Coastguard Worker                if dtype1 == torch.float:
67*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
68*da0073e9SAndroid Build Coastguard Worker                else:
69*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2)
70*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError,
71*da0073e9SAndroid Build Coastguard Worker                                       "Cannot access data pointer of Tensor that doesn't have storage",
72*da0073e9SAndroid Build Coastguard Worker                                       lambda: mkldnn_tensor.data_ptr() != 0)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker            # bfloat cpu tensor to mkldnn float tensor or bfloat tensor.
75*da0073e9SAndroid Build Coastguard Worker            for orig_dtype in [torch.half, torch.bfloat16]:
76*da0073e9SAndroid Build Coastguard Worker                cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype)
77*da0073e9SAndroid Build Coastguard Worker                for dtype1 in convert_dtypes[orig_dtype]:
78*da0073e9SAndroid Build Coastguard Worker                    mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1)
79*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.dtype, dtype1)
80*da0073e9SAndroid Build Coastguard Worker                    cpu_tensor_1 = mkldnn_tensor.to_dense()
81*da0073e9SAndroid Build Coastguard Worker                    # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
82*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
83*da0073e9SAndroid Build Coastguard Worker                    # mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor
84*da0073e9SAndroid Build Coastguard Worker                    for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]:
85*da0073e9SAndroid Build Coastguard Worker                        cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
86*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(cpu_tensor_2.dtype, dtype2)
87*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(cpu_tensor_lower,
88*da0073e9SAndroid Build Coastguard Worker                                         cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
91*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
92*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
93*da0073e9SAndroid Build Coastguard Worker                    if dtype1 in [torch.bfloat16, torch.half]:
94*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size())
95*da0073e9SAndroid Build Coastguard Worker                    else:
96*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2)
97*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError,
98*da0073e9SAndroid Build Coastguard Worker                                           "Cannot access data pointer of Tensor that doesn't have storage",
99*da0073e9SAndroid Build Coastguard Worker                                           lambda: mkldnn_tensor.data_ptr() != 0)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_conversion_byte_char(self):
102*da0073e9SAndroid Build Coastguard Worker        int8_types = [torch.int8, torch.uint8]
103*da0073e9SAndroid Build Coastguard Worker        for int8_type in int8_types:
104*da0073e9SAndroid Build Coastguard Worker            low = -100 if int8_type is torch.int8 else 0
105*da0073e9SAndroid Build Coastguard Worker            high = 100
106*da0073e9SAndroid Build Coastguard Worker            for cpu_tensor in [torch.randint(
107*da0073e9SAndroid Build Coastguard Worker                               low=low,
108*da0073e9SAndroid Build Coastguard Worker                               high=high,
109*da0073e9SAndroid Build Coastguard Worker                               size=(1, 2, 3, 4),
110*da0073e9SAndroid Build Coastguard Worker                               dtype=torch.int64,
111*da0073e9SAndroid Build Coastguard Worker                               device=torch.device('cpu')),
112*da0073e9SAndroid Build Coastguard Worker                               torch.randint(
113*da0073e9SAndroid Build Coastguard Worker                               low=low,
114*da0073e9SAndroid Build Coastguard Worker                               high=high,
115*da0073e9SAndroid Build Coastguard Worker                               size=(1, 2, 3, 4, 5),
116*da0073e9SAndroid Build Coastguard Worker                               dtype=torch.int64,
117*da0073e9SAndroid Build Coastguard Worker                               device=torch.device('cpu'))[:, :, :, :, :]]:
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker                cpu_tensor = cpu_tensor.to(dtype=int8_type)
120*da0073e9SAndroid Build Coastguard Worker                mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type)
121*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.dtype, int8_type)
122*da0073e9SAndroid Build Coastguard Worker                cpu_tensor_1 = mkldnn_tensor.to_dense()
123*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
124*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_tensor, cpu_tensor_1)
125*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
126*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size())
127*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
128*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
129*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError,
130*da0073e9SAndroid Build Coastguard Worker                                       "Cannot access data pointer of Tensor that doesn't have storage",
131*da0073e9SAndroid Build Coastguard Worker                                       lambda: mkldnn_tensor.data_ptr() != 0)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def test_copy(self):
134*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5, dtype=torch.float32)
135*da0073e9SAndroid Build Coastguard Worker        mkldnn_x = x.to_mkldnn()
136*da0073e9SAndroid Build Coastguard Worker        mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn()
137*da0073e9SAndroid Build Coastguard Worker        mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn()
138*da0073e9SAndroid Build Coastguard Worker        mkldnn_y.copy_(mkldnn_x)
139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mkldnn_y.to_dense())
140*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
141*da0073e9SAndroid Build Coastguard Worker                               "copy_mkldnn_: only support same size tensor.",
142*da0073e9SAndroid Build Coastguard Worker                               lambda: mkldnn_z.copy_(mkldnn_x))
143*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
144*da0073e9SAndroid Build Coastguard Worker                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
145*da0073e9SAndroid Build Coastguard Worker                               "Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor",
146*da0073e9SAndroid Build Coastguard Worker                               lambda: x.copy_(mkldnn_x))
147*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
148*da0073e9SAndroid Build Coastguard Worker                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
149*da0073e9SAndroid Build Coastguard Worker                               "Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor",
150*da0073e9SAndroid Build Coastguard Worker                               lambda: mkldnn_x.copy_(x))
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def test_unsupported(self):
153*da0073e9SAndroid Build Coastguard Worker        # unsupported types and unsupported types with gpu
154*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.double, torch.uint8, torch.int8,
155*da0073e9SAndroid Build Coastguard Worker                      torch.short, torch.int, torch.long]:
156*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError) as context:
157*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn()
158*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
159*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError) as context:
160*da0073e9SAndroid Build Coastguard Worker                    torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn()
161*da0073e9SAndroid Build Coastguard Worker        # supported type with gpu
162*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
163*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError) as context:
164*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn()
165*da0073e9SAndroid Build Coastguard Worker        # some factory functions
166*da0073e9SAndroid Build Coastguard Worker        for creator in [torch.ones, torch.randn, torch.rand]:
167*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError) as context:
168*da0073e9SAndroid Build Coastguard Worker                creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def test_mkldnn_conv_shapecheck(self):
171*da0073e9SAndroid Build Coastguard Worker        input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
172*da0073e9SAndroid Build Coastguard Worker        w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
173*da0073e9SAndroid Build Coastguard Worker        b1 = torch.full((1,), 1, dtype=torch.float32)
174*da0073e9SAndroid Build Coastguard Worker        w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32)
175*da0073e9SAndroid Build Coastguard Worker        b2 = torch.full((2,), 1, dtype=torch.float32)
176*da0073e9SAndroid Build Coastguard Worker        options = zip([-1, 0, 0, 0, 0, 0, 0],  # padding
177*da0073e9SAndroid Build Coastguard Worker                      [1, 0, 1, 1, 1, 1, 1],  # stride
178*da0073e9SAndroid Build Coastguard Worker                      [1, 1, 0, 1, 1, 1, 1],  # dilation
179*da0073e9SAndroid Build Coastguard Worker                      [1, 1, 1, 0, 2, 1, 1],  # groups
180*da0073e9SAndroid Build Coastguard Worker                      [w1, w1, w1, w1, w1, w1, w2],  # weight
181*da0073e9SAndroid Build Coastguard Worker                      [b1, b1, b1, b1, b1, b2, b1])  # bias
182*da0073e9SAndroid Build Coastguard Worker        for pad, st, dil, gr, w, b in options:
183*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError) as _:
184*da0073e9SAndroid Build Coastguard Worker                torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def test_autograd_to_mkldnn(self):
187*da0073e9SAndroid Build Coastguard Worker        # MKLDNN only supports float32
188*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        def func(root):
191*da0073e9SAndroid Build Coastguard Worker            return root.to_mkldnn().to_dense()
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        # because MKLDNN only supports float32, we need to lessen the precision.
194*da0073e9SAndroid Build Coastguard Worker        # these numbers are just empirical results that seem to work.
195*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning,
196*da0073e9SAndroid Build Coastguard Worker                              'double precision floating point',
197*da0073e9SAndroid Build Coastguard Worker                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
198*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning,
199*da0073e9SAndroid Build Coastguard Worker                              'double precision floating point',
200*da0073e9SAndroid Build Coastguard Worker                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def test_autograd_from_mkldnn(self):
203*da0073e9SAndroid Build Coastguard Worker        # MKLDNN only supports float32
204*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        def func(root):
207*da0073e9SAndroid Build Coastguard Worker            return root.to_dense()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        # because MKLDNN only supports float32, we need to lessen the precision.
210*da0073e9SAndroid Build Coastguard Worker        # these numbers are just empirical results that seem to work.
211*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning,
212*da0073e9SAndroid Build Coastguard Worker                              'double precision floating point',
213*da0073e9SAndroid Build Coastguard Worker                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def test_detach(self):
216*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        detach = root.detach()
219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((4, 5), detach.size())
220*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(detach.requires_grad)
221*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(root.requires_grad)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        detach_ = root.detach_()
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((4, 5), detach_.size())
225*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(detach_.requires_grad)
226*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(root.requires_grad)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    def test_repr(self):
229*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
230*da0073e9SAndroid Build Coastguard Worker                                                                  dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    def _test_conv_base(self, dim):
233*da0073e9SAndroid Build Coastguard Worker        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
234*da0073e9SAndroid Build Coastguard Worker        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
235*da0073e9SAndroid Build Coastguard Worker        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
236*da0073e9SAndroid Build Coastguard Worker        for train, bias, dilation, groups in options:
237*da0073e9SAndroid Build Coastguard Worker            N = torch.randint(3, 10, (1,)).item()
238*da0073e9SAndroid Build Coastguard Worker            M = torch.randint(1, 3, (1,)).item() * groups
239*da0073e9SAndroid Build Coastguard Worker            C = torch.randint(1, 3, (1,)).item() * groups
240*da0073e9SAndroid Build Coastguard Worker            x_shape = (N, C) + input_shapes[dim]
241*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(x_shape, dtype=torch.float32)
242*da0073e9SAndroid Build Coastguard Worker            conv = conv_module[dim](in_channels=C,
243*da0073e9SAndroid Build Coastguard Worker                                    out_channels=M,
244*da0073e9SAndroid Build Coastguard Worker                                    kernel_size=3,
245*da0073e9SAndroid Build Coastguard Worker                                    stride=2,
246*da0073e9SAndroid Build Coastguard Worker                                    padding=1,
247*da0073e9SAndroid Build Coastguard Worker                                    dilation=dilation,
248*da0073e9SAndroid Build Coastguard Worker                                    bias=bias,
249*da0073e9SAndroid Build Coastguard Worker                                    groups=groups).float()
250*da0073e9SAndroid Build Coastguard Worker            x1 = x.clone()
251*da0073e9SAndroid Build Coastguard Worker            x2 = x.clone().to_mkldnn()
252*da0073e9SAndroid Build Coastguard Worker            if not train:
253*da0073e9SAndroid Build Coastguard Worker                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
254*da0073e9SAndroid Build Coastguard Worker            elif train and dim != 1:
255*da0073e9SAndroid Build Coastguard Worker                # TODO: enable conv1d training.
256*da0073e9SAndroid Build Coastguard Worker                x1.requires_grad_()
257*da0073e9SAndroid Build Coastguard Worker                x2.requires_grad_()
258*da0073e9SAndroid Build Coastguard Worker                mkldnn_conv = copy.deepcopy(conv)
259*da0073e9SAndroid Build Coastguard Worker            with torch.backends.mkldnn.flags(enabled=False):
260*da0073e9SAndroid Build Coastguard Worker                y_aten = conv(x1)
261*da0073e9SAndroid Build Coastguard Worker                if train and dim != 1:
262*da0073e9SAndroid Build Coastguard Worker                    loss1 = y_aten.sum()
263*da0073e9SAndroid Build Coastguard Worker                    loss1.backward()
264*da0073e9SAndroid Build Coastguard Worker            if not train or (train and dim != 1):
265*da0073e9SAndroid Build Coastguard Worker                y_mkldnn = mkldnn_conv(x2).to_dense()
266*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_aten, y_mkldnn)
267*da0073e9SAndroid Build Coastguard Worker            if not train:
268*da0073e9SAndroid Build Coastguard Worker                self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
269*da0073e9SAndroid Build Coastguard Worker                self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
270*da0073e9SAndroid Build Coastguard Worker            elif dim != 1:
271*da0073e9SAndroid Build Coastguard Worker                loss2 = y_mkldnn.sum()
272*da0073e9SAndroid Build Coastguard Worker                loss2.backward()
273*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(x2.grad.is_mkldnn)
274*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x1.grad, x2.grad.to_dense())
275*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(conv.weight.grad,
276*da0073e9SAndroid Build Coastguard Worker                                 mkldnn_conv.weight.grad,
277*da0073e9SAndroid Build Coastguard Worker                                 atol=1e-3,
278*da0073e9SAndroid Build Coastguard Worker                                 rtol=1e-3)
279*da0073e9SAndroid Build Coastguard Worker                if bias:
280*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    def test_conv1d(self):
283*da0073e9SAndroid Build Coastguard Worker        self._test_conv_base(dim=1)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def test_conv2d(self):
286*da0073e9SAndroid Build Coastguard Worker        self._test_conv_base(dim=2)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def test_conv3d(self):
289*da0073e9SAndroid Build Coastguard Worker        self._test_conv_base(dim=3)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype):
292*da0073e9SAndroid Build Coastguard Worker        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
293*da0073e9SAndroid Build Coastguard Worker        options = itertools.product([True, False], [1, 2], [1, 4])
294*da0073e9SAndroid Build Coastguard Worker        for bias, dilation, groups in options:
295*da0073e9SAndroid Build Coastguard Worker            N = torch.randint(1, 3, (1,)).item()
296*da0073e9SAndroid Build Coastguard Worker            M = torch.randint(1, 3, (1,)).item() * groups
297*da0073e9SAndroid Build Coastguard Worker            C = torch.randint(1, 3, (1,)).item() * groups
298*da0073e9SAndroid Build Coastguard Worker            x_shape = (N, C) + input_shapes[dim]
299*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(x_shape, dtype=torch.float32)
300*da0073e9SAndroid Build Coastguard Worker            # TODO: remove this when group depthwise is supported:
301*da0073e9SAndroid Build Coastguard Worker            if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
302*da0073e9SAndroid Build Coastguard Worker                               torch.nn.ConvTranspose3d] and groups > 1 and C == groups:
303*da0073e9SAndroid Build Coastguard Worker                continue
304*da0073e9SAndroid Build Coastguard Worker            conv = conv_module(in_channels=C,
305*da0073e9SAndroid Build Coastguard Worker                               out_channels=M,
306*da0073e9SAndroid Build Coastguard Worker                               kernel_size=3,
307*da0073e9SAndroid Build Coastguard Worker                               stride=2,
308*da0073e9SAndroid Build Coastguard Worker                               padding=1,
309*da0073e9SAndroid Build Coastguard Worker                               dilation=dilation,
310*da0073e9SAndroid Build Coastguard Worker                               bias=bias,
311*da0073e9SAndroid Build Coastguard Worker                               groups=groups).float()
312*da0073e9SAndroid Build Coastguard Worker            x_lower = x.to(dtype=dtype)
313*da0073e9SAndroid Build Coastguard Worker            if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \
314*da0073e9SAndroid Build Coastguard Worker               (dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()):
315*da0073e9SAndroid Build Coastguard Worker                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
316*da0073e9SAndroid Build Coastguard Worker                mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
317*da0073e9SAndroid Build Coastguard Worker                y = mkldnn_conv(x.to_mkldnn()).to_dense()
318*da0073e9SAndroid Build Coastguard Worker                y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
319*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3)
320*da0073e9SAndroid Build Coastguard Worker            else:
321*da0073e9SAndroid Build Coastguard Worker                msg = {
322*da0073e9SAndroid Build Coastguard Worker                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
323*da0073e9SAndroid Build Coastguard Worker                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
324*da0073e9SAndroid Build Coastguard Worker                }
325*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg[dtype]):
326*da0073e9SAndroid Build Coastguard Worker                    mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
327*da0073e9SAndroid Build Coastguard Worker                    y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
328*da0073e9SAndroid Build Coastguard Worker            # test thnn impl
329*da0073e9SAndroid Build Coastguard Worker            conv_lower = copy.deepcopy(conv).to(dtype=dtype)
330*da0073e9SAndroid Build Coastguard Worker            conv_ref = copy.deepcopy(conv_lower).float()
331*da0073e9SAndroid Build Coastguard Worker            with torch.backends.mkldnn.flags(enabled=False):
332*da0073e9SAndroid Build Coastguard Worker                x_ref = x_lower.clone().float().detach().requires_grad_()
333*da0073e9SAndroid Build Coastguard Worker                x_lower.requires_grad_()
334*da0073e9SAndroid Build Coastguard Worker                y = conv_ref(x_ref)
335*da0073e9SAndroid Build Coastguard Worker                y_lower = conv_lower(x_lower).float()
336*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
339*da0073e9SAndroid Build Coastguard Worker    def test_conv_deconv_1d_lower_precision(self, dtype):
340*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype)
341*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
344*da0073e9SAndroid Build Coastguard Worker    def test_conv_deconv_2d_lower_precision(self, dtype):
345*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype)
346*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
349*da0073e9SAndroid Build Coastguard Worker    def test_conv_deconv_3d_lower_precision(self, dtype):
350*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype)
351*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
354*da0073e9SAndroid Build Coastguard Worker        input_shapes = {2: (55, 55), 3: (14, 14, 14)}
355*da0073e9SAndroid Build Coastguard Worker        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
356*da0073e9SAndroid Build Coastguard Worker        if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
357*da0073e9SAndroid Build Coastguard Worker            cl_format = torch.channels_last
358*da0073e9SAndroid Build Coastguard Worker            input_shape = input_shapes[2]
359*da0073e9SAndroid Build Coastguard Worker        elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
360*da0073e9SAndroid Build Coastguard Worker            cl_format = torch.channels_last_3d
361*da0073e9SAndroid Build Coastguard Worker            input_shape = input_shapes[3]
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        for train, bias, dilation, groups in options:
364*da0073e9SAndroid Build Coastguard Worker            N = torch.randint(3, 10, (1,)).item()
365*da0073e9SAndroid Build Coastguard Worker            M = torch.randint(1, 3, (1,)).item() * groups
366*da0073e9SAndroid Build Coastguard Worker            C = torch.randint(1, 3, (1,)).item() * groups
367*da0073e9SAndroid Build Coastguard Worker            x_shape = (N, C) + input_shape
368*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(x_shape, dtype=dtype)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker            # conv1: mkldnn conv/deconv in contiguous memory format (nchw)
371*da0073e9SAndroid Build Coastguard Worker            # conv2: mkldnn conv/deconv in channels last memory format (nhwc)
372*da0073e9SAndroid Build Coastguard Worker            conv1 = conv_module(in_channels=C,
373*da0073e9SAndroid Build Coastguard Worker                                out_channels=M,
374*da0073e9SAndroid Build Coastguard Worker                                kernel_size=3,
375*da0073e9SAndroid Build Coastguard Worker                                stride=2,
376*da0073e9SAndroid Build Coastguard Worker                                padding=1,
377*da0073e9SAndroid Build Coastguard Worker                                dilation=dilation,
378*da0073e9SAndroid Build Coastguard Worker                                bias=bias,
379*da0073e9SAndroid Build Coastguard Worker                                groups=groups).to(dtype=dtype)
380*da0073e9SAndroid Build Coastguard Worker            conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
381*da0073e9SAndroid Build Coastguard Worker            x1 = x.clone()
382*da0073e9SAndroid Build Coastguard Worker            x2 = x.clone().to(memory_format=cl_format)
383*da0073e9SAndroid Build Coastguard Worker            if train:
384*da0073e9SAndroid Build Coastguard Worker                x1.requires_grad_()
385*da0073e9SAndroid Build Coastguard Worker                x2.requires_grad_()
386*da0073e9SAndroid Build Coastguard Worker            y1 = conv1(x1)
387*da0073e9SAndroid Build Coastguard Worker            y2 = conv2(x2)
388*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2, atol=prec, rtol=prec)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker            if train:
391*da0073e9SAndroid Build Coastguard Worker                y1.sum().backward()
392*da0073e9SAndroid Build Coastguard Worker                y2.sum().backward()
393*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
394*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(conv1.weight.grad,
395*da0073e9SAndroid Build Coastguard Worker                                 conv2.weight.grad,
396*da0073e9SAndroid Build Coastguard Worker                                 atol=1e-3,
397*da0073e9SAndroid Build Coastguard Worker                                 rtol=1e-3)
398*da0073e9SAndroid Build Coastguard Worker                if bias:
399*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
400*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    def test_conv_nhwc_fp32(self):
403*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
404*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
405*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
406*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
409*da0073e9SAndroid Build Coastguard Worker    def test_conv_nhwc_lower_precision(self, dtype):
410*da0073e9SAndroid Build Coastguard Worker        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
411*da0073e9SAndroid Build Coastguard Worker        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
412*da0073e9SAndroid Build Coastguard Worker        support_checks = {
413*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
414*da0073e9SAndroid Build Coastguard Worker            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
415*da0073e9SAndroid Build Coastguard Worker        }
416*da0073e9SAndroid Build Coastguard Worker        if support_checks[dtype]():
417*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype)
418*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype)
419*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype)
420*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype)
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # BF16/FP16 fallback implementations are divided into two parts im2col+gemm,
423*da0073e9SAndroid Build Coastguard Worker        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
424*da0073e9SAndroid Build Coastguard Worker        # resulting in additional accuracy loss.
425*da0073e9SAndroid Build Coastguard Worker        precisions = {
426*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: 1e-2,
427*da0073e9SAndroid Build Coastguard Worker            torch.float16: 2e-3,
428*da0073e9SAndroid Build Coastguard Worker        }
429*da0073e9SAndroid Build Coastguard Worker        prec = precisions[dtype]
430*da0073e9SAndroid Build Coastguard Worker        with torch.backends.mkldnn.flags(enabled=False):
431*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec)
432*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec)
433*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec)
434*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_nhwc_fp32(self):
438*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
439*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
440*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
441*da0073e9SAndroid Build Coastguard Worker        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
444*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_nhwc_lower_precision(self, dtype):
445*da0073e9SAndroid Build Coastguard Worker        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
446*da0073e9SAndroid Build Coastguard Worker        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
447*da0073e9SAndroid Build Coastguard Worker        support_checks = {
448*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
449*da0073e9SAndroid Build Coastguard Worker            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
450*da0073e9SAndroid Build Coastguard Worker        }
451*da0073e9SAndroid Build Coastguard Worker        if support_checks[dtype]():
452*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype)
453*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype)
454*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype)
455*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype)
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        # BF16/FP16 fallback implementations are divided into two parts col2im+gemm,
458*da0073e9SAndroid Build Coastguard Worker        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
459*da0073e9SAndroid Build Coastguard Worker        # resulting in additional accuracy loss.
460*da0073e9SAndroid Build Coastguard Worker        precisions = {
461*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: 2e-2,
462*da0073e9SAndroid Build Coastguard Worker            torch.float16: 3e-3,
463*da0073e9SAndroid Build Coastguard Worker        }
464*da0073e9SAndroid Build Coastguard Worker        prec = precisions[dtype]
465*da0073e9SAndroid Build Coastguard Worker        with torch.backends.mkldnn.flags(enabled=False):
466*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec)
467*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec)
468*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec)
469*da0073e9SAndroid Build Coastguard Worker            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker    def _test_conv_transpose_base(self, dim):
472*da0073e9SAndroid Build Coastguard Worker        conv_module = {
473*da0073e9SAndroid Build Coastguard Worker            1: torch.nn.ConvTranspose1d,
474*da0073e9SAndroid Build Coastguard Worker            2: torch.nn.ConvTranspose2d,
475*da0073e9SAndroid Build Coastguard Worker            3: torch.nn.ConvTranspose3d
476*da0073e9SAndroid Build Coastguard Worker        }
477*da0073e9SAndroid Build Coastguard Worker        input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)}
478*da0073e9SAndroid Build Coastguard Worker        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
479*da0073e9SAndroid Build Coastguard Worker        for train, bias, dilation, groups in options:
480*da0073e9SAndroid Build Coastguard Worker            N = torch.randint(3, 10, (1,)).item()
481*da0073e9SAndroid Build Coastguard Worker            M = torch.randint(1, 3, (1,)).item() * groups
482*da0073e9SAndroid Build Coastguard Worker            C = torch.randint(1, 3, (1,)).item() * groups
483*da0073e9SAndroid Build Coastguard Worker            x_shape = (N, C) + input_shapes[dim]
484*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(x_shape, dtype=torch.float32)
485*da0073e9SAndroid Build Coastguard Worker            # conv: mkldnn tranpose conv fp32
486*da0073e9SAndroid Build Coastguard Worker            # conv_ref: thnn transpose conv fp32
487*da0073e9SAndroid Build Coastguard Worker            conv = conv_module[dim](in_channels=C,
488*da0073e9SAndroid Build Coastguard Worker                                    out_channels=M,
489*da0073e9SAndroid Build Coastguard Worker                                    kernel_size=3,
490*da0073e9SAndroid Build Coastguard Worker                                    stride=1,
491*da0073e9SAndroid Build Coastguard Worker                                    padding=1,
492*da0073e9SAndroid Build Coastguard Worker                                    dilation=dilation,
493*da0073e9SAndroid Build Coastguard Worker                                    bias=bias,
494*da0073e9SAndroid Build Coastguard Worker                                    groups=groups).to(dtype=torch.float32)
495*da0073e9SAndroid Build Coastguard Worker            x = data.clone()
496*da0073e9SAndroid Build Coastguard Worker            x_ref = x.clone()
497*da0073e9SAndroid Build Coastguard Worker            if train:
498*da0073e9SAndroid Build Coastguard Worker                x.requires_grad_()
499*da0073e9SAndroid Build Coastguard Worker                x_ref.requires_grad_()
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker            conv_ref = copy.deepcopy(conv)
502*da0073e9SAndroid Build Coastguard Worker            with torch.backends.mkldnn.flags(enabled=False):
503*da0073e9SAndroid Build Coastguard Worker                y_ref = conv_ref(x_ref)
504*da0073e9SAndroid Build Coastguard Worker                if train:
505*da0073e9SAndroid Build Coastguard Worker                    y_ref.sum().backward()
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker            y = conv(x)
508*da0073e9SAndroid Build Coastguard Worker            if train:
509*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_ref)
512*da0073e9SAndroid Build Coastguard Worker            if train:
513*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, x_ref.grad)
514*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(conv.weight.grad,
515*da0073e9SAndroid Build Coastguard Worker                                 conv_ref.weight.grad,
516*da0073e9SAndroid Build Coastguard Worker                                 atol=1e-3,
517*da0073e9SAndroid Build Coastguard Worker                                 rtol=1e-3)
518*da0073e9SAndroid Build Coastguard Worker                if bias:
519*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose1d(self):
522*da0073e9SAndroid Build Coastguard Worker        self._test_conv_transpose_base(dim=1)
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose2d(self):
525*da0073e9SAndroid Build Coastguard Worker        self._test_conv_transpose_base(dim=2)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose3d(self):
528*da0073e9SAndroid Build Coastguard Worker        self._test_conv_transpose_base(dim=3)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_legacy_jit_model(self):
531*da0073e9SAndroid Build Coastguard Worker        """
532*da0073e9SAndroid Build Coastguard Worker        MKLDNN integration used to serialize models with 5d weight for grouped
533*da0073e9SAndroid Build Coastguard Worker        convolutions, we'd like to preserve this behavior
534*da0073e9SAndroid Build Coastguard Worker        """
535*da0073e9SAndroid Build Coastguard Worker        g = 4
536*da0073e9SAndroid Build Coastguard Worker        conv2d = torch.nn.Conv2d(16, 16, 3, groups=g)
537*da0073e9SAndroid Build Coastguard Worker        conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # contrive legacy conv2d module with a 5-d weight
540*da0073e9SAndroid Build Coastguard Worker        o, i, h, w = conv2d.weight.shape
541*da0073e9SAndroid Build Coastguard Worker        weight_5d = conv2d.weight.reshape((g, o // g, i, h, w))
542*da0073e9SAndroid Build Coastguard Worker        conv2d_mkldnn.weight = weight_5d.to_mkldnn()
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 16, 8, 8)
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
547*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(conv2d_mkldnn, fname)
548*da0073e9SAndroid Build Coastguard Worker            conv2d_loaded = torch.jit.load(fname)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5)
551*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
552*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
553*da0073e9SAndroid Build Coastguard Worker                conv2d(x),
554*da0073e9SAndroid Build Coastguard Worker                conv2d_loaded(x.to_mkldnn()).to_dense())
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    # This test is to check whether 1D conv is supported for mkldnn tensor,
557*da0073e9SAndroid Build Coastguard Worker    # which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034.
558*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_functional(self):
559*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 10).to_mkldnn()
560*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(3, 3, 3).to_mkldnn()
561*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(3).to_mkldnn()
562*da0073e9SAndroid Build Coastguard Worker        output = torch.nn.functional.conv1d(input, weight, bias)
563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.size(), torch.Size([2, 3, 8]))
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker    def test_relu(self):
566*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 5), dtype=torch.float32) * 10
567*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
568*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
569*da0073e9SAndroid Build Coastguard Worker        y1 = torch.relu(x1)
570*da0073e9SAndroid Build Coastguard Worker        y2 = torch.relu(x2).to_dense()
571*da0073e9SAndroid Build Coastguard Worker        loss1 = y1.sum()
572*da0073e9SAndroid Build Coastguard Worker        loss2 = y2.sum()
573*da0073e9SAndroid Build Coastguard Worker        loss1.backward()
574*da0073e9SAndroid Build Coastguard Worker        loss2.backward()
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker    def test_relu_(self):
579*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 5), dtype=torch.float32) * 10
580*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
581*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
582*da0073e9SAndroid Build Coastguard Worker        y1 = torch.relu_(x1.clone())
583*da0073e9SAndroid Build Coastguard Worker        y2 = torch.relu_(x2.clone()).to_dense()
584*da0073e9SAndroid Build Coastguard Worker        loss1 = y1.sum()
585*da0073e9SAndroid Build Coastguard Worker        loss2 = y2.sum()
586*da0073e9SAndroid Build Coastguard Worker        loss1.backward()
587*da0073e9SAndroid Build Coastguard Worker        loss2.backward()
588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
592*da0073e9SAndroid Build Coastguard Worker    def _test_relu_bf16_base(self, name):
593*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 5), dtype=torch.float32) * 10
594*da0073e9SAndroid Build Coastguard Worker        x_bf16 = x.bfloat16()
595*da0073e9SAndroid Build Coastguard Worker        fn = getattr(torch, name)
596*da0073e9SAndroid Build Coastguard Worker        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
597*da0073e9SAndroid Build Coastguard Worker            y = fn(x.to_mkldnn()).to_dense()
598*da0073e9SAndroid Build Coastguard Worker            y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32)
599*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
600*da0073e9SAndroid Build Coastguard Worker        else:
601*da0073e9SAndroid Build Coastguard Worker            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
602*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
603*da0073e9SAndroid Build Coastguard Worker                                   msg,
604*da0073e9SAndroid Build Coastguard Worker                                   lambda: fn(x_bf16.to_mkldnn()))
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    def test_relu_bf16(self):
607*da0073e9SAndroid Build Coastguard Worker        self._test_relu_bf16_base("relu")
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    def test_relu_inplace_bf16(self):
610*da0073e9SAndroid Build Coastguard Worker        self._test_relu_bf16_base("relu_")
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker    def test_gelu(self):
613*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.GELU()
614*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 5), dtype=torch.float32) * 10
615*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
616*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
617*da0073e9SAndroid Build Coastguard Worker        y1 = m(x1)
618*da0073e9SAndroid Build Coastguard Worker        y2 = m(x2).to_dense()
619*da0073e9SAndroid Build Coastguard Worker        loss1 = y1.sum()
620*da0073e9SAndroid Build Coastguard Worker        loss2 = y2.sum()
621*da0073e9SAndroid Build Coastguard Worker        loss1.backward()
622*da0073e9SAndroid Build Coastguard Worker        loss2.backward()
623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
627*da0073e9SAndroid Build Coastguard Worker    def test_gelu_bf16(self):
628*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.GELU()
629*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 5), dtype=torch.float32) * 10
630*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().to_mkldnn().requires_grad_()
631*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
632*da0073e9SAndroid Build Coastguard Worker        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
633*da0073e9SAndroid Build Coastguard Worker            y1 = m(x1).to_dense()
634*da0073e9SAndroid Build Coastguard Worker            y2 = m(x2).to_dense()
635*da0073e9SAndroid Build Coastguard Worker            loss1 = y1.sum()
636*da0073e9SAndroid Build Coastguard Worker            loss2 = y2.sum()
637*da0073e9SAndroid Build Coastguard Worker            loss1.backward()
638*da0073e9SAndroid Build Coastguard Worker            loss2.backward()
639*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0)
640*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0)
641*da0073e9SAndroid Build Coastguard Worker        else:
642*da0073e9SAndroid Build Coastguard Worker            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
643*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
644*da0073e9SAndroid Build Coastguard Worker                                   msg,
645*da0073e9SAndroid Build Coastguard Worker                                   lambda: m(x2))
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    def _test_prelu_base(self, size, num_channels):
648*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=torch.float32)
649*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
650*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
651*da0073e9SAndroid Build Coastguard Worker        x3 = x.clone().to_mkldnn().requires_grad_()
652*da0073e9SAndroid Build Coastguard Worker        m1 = torch.nn.PReLU(num_channels)
653*da0073e9SAndroid Build Coastguard Worker        m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
654*da0073e9SAndroid Build Coastguard Worker        m3 = copy.deepcopy(m1)
655*da0073e9SAndroid Build Coastguard Worker        y1 = m1(x1)
656*da0073e9SAndroid Build Coastguard Worker        y2 = m2(x2).to_dense()
657*da0073e9SAndroid Build Coastguard Worker        y3 = m3(x3).to_dense()  # Only convert data to mkldnn, weight is Aten tensor
658*da0073e9SAndroid Build Coastguard Worker        loss1 = y1.sum()
659*da0073e9SAndroid Build Coastguard Worker        loss1.backward()
660*da0073e9SAndroid Build Coastguard Worker        loss2 = y2.sum()
661*da0073e9SAndroid Build Coastguard Worker        loss2.backward()
662*da0073e9SAndroid Build Coastguard Worker        loss3 = y3.sum()
663*da0073e9SAndroid Build Coastguard Worker        loss3.backward()
664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y3)
666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x3.grad.to_dense())
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_prelu(self):
670*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16]), 1)
671*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64]), 1)
672*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64]), 64)
673*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112]), 1)
674*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112]), 64)
675*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
676*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
677*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
678*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
681*da0073e9SAndroid Build Coastguard Worker    def _test_prelu_bf16_base(self, size, num_channels):
682*da0073e9SAndroid Build Coastguard Worker        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
683*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(size, dtype=torch.float32)
684*da0073e9SAndroid Build Coastguard Worker            x_fp32 = x.clone().to_mkldnn().requires_grad_()
685*da0073e9SAndroid Build Coastguard Worker            x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
686*da0073e9SAndroid Build Coastguard Worker            m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
687*da0073e9SAndroid Build Coastguard Worker            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker            y = m(x_fp32).to_dense()
690*da0073e9SAndroid Build Coastguard Worker            y_bf16 = m_bf16(x_bf16).to_dense()
691*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker            loss = y.sum()
694*da0073e9SAndroid Build Coastguard Worker            loss.backward()
695*da0073e9SAndroid Build Coastguard Worker            loss_bf16 = y_bf16.sum()
696*da0073e9SAndroid Build Coastguard Worker            loss_bf16.backward()
697*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
698*da0073e9SAndroid Build Coastguard Worker        else:
699*da0073e9SAndroid Build Coastguard Worker            x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
700*da0073e9SAndroid Build Coastguard Worker            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
701*da0073e9SAndroid Build Coastguard Worker            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
702*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
703*da0073e9SAndroid Build Coastguard Worker                                   msg,
704*da0073e9SAndroid Build Coastguard Worker                                   lambda: m_bf16(x_bf16))
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker    def test_prelu_bf16(self):
707*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16]), 1)
708*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
709*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
710*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
711*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
712*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
713*da0073e9SAndroid Build Coastguard Worker        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker    def _test_max_pool_base(self, dim, input):
716*da0073e9SAndroid Build Coastguard Worker        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
717*da0073e9SAndroid Build Coastguard Worker        for stride in [1, 2, 3]:
718*da0073e9SAndroid Build Coastguard Worker            for ceil_mode in [False, True]:
719*da0073e9SAndroid Build Coastguard Worker                max_pool = pool_module[dim](
720*da0073e9SAndroid Build Coastguard Worker                    kernel_size=3 if not ceil_mode else 7,
721*da0073e9SAndroid Build Coastguard Worker                    stride=stride,
722*da0073e9SAndroid Build Coastguard Worker                    padding=1,
723*da0073e9SAndroid Build Coastguard Worker                    ceil_mode=ceil_mode)
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker                x1 = input.clone().requires_grad_()
726*da0073e9SAndroid Build Coastguard Worker                x2 = input.clone().to_mkldnn().requires_grad_()
727*da0073e9SAndroid Build Coastguard Worker                y1 = max_pool(x1)
728*da0073e9SAndroid Build Coastguard Worker                y2 = max_pool(x2).to_dense()
729*da0073e9SAndroid Build Coastguard Worker                loss1 = y1.sum()
730*da0073e9SAndroid Build Coastguard Worker                loss2 = y2.sum()
731*da0073e9SAndroid Build Coastguard Worker                loss1.backward()
732*da0073e9SAndroid Build Coastguard Worker                loss2.backward()
733*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y1, y2)
734*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x1.grad, x2.grad.to_dense())
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d(self):
737*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
738*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
739*da0073e9SAndroid Build Coastguard Worker        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
740*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
741*da0073e9SAndroid Build Coastguard Worker            self._test_max_pool_base(dim=2, input=x)
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    def test_max_pool3d(self):
744*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
745*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
746*da0073e9SAndroid Build Coastguard Worker        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
747*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
748*da0073e9SAndroid Build Coastguard Worker            self._test_max_pool_base(dim=3, input=x)
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
752*da0073e9SAndroid Build Coastguard Worker    def _test_max_pool_bf16_base(self, dim, input):
753*da0073e9SAndroid Build Coastguard Worker        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
754*da0073e9SAndroid Build Coastguard Worker        x_bf16 = input.bfloat16()
755*da0073e9SAndroid Build Coastguard Worker        for stride in [1, 2, 3]:
756*da0073e9SAndroid Build Coastguard Worker            for ceil_mode in [False, True]:
757*da0073e9SAndroid Build Coastguard Worker                max_pool = pool_module[dim](
758*da0073e9SAndroid Build Coastguard Worker                    kernel_size=3 if not ceil_mode else 7,
759*da0073e9SAndroid Build Coastguard Worker                    stride=stride,
760*da0073e9SAndroid Build Coastguard Worker                    padding=1,
761*da0073e9SAndroid Build Coastguard Worker                    ceil_mode=ceil_mode)
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker                if torch.ops.mkldnn._is_mkldnn_bf16_supported():
764*da0073e9SAndroid Build Coastguard Worker                    y = max_pool(input.to_mkldnn()).to_dense()
765*da0073e9SAndroid Build Coastguard Worker                    y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32)
766*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3)
767*da0073e9SAndroid Build Coastguard Worker                else:
768*da0073e9SAndroid Build Coastguard Worker                    msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
769*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError,
770*da0073e9SAndroid Build Coastguard Worker                                           msg,
771*da0073e9SAndroid Build Coastguard Worker                                           lambda: max_pool(x_bf16.to_mkldnn()))
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d_bf16(self):
774*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
775*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
776*da0073e9SAndroid Build Coastguard Worker        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
777*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
778*da0073e9SAndroid Build Coastguard Worker            self._test_max_pool_bf16_base(dim=2, input=x)
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    def test_max_pool3d_bf16(self):
781*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
782*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
783*da0073e9SAndroid Build Coastguard Worker        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
784*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
785*da0073e9SAndroid Build Coastguard Worker            self._test_max_pool_bf16_base(dim=3, input=x)
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d_stride_none(self):
788*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
789*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
792*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
793*da0073e9SAndroid Build Coastguard Worker            for ceil_mode in [False, True]:
794*da0073e9SAndroid Build Coastguard Worker                y1 = F.max_pool2d(
795*da0073e9SAndroid Build Coastguard Worker                    x,
796*da0073e9SAndroid Build Coastguard Worker                    kernel_size=3 if not ceil_mode else 7,
797*da0073e9SAndroid Build Coastguard Worker                    stride=None,
798*da0073e9SAndroid Build Coastguard Worker                    padding=1,
799*da0073e9SAndroid Build Coastguard Worker                    ceil_mode=ceil_mode)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker                y2 = F.max_pool2d(
802*da0073e9SAndroid Build Coastguard Worker                    x.to_mkldnn(),
803*da0073e9SAndroid Build Coastguard Worker                    kernel_size=3 if not ceil_mode else 7,
804*da0073e9SAndroid Build Coastguard Worker                    stride=None,
805*da0073e9SAndroid Build Coastguard Worker                    padding=1,
806*da0073e9SAndroid Build Coastguard Worker                    ceil_mode=ceil_mode)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y1, y2.to_dense())
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/127111
811*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
812*da0073e9SAndroid Build Coastguard Worker    def test_max_pool_unsupported(self):
813*da0073e9SAndroid Build Coastguard Worker        # OneDNN not support dilation max_pooling, will be avilabled in v2.0.
814*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
815*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        # 2d dilation case
818*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn()
819*da0073e9SAndroid Build Coastguard Worker        max_pool2d = torch.nn.MaxPool2d(
820*da0073e9SAndroid Build Coastguard Worker            kernel_size=3,
821*da0073e9SAndroid Build Coastguard Worker            stride=3,
822*da0073e9SAndroid Build Coastguard Worker            padding=1,
823*da0073e9SAndroid Build Coastguard Worker            dilation=2)
824*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
825*da0073e9SAndroid Build Coastguard Worker                               'mkldnn_max_pool2d does not support dilation case',
826*da0073e9SAndroid Build Coastguard Worker                               lambda: max_pool2d(x))
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker        # 3d dilation case
829*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn()
830*da0073e9SAndroid Build Coastguard Worker        max_pool3d = torch.nn.MaxPool3d(
831*da0073e9SAndroid Build Coastguard Worker            kernel_size=3,
832*da0073e9SAndroid Build Coastguard Worker            stride=3,
833*da0073e9SAndroid Build Coastguard Worker            padding=1,
834*da0073e9SAndroid Build Coastguard Worker            dilation=2)
835*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
836*da0073e9SAndroid Build Coastguard Worker                               'mkldnn_max_pool3d does not support dilation case',
837*da0073e9SAndroid Build Coastguard Worker                               lambda: max_pool3d(x))
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker    def _test_avg_pool_base(self, dim, input):
840*da0073e9SAndroid Build Coastguard Worker        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
841*da0073e9SAndroid Build Coastguard Worker        for count_include_pad in [True, False]:
842*da0073e9SAndroid Build Coastguard Worker            avg_pool = avg_module[dim](
843*da0073e9SAndroid Build Coastguard Worker                kernel_size=3,
844*da0073e9SAndroid Build Coastguard Worker                stride=2,
845*da0073e9SAndroid Build Coastguard Worker                padding=1,
846*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad)
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker            x1 = input.clone().requires_grad_()
849*da0073e9SAndroid Build Coastguard Worker            x2 = input.clone().to_mkldnn().requires_grad_()
850*da0073e9SAndroid Build Coastguard Worker            y1 = avg_pool(x1)
851*da0073e9SAndroid Build Coastguard Worker            y2 = avg_pool(x2).to_dense()
852*da0073e9SAndroid Build Coastguard Worker            loss1 = y1.sum()
853*da0073e9SAndroid Build Coastguard Worker            loss2 = y2.sum()
854*da0073e9SAndroid Build Coastguard Worker            loss1.backward()
855*da0073e9SAndroid Build Coastguard Worker            loss2.backward()
856*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2)
857*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1.grad, x2.grad.to_dense())
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d(self):
860*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
861*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
862*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
863*da0073e9SAndroid Build Coastguard Worker        self._test_avg_pool_base(dim=2, input=x)
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool3d(self):
866*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
867*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
868*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
869*da0073e9SAndroid Build Coastguard Worker        self._test_avg_pool_base(dim=3, input=x)
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
872*da0073e9SAndroid Build Coastguard Worker    def _test_avg_pool_bf16_base(self, dim, input):
873*da0073e9SAndroid Build Coastguard Worker        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
874*da0073e9SAndroid Build Coastguard Worker        x_bf16 = input.bfloat16()
875*da0073e9SAndroid Build Coastguard Worker        for count_include_pad in [True, False]:
876*da0073e9SAndroid Build Coastguard Worker            avg_pool = avg_module[dim](
877*da0073e9SAndroid Build Coastguard Worker                kernel_size=3,
878*da0073e9SAndroid Build Coastguard Worker                stride=2,
879*da0073e9SAndroid Build Coastguard Worker                padding=1,
880*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad)
881*da0073e9SAndroid Build Coastguard Worker            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
882*da0073e9SAndroid Build Coastguard Worker                y = avg_pool(input.to_mkldnn()).to_dense()
883*da0073e9SAndroid Build Coastguard Worker                y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float)
884*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
885*da0073e9SAndroid Build Coastguard Worker            else:
886*da0073e9SAndroid Build Coastguard Worker                msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
887*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError,
888*da0073e9SAndroid Build Coastguard Worker                                       msg,
889*da0073e9SAndroid Build Coastguard Worker                                       lambda: avg_pool(x_bf16.to_mkldnn()))
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_bf16(self):
892*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
893*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
894*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
895*da0073e9SAndroid Build Coastguard Worker        self._test_avg_pool_bf16_base(dim=2, input=x)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool3d_bf16(self):
898*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
899*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
900*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
901*da0073e9SAndroid Build Coastguard Worker        self._test_avg_pool_bf16_base(dim=3, input=x)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_stride_none(self):
904*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
905*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
906*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        for count_include_pad in [True, False]:
909*da0073e9SAndroid Build Coastguard Worker            y1 = F.avg_pool2d(
910*da0073e9SAndroid Build Coastguard Worker                x,
911*da0073e9SAndroid Build Coastguard Worker                kernel_size=3,
912*da0073e9SAndroid Build Coastguard Worker                stride=None,
913*da0073e9SAndroid Build Coastguard Worker                padding=1,
914*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad)
915*da0073e9SAndroid Build Coastguard Worker            y2 = F.avg_pool2d(
916*da0073e9SAndroid Build Coastguard Worker                x.to_mkldnn(),
917*da0073e9SAndroid Build Coastguard Worker                kernel_size=3,
918*da0073e9SAndroid Build Coastguard Worker                stride=None,
919*da0073e9SAndroid Build Coastguard Worker                padding=1,
920*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2.to_dense())
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d(self):
925*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
926*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
927*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
930*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
931*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
932*da0073e9SAndroid Build Coastguard Worker        y1 = adaptive_avg_pool2d(x1)
933*da0073e9SAndroid Build Coastguard Worker        y2 = adaptive_avg_pool2d(x2).to_dense()
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker        loss1 = y1.sum()
936*da0073e9SAndroid Build Coastguard Worker        loss2 = y2.sum()
937*da0073e9SAndroid Build Coastguard Worker        loss1.backward()
938*da0073e9SAndroid Build Coastguard Worker        loss2.backward()
939*da0073e9SAndroid Build Coastguard Worker
940*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
941*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
944*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d_bf16(self):
945*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
946*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 10, (1,)).item()
947*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker        x_bf16 = x.bfloat16()
950*da0073e9SAndroid Build Coastguard Worker        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
953*da0073e9SAndroid Build Coastguard Worker            y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense()
954*da0073e9SAndroid Build Coastguard Worker            y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32)
955*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
956*da0073e9SAndroid Build Coastguard Worker        else:
957*da0073e9SAndroid Build Coastguard Worker            msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
958*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
959*da0073e9SAndroid Build Coastguard Worker                                   msg,
960*da0073e9SAndroid Build Coastguard Worker                                   lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn()))
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    def _test_batch_norm_base(self, dim, channels, input):
963*da0073e9SAndroid Build Coastguard Worker        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
964*da0073e9SAndroid Build Coastguard Worker        bn = bn_module[dim](channels).float().train(False)
965*da0073e9SAndroid Build Coastguard Worker        mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
966*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
967*da0073e9SAndroid Build Coastguard Worker            bn(input),
968*da0073e9SAndroid Build Coastguard Worker            mkldnn_bn(input.to_mkldnn()).to_dense())
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker        self._test_serialization(mkldnn_bn, (input.to_mkldnn(),))
971*da0073e9SAndroid Build Coastguard Worker        self._test_tracing(mkldnn_bn, (input.to_mkldnn(),))
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    def _test_batch_norm_train_base(self, dim, channels, input):
974*da0073e9SAndroid Build Coastguard Worker        # TODO: support 3d batchnorm training.
975*da0073e9SAndroid Build Coastguard Worker        bn_module = {2 : torch.nn.BatchNorm2d}
976*da0073e9SAndroid Build Coastguard Worker        # TODO: support none affine.
977*da0073e9SAndroid Build Coastguard Worker        options = itertools.product([True], [True, False])
978*da0073e9SAndroid Build Coastguard Worker        for affine, track_running_stats in options:
979*da0073e9SAndroid Build Coastguard Worker            bn = bn_module[dim](
980*da0073e9SAndroid Build Coastguard Worker                num_features=channels,
981*da0073e9SAndroid Build Coastguard Worker                affine=affine,
982*da0073e9SAndroid Build Coastguard Worker                track_running_stats=track_running_stats).float().train(True)
983*da0073e9SAndroid Build Coastguard Worker            mkldnn_bn = copy.deepcopy(bn)
984*da0073e9SAndroid Build Coastguard Worker            x1 = input.clone().requires_grad_()
985*da0073e9SAndroid Build Coastguard Worker            x2 = input.clone().to_mkldnn().requires_grad_()
986*da0073e9SAndroid Build Coastguard Worker            y1 = bn(x1)
987*da0073e9SAndroid Build Coastguard Worker            y2 = mkldnn_bn(x2).to_dense()
988*da0073e9SAndroid Build Coastguard Worker            loss1 = y1.sum()
989*da0073e9SAndroid Build Coastguard Worker            loss2 = y2.sum()
990*da0073e9SAndroid Build Coastguard Worker            loss1.backward()
991*da0073e9SAndroid Build Coastguard Worker            loss2.backward()
992*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2)
993*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1.grad, x2.grad.to_dense())
994*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3)
995*da0073e9SAndroid Build Coastguard Worker            if track_running_stats:
996*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bn.running_mean, mkldnn_bn.running_mean)
997*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5)
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_2d(self):
1000*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1001*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1002*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1003*da0073e9SAndroid Build Coastguard Worker        self._test_batch_norm_base(dim=2, channels=C, input=x)
1004*da0073e9SAndroid Build Coastguard Worker        self._test_batch_norm_train_base(dim=2, channels=C, input=x)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_3d(self):
1007*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1008*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1009*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1010*da0073e9SAndroid Build Coastguard Worker        self._test_batch_norm_base(dim=3, channels=C, input=x)
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1013*da0073e9SAndroid Build Coastguard Worker    def _test_batch_norm_bf16_base(self, dim, channels, input):
1014*da0073e9SAndroid Build Coastguard Worker        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
1015*da0073e9SAndroid Build Coastguard Worker        x_bf16 = input.bfloat16()
1016*da0073e9SAndroid Build Coastguard Worker        # TODO: support training
1017*da0073e9SAndroid Build Coastguard Worker        for train in [False]:
1018*da0073e9SAndroid Build Coastguard Worker            bn = bn_module[dim](channels).float().train(train)
1019*da0073e9SAndroid Build Coastguard Worker            mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
1020*da0073e9SAndroid Build Coastguard Worker            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
1021*da0073e9SAndroid Build Coastguard Worker                y = bn(input.to_mkldnn().to_dense())
1022*da0073e9SAndroid Build Coastguard Worker                y_bf16 = bn(input.to_mkldnn().to_dense(torch.float))
1023*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
1024*da0073e9SAndroid Build Coastguard Worker            else:
1025*da0073e9SAndroid Build Coastguard Worker                msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
1026*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError,
1027*da0073e9SAndroid Build Coastguard Worker                                       msg,
1028*da0073e9SAndroid Build Coastguard Worker                                       lambda: bn(x_bf16.to_mkldnn()))
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_2d_bf16(self):
1031*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1032*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1033*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1034*da0073e9SAndroid Build Coastguard Worker        self._test_batch_norm_bf16_base(dim=2, channels=C, input=x)
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_3d_bf16(self):
1037*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1038*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1039*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1040*da0073e9SAndroid Build Coastguard Worker        self._test_batch_norm_bf16_base(dim=3, channels=C, input=x)
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    def test_add(self):
1043*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1044*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1045*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(1, dtype=torch.float32).item()
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1048*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1049*da0073e9SAndroid Build Coastguard Worker        mx = x.to_mkldnn()
1050*da0073e9SAndroid Build Coastguard Worker        my = y.to_mkldnn()
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        # add
1053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1054*da0073e9SAndroid Build Coastguard Worker            x + y,
1055*da0073e9SAndroid Build Coastguard Worker            (mx + my).to_dense())
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1058*da0073e9SAndroid Build Coastguard Worker            torch.add(x, y, alpha=alpha),
1059*da0073e9SAndroid Build Coastguard Worker            torch.add(mx, my, alpha=alpha).to_dense())
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        # add_
1062*da0073e9SAndroid Build Coastguard Worker        x += y
1063*da0073e9SAndroid Build Coastguard Worker        mx += my
1064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mx.to_dense())
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        # add_out
1067*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
1068*da0073e9SAndroid Build Coastguard Worker        mkldnn_out = out.to_mkldnn()
1069*da0073e9SAndroid Build Coastguard Worker        torch.add(x, y, alpha=alpha, out=out)
1070*da0073e9SAndroid Build Coastguard Worker        torch.add(mx, my, alpha=alpha, out=mkldnn_out)
1071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, mkldnn_out.to_dense())
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker        # add_out inplace case: first input
1074*da0073e9SAndroid Build Coastguard Worker        torch.add(x, y, alpha=alpha, out=x)
1075*da0073e9SAndroid Build Coastguard Worker        torch.add(mx, my, alpha=alpha, out=mx)
1076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mx.to_dense())
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        # add_out inplace case: second input
1079*da0073e9SAndroid Build Coastguard Worker        torch.add(x, y, alpha=alpha, out=y)
1080*da0073e9SAndroid Build Coastguard Worker        torch.add(mx, my, alpha=alpha, out=my)
1081*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, my.to_dense())
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker    def test_mul(self):
1084*da0073e9SAndroid Build Coastguard Worker        N = torch.randint(3, 10, (1,)).item()
1085*da0073e9SAndroid Build Coastguard Worker        C = torch.randint(3, 100, (1,)).item()
1086*da0073e9SAndroid Build Coastguard Worker        value = torch.randn(1, dtype=torch.float32).item()
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1089*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1090*da0073e9SAndroid Build Coastguard Worker        mx = x.to_mkldnn()
1091*da0073e9SAndroid Build Coastguard Worker        my = y.to_mkldnn()
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker        # mul
1094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1095*da0073e9SAndroid Build Coastguard Worker            x * y,
1096*da0073e9SAndroid Build Coastguard Worker            (mx * my).to_dense())
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1099*da0073e9SAndroid Build Coastguard Worker            x * value,
1100*da0073e9SAndroid Build Coastguard Worker            (mx * value).to_dense())
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1103*da0073e9SAndroid Build Coastguard Worker            torch.mul(x, y),
1104*da0073e9SAndroid Build Coastguard Worker            torch.mul(mx, my).to_dense())
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1107*da0073e9SAndroid Build Coastguard Worker            torch.mul(x, value),
1108*da0073e9SAndroid Build Coastguard Worker            torch.mul(mx, value).to_dense())
1109*da0073e9SAndroid Build Coastguard Worker
1110*da0073e9SAndroid Build Coastguard Worker        # mul_
1111*da0073e9SAndroid Build Coastguard Worker        x *= y
1112*da0073e9SAndroid Build Coastguard Worker        mx *= my
1113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mx.to_dense())
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker        x *= value
1116*da0073e9SAndroid Build Coastguard Worker        mx *= value
1117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mx.to_dense())
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        # mul_out
1120*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
1121*da0073e9SAndroid Build Coastguard Worker        mkldnn_out = out.to_mkldnn()
1122*da0073e9SAndroid Build Coastguard Worker        torch.mul(x, y, out=out)
1123*da0073e9SAndroid Build Coastguard Worker        torch.mul(mx, my, out=mkldnn_out)
1124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, mkldnn_out.to_dense())
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
1127*da0073e9SAndroid Build Coastguard Worker        mkldnn_out = out.to_mkldnn()
1128*da0073e9SAndroid Build Coastguard Worker        torch.mul(x, value, out=out)
1129*da0073e9SAndroid Build Coastguard Worker        torch.mul(mx, value, out=mkldnn_out)
1130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, mkldnn_out.to_dense())
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker    def test_0_dimension_tensor(self):
1133*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([20, 20, 1, 1], dtype=torch.float)
1134*da0073e9SAndroid Build Coastguard Worker        y = torch.rand([20, 20, 0, 1], dtype=torch.float)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # unary ops work without modification
1137*da0073e9SAndroid Build Coastguard Worker        out_relu = torch.relu(y)
1138*da0073e9SAndroid Build Coastguard Worker        out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense()
1139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_relu, out_relu_mkldnn)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        out_mul = x * y
1142*da0073e9SAndroid Build Coastguard Worker        out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense()
1143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mul, out_mul_mkldnn)
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker        out_add = x + y
1146*da0073e9SAndroid Build Coastguard Worker        out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense()
1147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_add, out_add_mkldnn)
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
1150*da0073e9SAndroid Build Coastguard Worker        y.requires_grad_(True)
1151*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"):
1152*da0073e9SAndroid Build Coastguard Worker            x.to_mkldnn() + y.to_mkldnn()
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must match"):
1155*da0073e9SAndroid Build Coastguard Worker            torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn()
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        C = 7
1158*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Conv2d(C, C, 3)
1159*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, C, C, 8, dtype=torch.float)
1160*da0073e9SAndroid Build Coastguard Worker        out_eager = m(x)
1161*da0073e9SAndroid Build Coastguard Worker        out_mkldnn = mkldnn_utils.to_mkldnn(m)(x)
1162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_eager, out_mkldnn)
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/127111
1165*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
1166*da0073e9SAndroid Build Coastguard Worker    def test_view(self):
1167*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1168*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError,
1169*da0073e9SAndroid Build Coastguard Worker                               "Change to use reshape",
1170*da0073e9SAndroid Build Coastguard Worker                               lambda: x.view(x.size(0), -1))
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self):
1173*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1174*da0073e9SAndroid Build Coastguard Worker        size = (x.size(0), -1)
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1177*da0073e9SAndroid Build Coastguard Worker            x.reshape(size),
1178*da0073e9SAndroid Build Coastguard Worker            x.to_mkldnn().reshape(size).to_dense(),
1179*da0073e9SAndroid Build Coastguard Worker        )
1180*da0073e9SAndroid Build Coastguard Worker        # test whether share same memory for plain format tensor
1181*da0073e9SAndroid Build Coastguard Worker        y = x.to_mkldnn()
1182*da0073e9SAndroid Build Coastguard Worker        z = y.reshape(size).add_(y.reshape(size))
1183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1184*da0073e9SAndroid Build Coastguard Worker            y.reshape(size).to_dense(),
1185*da0073e9SAndroid Build Coastguard Worker            z.to_dense(),
1186*da0073e9SAndroid Build Coastguard Worker        )
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker    def test_reshape_blocked_format(self):
1189*da0073e9SAndroid Build Coastguard Worker        # construct an mkldnn blocked tensor with mkldnn conv2d
1190*da0073e9SAndroid Build Coastguard Worker        C = 7
1191*da0073e9SAndroid Build Coastguard Worker        m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3))
1192*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, C, 8, 8).to_mkldnn()
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        # mkldnn tensor w/ blocked format
1195*da0073e9SAndroid Build Coastguard Worker        y_block = m(x)
1196*da0073e9SAndroid Build Coastguard Worker        # aten tensor w/ plain format
1197*da0073e9SAndroid Build Coastguard Worker        y_plain = y_block.to_dense()
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker        y_block_reshape = y_block.reshape(C, -1)
1200*da0073e9SAndroid Build Coastguard Worker        y_plain_reshape = y_plain.reshape(C, -1)
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_plain_reshape, y_block_reshape.to_dense())
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    def test_reshape_backward(self):
1205*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1206*da0073e9SAndroid Build Coastguard Worker        size = (x.size(0), -1)
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_()
1209*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().to_mkldnn().requires_grad_()
1210*da0073e9SAndroid Build Coastguard Worker        in_features = 20
1211*da0073e9SAndroid Build Coastguard Worker        out_features = torch.randint(3, 100, (1,)).item()
1212*da0073e9SAndroid Build Coastguard Worker        linear = torch.nn.Linear(in_features, out_features).float()
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker        y1 = linear(x1.reshape(size)).sum()
1215*da0073e9SAndroid Build Coastguard Worker        y2 = linear(x2.reshape(size).to_dense()).sum()
1216*da0073e9SAndroid Build Coastguard Worker        y1.backward()
1217*da0073e9SAndroid Build Coastguard Worker        y2.backward()
1218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad.to_dense())
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker    def test_clone(self):
1221*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5, dtype=torch.float32) * 10
1222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1223*da0073e9SAndroid Build Coastguard Worker            x.clone(),
1224*da0073e9SAndroid Build Coastguard Worker            x.to_mkldnn().clone().to_dense(),
1225*da0073e9SAndroid Build Coastguard Worker        )
1226*da0073e9SAndroid Build Coastguard Worker        # test whether share same memory
1227*da0073e9SAndroid Build Coastguard Worker        y = x.to_mkldnn()
1228*da0073e9SAndroid Build Coastguard Worker        z = y.clone().add_(y)
1229*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
1230*da0073e9SAndroid Build Coastguard Worker            y.to_dense(),
1231*da0073e9SAndroid Build Coastguard Worker            z.to_dense(),
1232*da0073e9SAndroid Build Coastguard Worker        )
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self):
1235*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1236*da0073e9SAndroid Build Coastguard Worker        for dim1 in range(x.ndim):
1237*da0073e9SAndroid Build Coastguard Worker            for dim2 in range(x.ndim):
1238*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1239*da0073e9SAndroid Build Coastguard Worker                    x.transpose(dim1, dim2),
1240*da0073e9SAndroid Build Coastguard Worker                    x.to_mkldnn().transpose(dim1, dim2).to_dense(),
1241*da0073e9SAndroid Build Coastguard Worker                )
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker    def test_transpose_invalid_dime(self):
1244*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1245*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1246*da0073e9SAndroid Build Coastguard Worker            torch._mkldnn_transpose(x, 0, 12)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker    def test_linear_non_contiguous_weight(self):
1249*da0073e9SAndroid Build Coastguard Worker        in_features = torch.randint(3, 10, (1,)).item()
1250*da0073e9SAndroid Build Coastguard Worker        out_features = torch.randint(3, 100, (1,)).item()
1251*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1252*da0073e9SAndroid Build Coastguard Worker        w = torch.randn(in_features, out_features, dtype=torch.float32)
1253*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
1254*da0073e9SAndroid Build Coastguard Worker            x1 = x.clone().requires_grad_()
1255*da0073e9SAndroid Build Coastguard Worker            x2 = x.clone().to_mkldnn().requires_grad_()
1256*da0073e9SAndroid Build Coastguard Worker            linear = torch.nn.Linear(in_features, out_features).float()
1257*da0073e9SAndroid Build Coastguard Worker            linear.weight = torch.nn.Parameter(w.t())
1258*da0073e9SAndroid Build Coastguard Worker            mkldnn_linear = copy.deepcopy(linear)
1259*da0073e9SAndroid Build Coastguard Worker            y1 = linear(x1).sum()
1260*da0073e9SAndroid Build Coastguard Worker            y2 = mkldnn_linear(x2).to_dense().sum()
1261*da0073e9SAndroid Build Coastguard Worker            y1.backward()
1262*da0073e9SAndroid Build Coastguard Worker            y2.backward()
1263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1.grad, x2.grad.to_dense())
1264*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1265*da0073e9SAndroid Build Coastguard Worker            if bias:
1266*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker    def test_linear(self):
1269*da0073e9SAndroid Build Coastguard Worker        in_features = torch.randint(3, 10, (1,)).item()
1270*da0073e9SAndroid Build Coastguard Worker        out_features = torch.randint(3, 100, (1,)).item()
1271*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
1274*da0073e9SAndroid Build Coastguard Worker            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1275*da0073e9SAndroid Build Coastguard Worker            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1276*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1277*da0073e9SAndroid Build Coastguard Worker                linear(x),
1278*da0073e9SAndroid Build Coastguard Worker                mkldnn_linear(x.to_mkldnn()).to_dense())
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker            self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
1281*da0073e9SAndroid Build Coastguard Worker            self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker    def test_linear_backward(self):
1284*da0073e9SAndroid Build Coastguard Worker        in_features = torch.randint(3, 10, (1,)).item()
1285*da0073e9SAndroid Build Coastguard Worker        out_features = torch.randint(3, 100, (1,)).item()
1286*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1287*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
1288*da0073e9SAndroid Build Coastguard Worker            x1 = x.clone().requires_grad_()
1289*da0073e9SAndroid Build Coastguard Worker            x2 = x.clone().to_mkldnn().requires_grad_()
1290*da0073e9SAndroid Build Coastguard Worker            linear = torch.nn.Linear(in_features, out_features).float()
1291*da0073e9SAndroid Build Coastguard Worker            mkldnn_linear = copy.deepcopy(linear)
1292*da0073e9SAndroid Build Coastguard Worker            y1 = linear(x1).sum()
1293*da0073e9SAndroid Build Coastguard Worker            y2 = mkldnn_linear(x2).to_dense().sum()
1294*da0073e9SAndroid Build Coastguard Worker            y1.backward()
1295*da0073e9SAndroid Build Coastguard Worker            y2.backward()
1296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1.grad, x2.grad.to_dense())
1297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1298*da0073e9SAndroid Build Coastguard Worker            if bias:
1299*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
1302*da0073e9SAndroid Build Coastguard Worker    def test_linear_lowp(self, dtype):
1303*da0073e9SAndroid Build Coastguard Worker        in_features = torch.randint(3, 10, (1,)).item()
1304*da0073e9SAndroid Build Coastguard Worker        out_features = torch.randint(3, 100, (1,)).item()
1305*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1306*da0073e9SAndroid Build Coastguard Worker        x_lowp = x.to(dtype=dtype)
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
1309*da0073e9SAndroid Build Coastguard Worker            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1310*da0073e9SAndroid Build Coastguard Worker            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1311*da0073e9SAndroid Build Coastguard Worker            mkldnn_linear_lowp = mkldnn_utils.to_mkldnn(
1312*da0073e9SAndroid Build Coastguard Worker                copy.deepcopy(linear), dtype
1313*da0073e9SAndroid Build Coastguard Worker            )
1314*da0073e9SAndroid Build Coastguard Worker            lowp_support = {
1315*da0073e9SAndroid Build Coastguard Worker                torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1316*da0073e9SAndroid Build Coastguard Worker                torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1317*da0073e9SAndroid Build Coastguard Worker            }
1318*da0073e9SAndroid Build Coastguard Worker            if lowp_support[dtype]():
1319*da0073e9SAndroid Build Coastguard Worker                y = mkldnn_linear(x.to_mkldnn()).to_dense()
1320*da0073e9SAndroid Build Coastguard Worker                y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense(
1321*da0073e9SAndroid Build Coastguard Worker                    torch.float32
1322*da0073e9SAndroid Build Coastguard Worker                )
1323*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.bfloat16:
1324*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3)
1325*da0073e9SAndroid Build Coastguard Worker                else:
1326*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3)
1327*da0073e9SAndroid Build Coastguard Worker            else:
1328*da0073e9SAndroid Build Coastguard Worker                msg = {
1329*da0073e9SAndroid Build Coastguard Worker                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
1330*da0073e9SAndroid Build Coastguard Worker                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
1331*da0073e9SAndroid Build Coastguard Worker                }
1332*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
1333*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1334*da0073e9SAndroid Build Coastguard Worker                    msg[dtype],
1335*da0073e9SAndroid Build Coastguard Worker                    lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()),
1336*da0073e9SAndroid Build Coastguard Worker                )
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self):
1339*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1340*da0073e9SAndroid Build Coastguard Worker        for dim in range(x.ndim):
1341*da0073e9SAndroid Build Coastguard Worker            softmax = torch.nn.Softmax(dim=dim)
1342*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1343*da0073e9SAndroid Build Coastguard Worker                softmax(x),
1344*da0073e9SAndroid Build Coastguard Worker                softmax(x.to_mkldnn()).to_dense())
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker    def test_sigmoid(self):
1347*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5, dtype=torch.float32) * 10
1348*da0073e9SAndroid Build Coastguard Worker        mkldnn_x = x.to_mkldnn()
1349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1350*da0073e9SAndroid Build Coastguard Worker            torch.sigmoid(x),
1351*da0073e9SAndroid Build Coastguard Worker            torch.sigmoid(mkldnn_x).to_dense(),
1352*da0073e9SAndroid Build Coastguard Worker        )
1353*da0073e9SAndroid Build Coastguard Worker        # inplace
1354*da0073e9SAndroid Build Coastguard Worker        torch.sigmoid_(x)
1355*da0073e9SAndroid Build Coastguard Worker        torch.sigmoid_(mkldnn_x)
1356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mkldnn_x.to_dense())
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    def test_tanh(self):
1359*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5, dtype=torch.float32) * 10
1360*da0073e9SAndroid Build Coastguard Worker        mkldnn_x = x.to_mkldnn()
1361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1362*da0073e9SAndroid Build Coastguard Worker            torch.tanh(x),
1363*da0073e9SAndroid Build Coastguard Worker            torch.tanh(mkldnn_x).to_dense(),
1364*da0073e9SAndroid Build Coastguard Worker        )
1365*da0073e9SAndroid Build Coastguard Worker        # inplace
1366*da0073e9SAndroid Build Coastguard Worker        torch.tanh_(x)
1367*da0073e9SAndroid Build Coastguard Worker        torch.tanh_(mkldnn_x)
1368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, mkldnn_x.to_dense())
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker    def _test_serialization(self, module, inputs):
1371*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
1372*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(module, fname)
1373*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(fname)
1374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1375*da0073e9SAndroid Build Coastguard Worker                module(*inputs).to_dense(),
1376*da0073e9SAndroid Build Coastguard Worker                loaded(*inputs).to_dense())
1377*da0073e9SAndroid Build Coastguard Worker
1378*da0073e9SAndroid Build Coastguard Worker    def _test_tracing(self, module, inputs):
1379*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(module, inputs)
1380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1381*da0073e9SAndroid Build Coastguard Worker            module(*inputs).to_dense(),
1382*da0073e9SAndroid Build Coastguard Worker            traced(*inputs).to_dense())
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker    def test_set_data_tensorimpl_type(self):
1385*da0073e9SAndroid Build Coastguard Worker        # Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
1386*da0073e9SAndroid Build Coastguard Worker        # of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
1387*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
1388*da0073e9SAndroid Build Coastguard Worker        x_mkldnn = x.to_mkldnn()
1389*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'):
1390*da0073e9SAndroid Build Coastguard Worker            x.data = x_mkldnn
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker    def test_empty(self):
1393*da0073e9SAndroid Build Coastguard Worker        x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32)
1394*da0073e9SAndroid Build Coastguard Worker        x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn)
1395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.size(), x2.to_dense().size())
1396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.dtype, x2.to_dense().dtype)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker    def test_zero_(self):
1399*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(4, 5, dtype=torch.float32) * 10
1400*da0073e9SAndroid Build Coastguard Worker        x2 = x1.clone().to_mkldnn()
1401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1402*da0073e9SAndroid Build Coastguard Worker            x1.zero_(),
1403*da0073e9SAndroid Build Coastguard Worker            x2.zero_().to_dense(),
1404*da0073e9SAndroid Build Coastguard Worker        )
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    def test_is_mkldnn(self):
1407*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, dtype=torch.float32)
1408*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_mkldnn)
1409*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.to_mkldnn().is_mkldnn)
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker    # legacy constructor/new doesn't support mkldnn tensors
1412*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
1413*da0073e9SAndroid Build Coastguard Worker    def test_legacy_new_failure(self):
1414*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, dtype=torch.float32)
1415*da0073e9SAndroid Build Coastguard Worker        x_mkldnn = x.to_mkldnn()
1416*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu'))
1417*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage()))
1418*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x))
1419*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3])))
1420*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6]))
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker    def test_is_mkldnn_jit(self):
1423*da0073e9SAndroid Build Coastguard Worker        class EnsureMkldnn(torch.jit.ScriptModule):
1424*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
1425*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1426*da0073e9SAndroid Build Coastguard Worker                if not x.is_mkldnn:
1427*da0073e9SAndroid Build Coastguard Worker                    x = x.to_mkldnn()
1428*da0073e9SAndroid Build Coastguard Worker                return x
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker        m = EnsureMkldnn()
1431*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, dtype=torch.float32)
1432*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m(x).is_mkldnn)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m(x.to_mkldnn()).is_mkldnn)
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    def _test_imagenet_model(self, model):
1436*da0073e9SAndroid Build Coastguard Worker        model = model.train(False).float()
1437*da0073e9SAndroid Build Coastguard Worker        mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model))
1438*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 3, 224, 224, dtype=torch.float32)
1439*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1440*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1441*da0073e9SAndroid Build Coastguard Worker                model(x),
1442*da0073e9SAndroid Build Coastguard Worker                mkldnn_model(x.to_mkldnn()).to_dense(),
1443*da0073e9SAndroid Build Coastguard Worker            )
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1446*da0073e9SAndroid Build Coastguard Worker    def test_resnet18(self):
1447*da0073e9SAndroid Build Coastguard Worker        model = torchvision.models.resnet.resnet18(weights=None)
1448*da0073e9SAndroid Build Coastguard Worker        self._test_imagenet_model(model)
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1451*da0073e9SAndroid Build Coastguard Worker    def test_resnext50_32x4d(self):
1452*da0073e9SAndroid Build Coastguard Worker        model = torchvision.models.resnet.resnext50_32x4d(weights=None)
1453*da0073e9SAndroid Build Coastguard Worker        self._test_imagenet_model(model)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    def _lstm_params_list(self):
1456*da0073e9SAndroid Build Coastguard Worker        params_dict = {
1457*da0073e9SAndroid Build Coastguard Worker            "input_size": [1, 5],
1458*da0073e9SAndroid Build Coastguard Worker            "hidden_size": [5, 16],
1459*da0073e9SAndroid Build Coastguard Worker            "num_layers": [1, 3],
1460*da0073e9SAndroid Build Coastguard Worker            "bidirectional": [False, True],
1461*da0073e9SAndroid Build Coastguard Worker            "bias": [False, True],
1462*da0073e9SAndroid Build Coastguard Worker            "batch_first": [False, True],
1463*da0073e9SAndroid Build Coastguard Worker            "dropout": [0, 0.4, 0.7, 1],
1464*da0073e9SAndroid Build Coastguard Worker            "batch_size": [1, 2],
1465*da0073e9SAndroid Build Coastguard Worker            "seq_len": [1, 3],
1466*da0073e9SAndroid Build Coastguard Worker            "training": [False, True]
1467*da0073e9SAndroid Build Coastguard Worker        }
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        params_list = list(params_dict.values())
1470*da0073e9SAndroid Build Coastguard Worker        return params_list
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker    def _cast_dtype(self, input, dtype):
1473*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16:
1474*da0073e9SAndroid Build Coastguard Worker            input = input.to(torch.bfloat16)
1475*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.half:
1476*da0073e9SAndroid Build Coastguard Worker            input = input.to(torch.half)
1477*da0073e9SAndroid Build Coastguard Worker        return input
1478*da0073e9SAndroid Build Coastguard Worker
1479*da0073e9SAndroid Build Coastguard Worker    def test_lstm(self):
1480*da0073e9SAndroid Build Coastguard Worker        seed = 2023
1481*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker        params_list = self._lstm_params_list()
1484*da0073e9SAndroid Build Coastguard Worker        for dtype in types:
1485*da0073e9SAndroid Build Coastguard Worker            bf16 = dtype == torch.bfloat16
1486*da0073e9SAndroid Build Coastguard Worker            fp16 = dtype == torch.half
1487*da0073e9SAndroid Build Coastguard Worker            rtol = 1.3e-6
1488*da0073e9SAndroid Build Coastguard Worker            atol = 1e-5
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker            if bf16:
1491*da0073e9SAndroid Build Coastguard Worker                rtol = 0.02
1492*da0073e9SAndroid Build Coastguard Worker                atol = 0.02
1493*da0073e9SAndroid Build Coastguard Worker            if fp16:
1494*da0073e9SAndroid Build Coastguard Worker                rtol = 1e-3
1495*da0073e9SAndroid Build Coastguard Worker                atol = 1e-3
1496*da0073e9SAndroid Build Coastguard Worker            for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
1497*da0073e9SAndroid Build Coastguard Worker                    in itertools.product(*params_list):
1498*da0073e9SAndroid Build Coastguard Worker                num_directions = 2 if bidirectional else 1
1499*da0073e9SAndroid Build Coastguard Worker                if batch_first:
1500*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32)
1501*da0073e9SAndroid Build Coastguard Worker                else:
1502*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
1503*da0073e9SAndroid Build Coastguard Worker                h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1504*da0073e9SAndroid Build Coastguard Worker                c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1505*da0073e9SAndroid Build Coastguard Worker                if fp16:
1506*da0073e9SAndroid Build Coastguard Worker                    # TODO add traing support when oneDNN support lstm FP16 training
1507*da0073e9SAndroid Build Coastguard Worker                    training = False
1508*da0073e9SAndroid Build Coastguard Worker                model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
1509*da0073e9SAndroid Build Coastguard Worker                                      bias=bias, dropout=dropout, batch_first=batch_first).float()
1510*da0073e9SAndroid Build Coastguard Worker                model.train() if training else model.eval()
1511*da0073e9SAndroid Build Coastguard Worker                input1 = input.clone().requires_grad_(training)
1512*da0073e9SAndroid Build Coastguard Worker                input2 = input.clone().requires_grad_(training)
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker                h1 = h.clone().requires_grad_(training)
1515*da0073e9SAndroid Build Coastguard Worker                h2 = h.clone().requires_grad_(training)
1516*da0073e9SAndroid Build Coastguard Worker                c1 = c.clone().requires_grad_(training)
1517*da0073e9SAndroid Build Coastguard Worker                c2 = c.clone().requires_grad_(training)
1518*da0073e9SAndroid Build Coastguard Worker
1519*da0073e9SAndroid Build Coastguard Worker                model1 = copy.deepcopy(model)
1520*da0073e9SAndroid Build Coastguard Worker                model2 = copy.deepcopy(model)
1521*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad() if not training else nullcontext():
1522*da0073e9SAndroid Build Coastguard Worker                    with torch.backends.mkldnn.flags(enabled=False):
1523*da0073e9SAndroid Build Coastguard Worker                        torch.manual_seed(seed)
1524*da0073e9SAndroid Build Coastguard Worker                        output1, (hn1, cn1) = self._cast_dtype(model1, dtype)(
1525*da0073e9SAndroid Build Coastguard Worker                            self._cast_dtype(input1, dtype),
1526*da0073e9SAndroid Build Coastguard Worker                            (
1527*da0073e9SAndroid Build Coastguard Worker                                self._cast_dtype(h1, dtype),
1528*da0073e9SAndroid Build Coastguard Worker                                self._cast_dtype(c1, dtype),
1529*da0073e9SAndroid Build Coastguard Worker                            ),
1530*da0073e9SAndroid Build Coastguard Worker                        )
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker                    torch.manual_seed(seed)
1533*da0073e9SAndroid Build Coastguard Worker                    output2, (hn2, cn2) = self._cast_dtype(model2, dtype)(
1534*da0073e9SAndroid Build Coastguard Worker                        self._cast_dtype(input2, dtype),
1535*da0073e9SAndroid Build Coastguard Worker                        (
1536*da0073e9SAndroid Build Coastguard Worker                            self._cast_dtype(h2, dtype),
1537*da0073e9SAndroid Build Coastguard Worker                            self._cast_dtype(c2, dtype),
1538*da0073e9SAndroid Build Coastguard Worker                        ),
1539*da0073e9SAndroid Build Coastguard Worker                    )
1540*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output1, output2, rtol=rtol, atol=atol)
1541*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
1542*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker                    if training:
1545*da0073e9SAndroid Build Coastguard Worker                        with torch.backends.mkldnn.flags(enabled=False):
1546*da0073e9SAndroid Build Coastguard Worker                            torch.manual_seed(seed)
1547*da0073e9SAndroid Build Coastguard Worker                            output1.sum().backward(retain_graph=True)
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker                        torch.manual_seed(seed)
1550*da0073e9SAndroid Build Coastguard Worker                        output2.sum().backward(retain_graph=True)
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
1553*da0073e9SAndroid Build Coastguard Worker                        for name, para in model1.named_parameters():
1554*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(para, getattr(model2, name))
1555*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(
1556*da0073e9SAndroid Build Coastguard Worker                                para.grad,
1557*da0073e9SAndroid Build Coastguard Worker                                getattr(model2, name).grad,
1558*da0073e9SAndroid Build Coastguard Worker                                rtol=rtol,
1559*da0073e9SAndroid Build Coastguard Worker                                atol=atol,
1560*da0073e9SAndroid Build Coastguard Worker                            )
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker                        with torch.backends.mkldnn.flags(enabled=False):
1563*da0073e9SAndroid Build Coastguard Worker                            torch.manual_seed(seed)
1564*da0073e9SAndroid Build Coastguard Worker                            hn1.sum().backward(retain_graph=True)
1565*da0073e9SAndroid Build Coastguard Worker                        torch.manual_seed(seed)
1566*da0073e9SAndroid Build Coastguard Worker                        hn2.sum().backward(retain_graph=True)
1567*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker                        with torch.backends.mkldnn.flags(enabled=False):
1570*da0073e9SAndroid Build Coastguard Worker                            torch.manual_seed(seed)
1571*da0073e9SAndroid Build Coastguard Worker                            cn1.sum().backward(retain_graph=True)
1572*da0073e9SAndroid Build Coastguard Worker                        torch.manual_seed(seed)
1573*da0073e9SAndroid Build Coastguard Worker                        cn2.sum().backward(retain_graph=True)
1574*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol)
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
1577*da0073e9SAndroid Build Coastguard Worker    def test_matmul_lower_precision(self, dtype):
1578*da0073e9SAndroid Build Coastguard Worker        support_check = {
1579*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1580*da0073e9SAndroid Build Coastguard Worker            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1581*da0073e9SAndroid Build Coastguard Worker        }
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker        def common(self, shape1, shape2, op, dtype):
1584*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(shape1, dtype=dtype)
1585*da0073e9SAndroid Build Coastguard Worker            a_ref = a.float()
1586*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(shape2, dtype=dtype)
1587*da0073e9SAndroid Build Coastguard Worker            b_ref = b.float()
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker            y = op(a, b)
1590*da0073e9SAndroid Build Coastguard Worker            y_ref = op(a_ref, b_ref)
1591*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_ref, exact_dtype=False)
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        if support_check[dtype]():
1594*da0073e9SAndroid Build Coastguard Worker            a1 = torch.randn([64, 1, 33], dtype=dtype)
1595*da0073e9SAndroid Build Coastguard Worker            # a2 is contiguous tensor but it's strides
1596*da0073e9SAndroid Build Coastguard Worker            # is not default contiguous strides.
1597*da0073e9SAndroid Build Coastguard Worker            a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1])
1598*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a2.is_contiguous())
1599*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(64, 33, 256).to(dtype=dtype)
1600*da0073e9SAndroid Build Coastguard Worker            y1 = torch.ops.aten.bmm(a1, b)
1601*da0073e9SAndroid Build Coastguard Worker            y2 = torch.bmm(a2, b)
1602*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y1, y2)
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker            for shape1, shape2, op in [
1605*da0073e9SAndroid Build Coastguard Worker                ((33, 77), (77, 22), torch.matmul),
1606*da0073e9SAndroid Build Coastguard Worker                ((128, 256), (256, 10), torch.matmul),
1607*da0073e9SAndroid Build Coastguard Worker                ((7, 300), (300, 3), torch.matmul),
1608*da0073e9SAndroid Build Coastguard Worker                ((1, 100), (100, 60), torch.matmul),
1609*da0073e9SAndroid Build Coastguard Worker                ((100, 1), (1, 100), torch.matmul),
1610*da0073e9SAndroid Build Coastguard Worker                ((20, 54, 78), (20, 78, 10), torch.bmm),
1611*da0073e9SAndroid Build Coastguard Worker                ((1, 300, 1), (1, 1, 300), torch.bmm),
1612*da0073e9SAndroid Build Coastguard Worker            ]:
1613*da0073e9SAndroid Build Coastguard Worker                common(self, shape1, shape2, op, dtype)
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker
1616*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
1619*da0073e9SAndroid Build Coastguard Worker    run_tests()
1620