xref: /aosp_15_r20/external/pytorch/test/test_mkldnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mkldnn"]
2
3import copy
4import itertools
5import functools
6import unittest
7from contextlib import nullcontext
8
9try:
10    import torchvision
11    HAS_TORCHVISION = True
12except ImportError:
13    HAS_TORCHVISION = False
14
15skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
16
17import torch
18import torch.nn.functional as F
19import torch.jit
20import torch.backends.mkldnn
21from torch.utils import mkldnn as mkldnn_utils
22from torch.testing._internal.common_utils import TestCase, \
23    run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
24    skipIfTorchDynamo, xfailIfTorchDynamo
25from torch.testing._internal.common_device_type import (
26    instantiate_device_type_tests,
27    dtypes,
28)
29
30# batched grad doesn't support mkldnn
31gradcheck = functools.partial(gradcheck, check_batched_grad=False)
32gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False)
33
34
35types = [torch.float, torch.bfloat16, torch.half]
36
37# Comment the line below to find out the CI machines having MKL-DNN build disabled
38@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
39class TestMkldnn(TestCase):
40    def test_conversion(self):
41        for cpu_tensor in [torch.randn((1, 2, 3, 4),
42                                       dtype=torch.float, device=torch.device('cpu')),
43                           torch.randn((1, 2, 3, 4, 5),
44                                       dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]:
45            cpu_tensor.requires_grad_()
46            convert_dtypes = {torch.half: [torch.half, torch.float],
47                              torch.bfloat16: [torch.bfloat16, torch.float],
48                              torch.float: [torch.bfloat16, torch.half]}
49            # float/bfloat16/half cpu tensor to mkldnn tensortensor.
50            for dtype1 in types:
51                mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1)
52                self.assertEqual(mkldnn_tensor.dtype, dtype1)
53                cpu_tensor_1 = mkldnn_tensor.to_dense()
54                # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
55                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
56                # mkldnn float/bfloat tensor to cpu float or bfloat tensor
57                for dtype2 in convert_dtypes[dtype1]:
58                    cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
59                    self.assertEqual(cpu_tensor_2.dtype, dtype2)
60                    atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2
61                    self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0)
62
63                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
64                self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
65                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
66                if dtype1 == torch.float:
67                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
68                else:
69                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2)
70                self.assertRaisesRegex(RuntimeError,
71                                       "Cannot access data pointer of Tensor that doesn't have storage",
72                                       lambda: mkldnn_tensor.data_ptr() != 0)
73
74            # bfloat cpu tensor to mkldnn float tensor or bfloat tensor.
75            for orig_dtype in [torch.half, torch.bfloat16]:
76                cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype)
77                for dtype1 in convert_dtypes[orig_dtype]:
78                    mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1)
79                    self.assertEqual(mkldnn_tensor.dtype, dtype1)
80                    cpu_tensor_1 = mkldnn_tensor.to_dense()
81                    # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
82                    self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
83                    # mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor
84                    for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]:
85                        cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
86                        self.assertEqual(cpu_tensor_2.dtype, dtype2)
87                        self.assertEqual(cpu_tensor_lower,
88                                         cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0)
89
90                    self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
91                    self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
92                    self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
93                    if dtype1 in [torch.bfloat16, torch.half]:
94                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size())
95                    else:
96                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2)
97                    self.assertRaisesRegex(RuntimeError,
98                                           "Cannot access data pointer of Tensor that doesn't have storage",
99                                           lambda: mkldnn_tensor.data_ptr() != 0)
100
101    def test_conversion_byte_char(self):
102        int8_types = [torch.int8, torch.uint8]
103        for int8_type in int8_types:
104            low = -100 if int8_type is torch.int8 else 0
105            high = 100
106            for cpu_tensor in [torch.randint(
107                               low=low,
108                               high=high,
109                               size=(1, 2, 3, 4),
110                               dtype=torch.int64,
111                               device=torch.device('cpu')),
112                               torch.randint(
113                               low=low,
114                               high=high,
115                               size=(1, 2, 3, 4, 5),
116                               dtype=torch.int64,
117                               device=torch.device('cpu'))[:, :, :, :, :]]:
118
119                cpu_tensor = cpu_tensor.to(dtype=int8_type)
120                mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type)
121                self.assertEqual(mkldnn_tensor.dtype, int8_type)
122                cpu_tensor_1 = mkldnn_tensor.to_dense()
123                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
124                self.assertEqual(cpu_tensor, cpu_tensor_1)
125                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
126                self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size())
127                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
128                self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
129                self.assertRaisesRegex(RuntimeError,
130                                       "Cannot access data pointer of Tensor that doesn't have storage",
131                                       lambda: mkldnn_tensor.data_ptr() != 0)
132
133    def test_copy(self):
134        x = torch.randn(4, 5, dtype=torch.float32)
135        mkldnn_x = x.to_mkldnn()
136        mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn()
137        mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn()
138        mkldnn_y.copy_(mkldnn_x)
139        self.assertEqual(x, mkldnn_y.to_dense())
140        self.assertRaisesRegex(RuntimeError,
141                               "copy_mkldnn_: only support same size tensor.",
142                               lambda: mkldnn_z.copy_(mkldnn_x))
143        self.assertRaisesRegex(RuntimeError,
144                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
145                               "Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor",
146                               lambda: x.copy_(mkldnn_x))
147        self.assertRaisesRegex(RuntimeError,
148                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
149                               "Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor",
150                               lambda: mkldnn_x.copy_(x))
151
152    def test_unsupported(self):
153        # unsupported types and unsupported types with gpu
154        for dtype in [torch.double, torch.uint8, torch.int8,
155                      torch.short, torch.int, torch.long]:
156            with self.assertRaises(RuntimeError) as context:
157                torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn()
158            if torch.cuda.is_available():
159                with self.assertRaises(RuntimeError) as context:
160                    torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn()
161        # supported type with gpu
162        if torch.cuda.is_available():
163            with self.assertRaises(RuntimeError) as context:
164                torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn()
165        # some factory functions
166        for creator in [torch.ones, torch.randn, torch.rand]:
167            with self.assertRaises(RuntimeError) as context:
168                creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn)
169
170    def test_mkldnn_conv_shapecheck(self):
171        input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
172        w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
173        b1 = torch.full((1,), 1, dtype=torch.float32)
174        w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32)
175        b2 = torch.full((2,), 1, dtype=torch.float32)
176        options = zip([-1, 0, 0, 0, 0, 0, 0],  # padding
177                      [1, 0, 1, 1, 1, 1, 1],  # stride
178                      [1, 1, 0, 1, 1, 1, 1],  # dilation
179                      [1, 1, 1, 0, 2, 1, 1],  # groups
180                      [w1, w1, w1, w1, w1, w1, w2],  # weight
181                      [b1, b1, b1, b1, b1, b2, b1])  # bias
182        for pad, st, dil, gr, w, b in options:
183            with self.assertRaises(RuntimeError) as _:
184                torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr)
185
186    def test_autograd_to_mkldnn(self):
187        # MKLDNN only supports float32
188        root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
189
190        def func(root):
191            return root.to_mkldnn().to_dense()
192
193        # because MKLDNN only supports float32, we need to lessen the precision.
194        # these numbers are just empirical results that seem to work.
195        self.assertWarnsRegex(UserWarning,
196                              'double precision floating point',
197                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
198        self.assertWarnsRegex(UserWarning,
199                              'double precision floating point',
200                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
201
202    def test_autograd_from_mkldnn(self):
203        # MKLDNN only supports float32
204        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
205
206        def func(root):
207            return root.to_dense()
208
209        # because MKLDNN only supports float32, we need to lessen the precision.
210        # these numbers are just empirical results that seem to work.
211        self.assertWarnsRegex(UserWarning,
212                              'double precision floating point',
213                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
214
215    def test_detach(self):
216        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
217
218        detach = root.detach()
219        self.assertEqual((4, 5), detach.size())
220        self.assertFalse(detach.requires_grad)
221        self.assertTrue(root.requires_grad)
222
223        detach_ = root.detach_()
224        self.assertEqual((4, 5), detach_.size())
225        self.assertFalse(detach_.requires_grad)
226        self.assertFalse(root.requires_grad)
227
228    def test_repr(self):
229        self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
230                                                                  dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
231
232    def _test_conv_base(self, dim):
233        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
234        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
235        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
236        for train, bias, dilation, groups in options:
237            N = torch.randint(3, 10, (1,)).item()
238            M = torch.randint(1, 3, (1,)).item() * groups
239            C = torch.randint(1, 3, (1,)).item() * groups
240            x_shape = (N, C) + input_shapes[dim]
241            x = torch.randn(x_shape, dtype=torch.float32)
242            conv = conv_module[dim](in_channels=C,
243                                    out_channels=M,
244                                    kernel_size=3,
245                                    stride=2,
246                                    padding=1,
247                                    dilation=dilation,
248                                    bias=bias,
249                                    groups=groups).float()
250            x1 = x.clone()
251            x2 = x.clone().to_mkldnn()
252            if not train:
253                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
254            elif train and dim != 1:
255                # TODO: enable conv1d training.
256                x1.requires_grad_()
257                x2.requires_grad_()
258                mkldnn_conv = copy.deepcopy(conv)
259            with torch.backends.mkldnn.flags(enabled=False):
260                y_aten = conv(x1)
261                if train and dim != 1:
262                    loss1 = y_aten.sum()
263                    loss1.backward()
264            if not train or (train and dim != 1):
265                y_mkldnn = mkldnn_conv(x2).to_dense()
266                self.assertEqual(y_aten, y_mkldnn)
267            if not train:
268                self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
269                self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
270            elif dim != 1:
271                loss2 = y_mkldnn.sum()
272                loss2.backward()
273                self.assertTrue(x2.grad.is_mkldnn)
274                self.assertEqual(x1.grad, x2.grad.to_dense())
275                self.assertEqual(conv.weight.grad,
276                                 mkldnn_conv.weight.grad,
277                                 atol=1e-3,
278                                 rtol=1e-3)
279                if bias:
280                    self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
281
282    def test_conv1d(self):
283        self._test_conv_base(dim=1)
284
285    def test_conv2d(self):
286        self._test_conv_base(dim=2)
287
288    def test_conv3d(self):
289        self._test_conv_base(dim=3)
290
291    def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype):
292        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
293        options = itertools.product([True, False], [1, 2], [1, 4])
294        for bias, dilation, groups in options:
295            N = torch.randint(1, 3, (1,)).item()
296            M = torch.randint(1, 3, (1,)).item() * groups
297            C = torch.randint(1, 3, (1,)).item() * groups
298            x_shape = (N, C) + input_shapes[dim]
299            x = torch.randn(x_shape, dtype=torch.float32)
300            # TODO: remove this when group depthwise is supported:
301            if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
302                               torch.nn.ConvTranspose3d] and groups > 1 and C == groups:
303                continue
304            conv = conv_module(in_channels=C,
305                               out_channels=M,
306                               kernel_size=3,
307                               stride=2,
308                               padding=1,
309                               dilation=dilation,
310                               bias=bias,
311                               groups=groups).float()
312            x_lower = x.to(dtype=dtype)
313            if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \
314               (dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()):
315                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
316                mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
317                y = mkldnn_conv(x.to_mkldnn()).to_dense()
318                y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
319                self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3)
320            else:
321                msg = {
322                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
323                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
324                }
325                with self.assertRaisesRegex(RuntimeError, msg[dtype]):
326                    mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
327                    y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
328            # test thnn impl
329            conv_lower = copy.deepcopy(conv).to(dtype=dtype)
330            conv_ref = copy.deepcopy(conv_lower).float()
331            with torch.backends.mkldnn.flags(enabled=False):
332                x_ref = x_lower.clone().float().detach().requires_grad_()
333                x_lower.requires_grad_()
334                y = conv_ref(x_ref)
335                y_lower = conv_lower(x_lower).float()
336                self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3)
337
338    @dtypes(torch.float16, torch.bfloat16)
339    def test_conv_deconv_1d_lower_precision(self, dtype):
340        self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype)
341        self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype)
342
343    @dtypes(torch.float16, torch.bfloat16)
344    def test_conv_deconv_2d_lower_precision(self, dtype):
345        self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype)
346        self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype)
347
348    @dtypes(torch.float16, torch.bfloat16)
349    def test_conv_deconv_3d_lower_precision(self, dtype):
350        self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype)
351        self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype)
352
353    def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
354        input_shapes = {2: (55, 55), 3: (14, 14, 14)}
355        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
356        if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
357            cl_format = torch.channels_last
358            input_shape = input_shapes[2]
359        elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
360            cl_format = torch.channels_last_3d
361            input_shape = input_shapes[3]
362
363        for train, bias, dilation, groups in options:
364            N = torch.randint(3, 10, (1,)).item()
365            M = torch.randint(1, 3, (1,)).item() * groups
366            C = torch.randint(1, 3, (1,)).item() * groups
367            x_shape = (N, C) + input_shape
368            x = torch.randn(x_shape, dtype=dtype)
369
370            # conv1: mkldnn conv/deconv in contiguous memory format (nchw)
371            # conv2: mkldnn conv/deconv in channels last memory format (nhwc)
372            conv1 = conv_module(in_channels=C,
373                                out_channels=M,
374                                kernel_size=3,
375                                stride=2,
376                                padding=1,
377                                dilation=dilation,
378                                bias=bias,
379                                groups=groups).to(dtype=dtype)
380            conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
381            x1 = x.clone()
382            x2 = x.clone().to(memory_format=cl_format)
383            if train:
384                x1.requires_grad_()
385                x2.requires_grad_()
386            y1 = conv1(x1)
387            y2 = conv2(x2)
388            self.assertEqual(y1, y2, atol=prec, rtol=prec)
389
390            if train:
391                y1.sum().backward()
392                y2.sum().backward()
393                self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
394                self.assertEqual(conv1.weight.grad,
395                                 conv2.weight.grad,
396                                 atol=1e-3,
397                                 rtol=1e-3)
398                if bias:
399                    self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
400                self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
401
402    def test_conv_nhwc_fp32(self):
403        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
404        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
405        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
406        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
407
408    @dtypes(torch.float16, torch.bfloat16)
409    def test_conv_nhwc_lower_precision(self, dtype):
410        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
411        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
412        support_checks = {
413            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
414            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
415        }
416        if support_checks[dtype]():
417            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype)
418            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype)
419            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype)
420            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype)
421
422        # BF16/FP16 fallback implementations are divided into two parts im2col+gemm,
423        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
424        # resulting in additional accuracy loss.
425        precisions = {
426            torch.bfloat16: 1e-2,
427            torch.float16: 2e-3,
428        }
429        prec = precisions[dtype]
430        with torch.backends.mkldnn.flags(enabled=False):
431            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec)
432            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec)
433            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec)
434            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
435
436
437    def test_conv_transpose_nhwc_fp32(self):
438        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
439        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
440        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
441        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
442
443    @dtypes(torch.float16, torch.bfloat16)
444    def test_conv_transpose_nhwc_lower_precision(self, dtype):
445        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
446        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
447        support_checks = {
448            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
449            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
450        }
451        if support_checks[dtype]():
452            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype)
453            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype)
454            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype)
455            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype)
456
457        # BF16/FP16 fallback implementations are divided into two parts col2im+gemm,
458        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
459        # resulting in additional accuracy loss.
460        precisions = {
461            torch.bfloat16: 2e-2,
462            torch.float16: 3e-3,
463        }
464        prec = precisions[dtype]
465        with torch.backends.mkldnn.flags(enabled=False):
466            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec)
467            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec)
468            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec)
469            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec)
470
471    def _test_conv_transpose_base(self, dim):
472        conv_module = {
473            1: torch.nn.ConvTranspose1d,
474            2: torch.nn.ConvTranspose2d,
475            3: torch.nn.ConvTranspose3d
476        }
477        input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)}
478        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
479        for train, bias, dilation, groups in options:
480            N = torch.randint(3, 10, (1,)).item()
481            M = torch.randint(1, 3, (1,)).item() * groups
482            C = torch.randint(1, 3, (1,)).item() * groups
483            x_shape = (N, C) + input_shapes[dim]
484            data = torch.randn(x_shape, dtype=torch.float32)
485            # conv: mkldnn tranpose conv fp32
486            # conv_ref: thnn transpose conv fp32
487            conv = conv_module[dim](in_channels=C,
488                                    out_channels=M,
489                                    kernel_size=3,
490                                    stride=1,
491                                    padding=1,
492                                    dilation=dilation,
493                                    bias=bias,
494                                    groups=groups).to(dtype=torch.float32)
495            x = data.clone()
496            x_ref = x.clone()
497            if train:
498                x.requires_grad_()
499                x_ref.requires_grad_()
500
501            conv_ref = copy.deepcopy(conv)
502            with torch.backends.mkldnn.flags(enabled=False):
503                y_ref = conv_ref(x_ref)
504                if train:
505                    y_ref.sum().backward()
506
507            y = conv(x)
508            if train:
509                y.sum().backward()
510
511            self.assertEqual(y, y_ref)
512            if train:
513                self.assertEqual(x.grad, x_ref.grad)
514                self.assertEqual(conv.weight.grad,
515                                 conv_ref.weight.grad,
516                                 atol=1e-3,
517                                 rtol=1e-3)
518                if bias:
519                    self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
520
521    def test_conv_transpose1d(self):
522        self._test_conv_transpose_base(dim=1)
523
524    def test_conv_transpose2d(self):
525        self._test_conv_transpose_base(dim=2)
526
527    def test_conv_transpose3d(self):
528        self._test_conv_transpose_base(dim=3)
529
530    def test_conv2d_legacy_jit_model(self):
531        """
532        MKLDNN integration used to serialize models with 5d weight for grouped
533        convolutions, we'd like to preserve this behavior
534        """
535        g = 4
536        conv2d = torch.nn.Conv2d(16, 16, 3, groups=g)
537        conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d)
538
539        # contrive legacy conv2d module with a 5-d weight
540        o, i, h, w = conv2d.weight.shape
541        weight_5d = conv2d.weight.reshape((g, o // g, i, h, w))
542        conv2d_mkldnn.weight = weight_5d.to_mkldnn()
543
544        x = torch.randn(1, 16, 8, 8)
545
546        with TemporaryFileName() as fname:
547            torch.jit.save(conv2d_mkldnn, fname)
548            conv2d_loaded = torch.jit.load(fname)
549
550            self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5)
551            self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
552            self.assertEqual(
553                conv2d(x),
554                conv2d_loaded(x.to_mkldnn()).to_dense())
555
556    # This test is to check whether 1D conv is supported for mkldnn tensor,
557    # which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034.
558    def test_conv1d_functional(self):
559        input = torch.randn(2, 3, 10).to_mkldnn()
560        weight = torch.randn(3, 3, 3).to_mkldnn()
561        bias = torch.randn(3).to_mkldnn()
562        output = torch.nn.functional.conv1d(input, weight, bias)
563        self.assertEqual(output.size(), torch.Size([2, 3, 8]))
564
565    def test_relu(self):
566        x = torch.randn((4, 5), dtype=torch.float32) * 10
567        x1 = x.clone().requires_grad_()
568        x2 = x.clone().to_mkldnn().requires_grad_()
569        y1 = torch.relu(x1)
570        y2 = torch.relu(x2).to_dense()
571        loss1 = y1.sum()
572        loss2 = y2.sum()
573        loss1.backward()
574        loss2.backward()
575        self.assertEqual(y1, y2)
576        self.assertEqual(x1.grad, x2.grad.to_dense())
577
578    def test_relu_(self):
579        x = torch.randn((4, 5), dtype=torch.float32) * 10
580        x1 = x.clone().requires_grad_()
581        x2 = x.clone().to_mkldnn().requires_grad_()
582        y1 = torch.relu_(x1.clone())
583        y2 = torch.relu_(x2.clone()).to_dense()
584        loss1 = y1.sum()
585        loss2 = y2.sum()
586        loss1.backward()
587        loss2.backward()
588        self.assertEqual(y1, y2)
589        self.assertEqual(x1.grad, x2.grad.to_dense())
590
591    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
592    def _test_relu_bf16_base(self, name):
593        x = torch.randn((4, 5), dtype=torch.float32) * 10
594        x_bf16 = x.bfloat16()
595        fn = getattr(torch, name)
596        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
597            y = fn(x.to_mkldnn()).to_dense()
598            y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32)
599            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
600        else:
601            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
602            self.assertRaisesRegex(RuntimeError,
603                                   msg,
604                                   lambda: fn(x_bf16.to_mkldnn()))
605
606    def test_relu_bf16(self):
607        self._test_relu_bf16_base("relu")
608
609    def test_relu_inplace_bf16(self):
610        self._test_relu_bf16_base("relu_")
611
612    def test_gelu(self):
613        m = torch.nn.GELU()
614        x = torch.randn((4, 5), dtype=torch.float32) * 10
615        x1 = x.clone().requires_grad_()
616        x2 = x.clone().to_mkldnn().requires_grad_()
617        y1 = m(x1)
618        y2 = m(x2).to_dense()
619        loss1 = y1.sum()
620        loss2 = y2.sum()
621        loss1.backward()
622        loss2.backward()
623        self.assertEqual(y1, y2)
624        self.assertEqual(x1.grad, x2.grad.to_dense())
625
626    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
627    def test_gelu_bf16(self):
628        m = torch.nn.GELU()
629        x = torch.randn((4, 5), dtype=torch.float32) * 10
630        x1 = x.clone().to_mkldnn().requires_grad_()
631        x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
632        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
633            y1 = m(x1).to_dense()
634            y2 = m(x2).to_dense()
635            loss1 = y1.sum()
636            loss2 = y2.sum()
637            loss1.backward()
638            loss2.backward()
639            self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0)
640            self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0)
641        else:
642            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
643            self.assertRaisesRegex(RuntimeError,
644                                   msg,
645                                   lambda: m(x2))
646
647    def _test_prelu_base(self, size, num_channels):
648        x = torch.randn(size, dtype=torch.float32)
649        x1 = x.clone().requires_grad_()
650        x2 = x.clone().to_mkldnn().requires_grad_()
651        x3 = x.clone().to_mkldnn().requires_grad_()
652        m1 = torch.nn.PReLU(num_channels)
653        m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
654        m3 = copy.deepcopy(m1)
655        y1 = m1(x1)
656        y2 = m2(x2).to_dense()
657        y3 = m3(x3).to_dense()  # Only convert data to mkldnn, weight is Aten tensor
658        loss1 = y1.sum()
659        loss1.backward()
660        loss2 = y2.sum()
661        loss2.backward()
662        loss3 = y3.sum()
663        loss3.backward()
664        self.assertEqual(y1, y2)
665        self.assertEqual(y1, y3)
666        self.assertEqual(x1.grad, x2.grad.to_dense())
667        self.assertEqual(x1.grad, x3.grad.to_dense())
668
669    def test_prelu(self):
670        self._test_prelu_base(torch.Size([16]), 1)
671        self._test_prelu_base(torch.Size([16, 64]), 1)
672        self._test_prelu_base(torch.Size([16, 64]), 64)
673        self._test_prelu_base(torch.Size([16, 64, 112]), 1)
674        self._test_prelu_base(torch.Size([16, 64, 112]), 64)
675        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
676        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
677        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
678        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
679
680    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
681    def _test_prelu_bf16_base(self, size, num_channels):
682        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
683            x = torch.randn(size, dtype=torch.float32)
684            x_fp32 = x.clone().to_mkldnn().requires_grad_()
685            x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
686            m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
687            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
688
689            y = m(x_fp32).to_dense()
690            y_bf16 = m_bf16(x_bf16).to_dense()
691            self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
692
693            loss = y.sum()
694            loss.backward()
695            loss_bf16 = y_bf16.sum()
696            loss_bf16.backward()
697            self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
698        else:
699            x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
700            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
701            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
702            self.assertRaisesRegex(RuntimeError,
703                                   msg,
704                                   lambda: m_bf16(x_bf16))
705
706    def test_prelu_bf16(self):
707        self._test_prelu_bf16_base(torch.Size([16]), 1)
708        self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
709        self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
710        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
711        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
712        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
713        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
714
715    def _test_max_pool_base(self, dim, input):
716        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
717        for stride in [1, 2, 3]:
718            for ceil_mode in [False, True]:
719                max_pool = pool_module[dim](
720                    kernel_size=3 if not ceil_mode else 7,
721                    stride=stride,
722                    padding=1,
723                    ceil_mode=ceil_mode)
724
725                x1 = input.clone().requires_grad_()
726                x2 = input.clone().to_mkldnn().requires_grad_()
727                y1 = max_pool(x1)
728                y2 = max_pool(x2).to_dense()
729                loss1 = y1.sum()
730                loss2 = y2.sum()
731                loss1.backward()
732                loss2.backward()
733                self.assertEqual(y1, y2)
734                self.assertEqual(x1.grad, x2.grad.to_dense())
735
736    def test_max_pool2d(self):
737        N = torch.randint(3, 10, (1,)).item()
738        C = torch.randint(3, 10, (1,)).item()
739        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
740            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
741            self._test_max_pool_base(dim=2, input=x)
742
743    def test_max_pool3d(self):
744        N = torch.randint(3, 10, (1,)).item()
745        C = torch.randint(3, 10, (1,)).item()
746        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
747            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
748            self._test_max_pool_base(dim=3, input=x)
749
750
751    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
752    def _test_max_pool_bf16_base(self, dim, input):
753        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
754        x_bf16 = input.bfloat16()
755        for stride in [1, 2, 3]:
756            for ceil_mode in [False, True]:
757                max_pool = pool_module[dim](
758                    kernel_size=3 if not ceil_mode else 7,
759                    stride=stride,
760                    padding=1,
761                    ceil_mode=ceil_mode)
762
763                if torch.ops.mkldnn._is_mkldnn_bf16_supported():
764                    y = max_pool(input.to_mkldnn()).to_dense()
765                    y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32)
766                    self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3)
767                else:
768                    msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
769                    self.assertRaisesRegex(RuntimeError,
770                                           msg,
771                                           lambda: max_pool(x_bf16.to_mkldnn()))
772
773    def test_max_pool2d_bf16(self):
774        N = torch.randint(3, 10, (1,)).item()
775        C = torch.randint(3, 10, (1,)).item()
776        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
777            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
778            self._test_max_pool_bf16_base(dim=2, input=x)
779
780    def test_max_pool3d_bf16(self):
781        N = torch.randint(3, 10, (1,)).item()
782        C = torch.randint(3, 10, (1,)).item()
783        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
784            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
785            self._test_max_pool_bf16_base(dim=3, input=x)
786
787    def test_max_pool2d_stride_none(self):
788        N = torch.randint(3, 10, (1,)).item()
789        C = torch.randint(3, 10, (1,)).item()
790
791        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
792            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
793            for ceil_mode in [False, True]:
794                y1 = F.max_pool2d(
795                    x,
796                    kernel_size=3 if not ceil_mode else 7,
797                    stride=None,
798                    padding=1,
799                    ceil_mode=ceil_mode)
800
801                y2 = F.max_pool2d(
802                    x.to_mkldnn(),
803                    kernel_size=3 if not ceil_mode else 7,
804                    stride=None,
805                    padding=1,
806                    ceil_mode=ceil_mode)
807
808                self.assertEqual(y1, y2.to_dense())
809
810    # https://github.com/pytorch/pytorch/issues/127111
811    @xfailIfTorchDynamo
812    def test_max_pool_unsupported(self):
813        # OneDNN not support dilation max_pooling, will be avilabled in v2.0.
814        N = torch.randint(3, 10, (1,)).item()
815        C = torch.randint(3, 10, (1,)).item()
816
817        # 2d dilation case
818        x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn()
819        max_pool2d = torch.nn.MaxPool2d(
820            kernel_size=3,
821            stride=3,
822            padding=1,
823            dilation=2)
824        self.assertRaisesRegex(RuntimeError,
825                               'mkldnn_max_pool2d does not support dilation case',
826                               lambda: max_pool2d(x))
827
828        # 3d dilation case
829        x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn()
830        max_pool3d = torch.nn.MaxPool3d(
831            kernel_size=3,
832            stride=3,
833            padding=1,
834            dilation=2)
835        self.assertRaisesRegex(RuntimeError,
836                               'mkldnn_max_pool3d does not support dilation case',
837                               lambda: max_pool3d(x))
838
839    def _test_avg_pool_base(self, dim, input):
840        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
841        for count_include_pad in [True, False]:
842            avg_pool = avg_module[dim](
843                kernel_size=3,
844                stride=2,
845                padding=1,
846                count_include_pad=count_include_pad)
847
848            x1 = input.clone().requires_grad_()
849            x2 = input.clone().to_mkldnn().requires_grad_()
850            y1 = avg_pool(x1)
851            y2 = avg_pool(x2).to_dense()
852            loss1 = y1.sum()
853            loss2 = y2.sum()
854            loss1.backward()
855            loss2.backward()
856            self.assertEqual(y1, y2)
857            self.assertEqual(x1.grad, x2.grad.to_dense())
858
859    def test_avg_pool2d(self):
860        N = torch.randint(3, 10, (1,)).item()
861        C = torch.randint(3, 10, (1,)).item()
862        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
863        self._test_avg_pool_base(dim=2, input=x)
864
865    def test_avg_pool3d(self):
866        N = torch.randint(3, 10, (1,)).item()
867        C = torch.randint(3, 10, (1,)).item()
868        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
869        self._test_avg_pool_base(dim=3, input=x)
870
871    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
872    def _test_avg_pool_bf16_base(self, dim, input):
873        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
874        x_bf16 = input.bfloat16()
875        for count_include_pad in [True, False]:
876            avg_pool = avg_module[dim](
877                kernel_size=3,
878                stride=2,
879                padding=1,
880                count_include_pad=count_include_pad)
881            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
882                y = avg_pool(input.to_mkldnn()).to_dense()
883                y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float)
884                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
885            else:
886                msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
887                self.assertRaisesRegex(RuntimeError,
888                                       msg,
889                                       lambda: avg_pool(x_bf16.to_mkldnn()))
890
891    def test_avg_pool2d_bf16(self):
892        N = torch.randint(3, 10, (1,)).item()
893        C = torch.randint(3, 10, (1,)).item()
894        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
895        self._test_avg_pool_bf16_base(dim=2, input=x)
896
897    def test_avg_pool3d_bf16(self):
898        N = torch.randint(3, 10, (1,)).item()
899        C = torch.randint(3, 10, (1,)).item()
900        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
901        self._test_avg_pool_bf16_base(dim=3, input=x)
902
903    def test_avg_pool2d_stride_none(self):
904        N = torch.randint(3, 10, (1,)).item()
905        C = torch.randint(3, 10, (1,)).item()
906        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
907
908        for count_include_pad in [True, False]:
909            y1 = F.avg_pool2d(
910                x,
911                kernel_size=3,
912                stride=None,
913                padding=1,
914                count_include_pad=count_include_pad)
915            y2 = F.avg_pool2d(
916                x.to_mkldnn(),
917                kernel_size=3,
918                stride=None,
919                padding=1,
920                count_include_pad=count_include_pad)
921
922            self.assertEqual(y1, y2.to_dense())
923
924    def test_adaptive_avg_pool2d(self):
925        N = torch.randint(3, 10, (1,)).item()
926        C = torch.randint(3, 10, (1,)).item()
927        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
928
929        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
930        x1 = x.clone().requires_grad_()
931        x2 = x.clone().to_mkldnn().requires_grad_()
932        y1 = adaptive_avg_pool2d(x1)
933        y2 = adaptive_avg_pool2d(x2).to_dense()
934
935        loss1 = y1.sum()
936        loss2 = y2.sum()
937        loss1.backward()
938        loss2.backward()
939
940        self.assertEqual(y1, y2)
941        self.assertEqual(x1.grad, x2.grad.to_dense())
942
943    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
944    def test_adaptive_avg_pool2d_bf16(self):
945        N = torch.randint(3, 10, (1,)).item()
946        C = torch.randint(3, 10, (1,)).item()
947        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
948
949        x_bf16 = x.bfloat16()
950        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
951
952        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
953            y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense()
954            y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32)
955            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
956        else:
957            msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
958            self.assertRaisesRegex(RuntimeError,
959                                   msg,
960                                   lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn()))
961
962    def _test_batch_norm_base(self, dim, channels, input):
963        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
964        bn = bn_module[dim](channels).float().train(False)
965        mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
966        self.assertEqual(
967            bn(input),
968            mkldnn_bn(input.to_mkldnn()).to_dense())
969
970        self._test_serialization(mkldnn_bn, (input.to_mkldnn(),))
971        self._test_tracing(mkldnn_bn, (input.to_mkldnn(),))
972
973    def _test_batch_norm_train_base(self, dim, channels, input):
974        # TODO: support 3d batchnorm training.
975        bn_module = {2 : torch.nn.BatchNorm2d}
976        # TODO: support none affine.
977        options = itertools.product([True], [True, False])
978        for affine, track_running_stats in options:
979            bn = bn_module[dim](
980                num_features=channels,
981                affine=affine,
982                track_running_stats=track_running_stats).float().train(True)
983            mkldnn_bn = copy.deepcopy(bn)
984            x1 = input.clone().requires_grad_()
985            x2 = input.clone().to_mkldnn().requires_grad_()
986            y1 = bn(x1)
987            y2 = mkldnn_bn(x2).to_dense()
988            loss1 = y1.sum()
989            loss2 = y2.sum()
990            loss1.backward()
991            loss2.backward()
992            self.assertEqual(y1, y2)
993            self.assertEqual(x1.grad, x2.grad.to_dense())
994            self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3)
995            if track_running_stats:
996                self.assertEqual(bn.running_mean, mkldnn_bn.running_mean)
997                self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5)
998
999    def test_batch_norm_2d(self):
1000        N = torch.randint(3, 10, (1,)).item()
1001        C = torch.randint(3, 100, (1,)).item()
1002        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1003        self._test_batch_norm_base(dim=2, channels=C, input=x)
1004        self._test_batch_norm_train_base(dim=2, channels=C, input=x)
1005
1006    def test_batch_norm_3d(self):
1007        N = torch.randint(3, 10, (1,)).item()
1008        C = torch.randint(3, 100, (1,)).item()
1009        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1010        self._test_batch_norm_base(dim=3, channels=C, input=x)
1011
1012    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1013    def _test_batch_norm_bf16_base(self, dim, channels, input):
1014        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
1015        x_bf16 = input.bfloat16()
1016        # TODO: support training
1017        for train in [False]:
1018            bn = bn_module[dim](channels).float().train(train)
1019            mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
1020            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
1021                y = bn(input.to_mkldnn().to_dense())
1022                y_bf16 = bn(input.to_mkldnn().to_dense(torch.float))
1023                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
1024            else:
1025                msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
1026                self.assertRaisesRegex(RuntimeError,
1027                                       msg,
1028                                       lambda: bn(x_bf16.to_mkldnn()))
1029
1030    def test_batch_norm_2d_bf16(self):
1031        N = torch.randint(3, 10, (1,)).item()
1032        C = torch.randint(3, 100, (1,)).item()
1033        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1034        self._test_batch_norm_bf16_base(dim=2, channels=C, input=x)
1035
1036    def test_batch_norm_3d_bf16(self):
1037        N = torch.randint(3, 10, (1,)).item()
1038        C = torch.randint(3, 100, (1,)).item()
1039        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1040        self._test_batch_norm_bf16_base(dim=3, channels=C, input=x)
1041
1042    def test_add(self):
1043        N = torch.randint(3, 10, (1,)).item()
1044        C = torch.randint(3, 100, (1,)).item()
1045        alpha = torch.randn(1, dtype=torch.float32).item()
1046
1047        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1048        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1049        mx = x.to_mkldnn()
1050        my = y.to_mkldnn()
1051
1052        # add
1053        self.assertEqual(
1054            x + y,
1055            (mx + my).to_dense())
1056
1057        self.assertEqual(
1058            torch.add(x, y, alpha=alpha),
1059            torch.add(mx, my, alpha=alpha).to_dense())
1060
1061        # add_
1062        x += y
1063        mx += my
1064        self.assertEqual(x, mx.to_dense())
1065
1066        # add_out
1067        out = x.clone()
1068        mkldnn_out = out.to_mkldnn()
1069        torch.add(x, y, alpha=alpha, out=out)
1070        torch.add(mx, my, alpha=alpha, out=mkldnn_out)
1071        self.assertEqual(out, mkldnn_out.to_dense())
1072
1073        # add_out inplace case: first input
1074        torch.add(x, y, alpha=alpha, out=x)
1075        torch.add(mx, my, alpha=alpha, out=mx)
1076        self.assertEqual(x, mx.to_dense())
1077
1078        # add_out inplace case: second input
1079        torch.add(x, y, alpha=alpha, out=y)
1080        torch.add(mx, my, alpha=alpha, out=my)
1081        self.assertEqual(y, my.to_dense())
1082
1083    def test_mul(self):
1084        N = torch.randint(3, 10, (1,)).item()
1085        C = torch.randint(3, 100, (1,)).item()
1086        value = torch.randn(1, dtype=torch.float32).item()
1087
1088        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1089        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1090        mx = x.to_mkldnn()
1091        my = y.to_mkldnn()
1092
1093        # mul
1094        self.assertEqual(
1095            x * y,
1096            (mx * my).to_dense())
1097
1098        self.assertEqual(
1099            x * value,
1100            (mx * value).to_dense())
1101
1102        self.assertEqual(
1103            torch.mul(x, y),
1104            torch.mul(mx, my).to_dense())
1105
1106        self.assertEqual(
1107            torch.mul(x, value),
1108            torch.mul(mx, value).to_dense())
1109
1110        # mul_
1111        x *= y
1112        mx *= my
1113        self.assertEqual(x, mx.to_dense())
1114
1115        x *= value
1116        mx *= value
1117        self.assertEqual(x, mx.to_dense())
1118
1119        # mul_out
1120        out = x.clone()
1121        mkldnn_out = out.to_mkldnn()
1122        torch.mul(x, y, out=out)
1123        torch.mul(mx, my, out=mkldnn_out)
1124        self.assertEqual(out, mkldnn_out.to_dense())
1125
1126        out = x.clone()
1127        mkldnn_out = out.to_mkldnn()
1128        torch.mul(x, value, out=out)
1129        torch.mul(mx, value, out=mkldnn_out)
1130        self.assertEqual(out, mkldnn_out.to_dense())
1131
1132    def test_0_dimension_tensor(self):
1133        x = torch.rand([20, 20, 1, 1], dtype=torch.float)
1134        y = torch.rand([20, 20, 0, 1], dtype=torch.float)
1135
1136        # unary ops work without modification
1137        out_relu = torch.relu(y)
1138        out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense()
1139        self.assertEqual(out_relu, out_relu_mkldnn)
1140
1141        out_mul = x * y
1142        out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense()
1143        self.assertEqual(out_mul, out_mul_mkldnn)
1144
1145        out_add = x + y
1146        out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense()
1147        self.assertEqual(out_add, out_add_mkldnn)
1148
1149        x.requires_grad_(True)
1150        y.requires_grad_(True)
1151        with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"):
1152            x.to_mkldnn() + y.to_mkldnn()
1153
1154        with self.assertRaisesRegex(RuntimeError, "must match"):
1155            torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn()
1156
1157        C = 7
1158        m = torch.nn.Conv2d(C, C, 3)
1159        x = torch.randn(0, C, C, 8, dtype=torch.float)
1160        out_eager = m(x)
1161        out_mkldnn = mkldnn_utils.to_mkldnn(m)(x)
1162        self.assertEqual(out_eager, out_mkldnn)
1163
1164    # https://github.com/pytorch/pytorch/issues/127111
1165    @xfailIfTorchDynamo
1166    def test_view(self):
1167        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1168        self.assertRaisesRegex(RuntimeError,
1169                               "Change to use reshape",
1170                               lambda: x.view(x.size(0), -1))
1171
1172    def test_reshape(self):
1173        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1174        size = (x.size(0), -1)
1175
1176        self.assertEqual(
1177            x.reshape(size),
1178            x.to_mkldnn().reshape(size).to_dense(),
1179        )
1180        # test whether share same memory for plain format tensor
1181        y = x.to_mkldnn()
1182        z = y.reshape(size).add_(y.reshape(size))
1183        self.assertEqual(
1184            y.reshape(size).to_dense(),
1185            z.to_dense(),
1186        )
1187
1188    def test_reshape_blocked_format(self):
1189        # construct an mkldnn blocked tensor with mkldnn conv2d
1190        C = 7
1191        m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3))
1192        x = torch.randn(1, C, 8, 8).to_mkldnn()
1193
1194        # mkldnn tensor w/ blocked format
1195        y_block = m(x)
1196        # aten tensor w/ plain format
1197        y_plain = y_block.to_dense()
1198
1199        y_block_reshape = y_block.reshape(C, -1)
1200        y_plain_reshape = y_plain.reshape(C, -1)
1201
1202        self.assertEqual(y_plain_reshape, y_block_reshape.to_dense())
1203
1204    def test_reshape_backward(self):
1205        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1206        size = (x.size(0), -1)
1207
1208        x1 = x.clone().requires_grad_()
1209        x2 = x.clone().to_mkldnn().requires_grad_()
1210        in_features = 20
1211        out_features = torch.randint(3, 100, (1,)).item()
1212        linear = torch.nn.Linear(in_features, out_features).float()
1213
1214        y1 = linear(x1.reshape(size)).sum()
1215        y2 = linear(x2.reshape(size).to_dense()).sum()
1216        y1.backward()
1217        y2.backward()
1218        self.assertEqual(x1.grad, x2.grad.to_dense())
1219
1220    def test_clone(self):
1221        x = torch.randn(4, 5, dtype=torch.float32) * 10
1222        self.assertEqual(
1223            x.clone(),
1224            x.to_mkldnn().clone().to_dense(),
1225        )
1226        # test whether share same memory
1227        y = x.to_mkldnn()
1228        z = y.clone().add_(y)
1229        self.assertNotEqual(
1230            y.to_dense(),
1231            z.to_dense(),
1232        )
1233
1234    def test_transpose(self):
1235        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1236        for dim1 in range(x.ndim):
1237            for dim2 in range(x.ndim):
1238                self.assertEqual(
1239                    x.transpose(dim1, dim2),
1240                    x.to_mkldnn().transpose(dim1, dim2).to_dense(),
1241                )
1242
1243    def test_transpose_invalid_dime(self):
1244        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1245        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1246            torch._mkldnn_transpose(x, 0, 12)
1247
1248    def test_linear_non_contiguous_weight(self):
1249        in_features = torch.randint(3, 10, (1,)).item()
1250        out_features = torch.randint(3, 100, (1,)).item()
1251        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1252        w = torch.randn(in_features, out_features, dtype=torch.float32)
1253        for bias in [True, False]:
1254            x1 = x.clone().requires_grad_()
1255            x2 = x.clone().to_mkldnn().requires_grad_()
1256            linear = torch.nn.Linear(in_features, out_features).float()
1257            linear.weight = torch.nn.Parameter(w.t())
1258            mkldnn_linear = copy.deepcopy(linear)
1259            y1 = linear(x1).sum()
1260            y2 = mkldnn_linear(x2).to_dense().sum()
1261            y1.backward()
1262            y2.backward()
1263            self.assertEqual(x1.grad, x2.grad.to_dense())
1264            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1265            if bias:
1266                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1267
1268    def test_linear(self):
1269        in_features = torch.randint(3, 10, (1,)).item()
1270        out_features = torch.randint(3, 100, (1,)).item()
1271        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1272
1273        for bias in [True, False]:
1274            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1275            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1276            self.assertEqual(
1277                linear(x),
1278                mkldnn_linear(x.to_mkldnn()).to_dense())
1279
1280            self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
1281            self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
1282
1283    def test_linear_backward(self):
1284        in_features = torch.randint(3, 10, (1,)).item()
1285        out_features = torch.randint(3, 100, (1,)).item()
1286        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1287        for bias in [True, False]:
1288            x1 = x.clone().requires_grad_()
1289            x2 = x.clone().to_mkldnn().requires_grad_()
1290            linear = torch.nn.Linear(in_features, out_features).float()
1291            mkldnn_linear = copy.deepcopy(linear)
1292            y1 = linear(x1).sum()
1293            y2 = mkldnn_linear(x2).to_dense().sum()
1294            y1.backward()
1295            y2.backward()
1296            self.assertEqual(x1.grad, x2.grad.to_dense())
1297            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1298            if bias:
1299                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1300
1301    @dtypes(torch.float16, torch.bfloat16)
1302    def test_linear_lowp(self, dtype):
1303        in_features = torch.randint(3, 10, (1,)).item()
1304        out_features = torch.randint(3, 100, (1,)).item()
1305        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1306        x_lowp = x.to(dtype=dtype)
1307
1308        for bias in [True, False]:
1309            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1310            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1311            mkldnn_linear_lowp = mkldnn_utils.to_mkldnn(
1312                copy.deepcopy(linear), dtype
1313            )
1314            lowp_support = {
1315                torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1316                torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1317            }
1318            if lowp_support[dtype]():
1319                y = mkldnn_linear(x.to_mkldnn()).to_dense()
1320                y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense(
1321                    torch.float32
1322                )
1323                if dtype == torch.bfloat16:
1324                    self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3)
1325                else:
1326                    self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3)
1327            else:
1328                msg = {
1329                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
1330                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
1331                }
1332                self.assertRaisesRegex(
1333                    RuntimeError,
1334                    msg[dtype],
1335                    lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()),
1336                )
1337
1338    def test_softmax(self):
1339        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1340        for dim in range(x.ndim):
1341            softmax = torch.nn.Softmax(dim=dim)
1342            self.assertEqual(
1343                softmax(x),
1344                softmax(x.to_mkldnn()).to_dense())
1345
1346    def test_sigmoid(self):
1347        x = torch.randn(4, 5, dtype=torch.float32) * 10
1348        mkldnn_x = x.to_mkldnn()
1349        self.assertEqual(
1350            torch.sigmoid(x),
1351            torch.sigmoid(mkldnn_x).to_dense(),
1352        )
1353        # inplace
1354        torch.sigmoid_(x)
1355        torch.sigmoid_(mkldnn_x)
1356        self.assertEqual(x, mkldnn_x.to_dense())
1357
1358    def test_tanh(self):
1359        x = torch.randn(4, 5, dtype=torch.float32) * 10
1360        mkldnn_x = x.to_mkldnn()
1361        self.assertEqual(
1362            torch.tanh(x),
1363            torch.tanh(mkldnn_x).to_dense(),
1364        )
1365        # inplace
1366        torch.tanh_(x)
1367        torch.tanh_(mkldnn_x)
1368        self.assertEqual(x, mkldnn_x.to_dense())
1369
1370    def _test_serialization(self, module, inputs):
1371        with TemporaryFileName() as fname:
1372            torch.jit.save(module, fname)
1373            loaded = torch.jit.load(fname)
1374            self.assertEqual(
1375                module(*inputs).to_dense(),
1376                loaded(*inputs).to_dense())
1377
1378    def _test_tracing(self, module, inputs):
1379        traced = torch.jit.trace(module, inputs)
1380        self.assertEqual(
1381            module(*inputs).to_dense(),
1382            traced(*inputs).to_dense())
1383
1384    def test_set_data_tensorimpl_type(self):
1385        # Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
1386        # of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
1387        x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
1388        x_mkldnn = x.to_mkldnn()
1389        with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'):
1390            x.data = x_mkldnn
1391
1392    def test_empty(self):
1393        x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32)
1394        x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn)
1395        self.assertEqual(x1.size(), x2.to_dense().size())
1396        self.assertEqual(x1.dtype, x2.to_dense().dtype)
1397
1398    def test_zero_(self):
1399        x1 = torch.randn(4, 5, dtype=torch.float32) * 10
1400        x2 = x1.clone().to_mkldnn()
1401        self.assertEqual(
1402            x1.zero_(),
1403            x2.zero_().to_dense(),
1404        )
1405
1406    def test_is_mkldnn(self):
1407        x = torch.randn(1, dtype=torch.float32)
1408        self.assertFalse(x.is_mkldnn)
1409        self.assertTrue(x.to_mkldnn().is_mkldnn)
1410
1411    # legacy constructor/new doesn't support mkldnn tensors
1412    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
1413    def test_legacy_new_failure(self):
1414        x = torch.randn(1, dtype=torch.float32)
1415        x_mkldnn = x.to_mkldnn()
1416        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu'))
1417        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage()))
1418        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x))
1419        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3])))
1420        self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6]))
1421
1422    def test_is_mkldnn_jit(self):
1423        class EnsureMkldnn(torch.jit.ScriptModule):
1424            @torch.jit.script_method
1425            def forward(self, x):
1426                if not x.is_mkldnn:
1427                    x = x.to_mkldnn()
1428                return x
1429
1430        m = EnsureMkldnn()
1431        x = torch.randn(1, dtype=torch.float32)
1432        self.assertTrue(m(x).is_mkldnn)
1433        self.assertTrue(m(x.to_mkldnn()).is_mkldnn)
1434
1435    def _test_imagenet_model(self, model):
1436        model = model.train(False).float()
1437        mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model))
1438        x = torch.randn(1, 3, 224, 224, dtype=torch.float32)
1439        with torch.no_grad():
1440            self.assertEqual(
1441                model(x),
1442                mkldnn_model(x.to_mkldnn()).to_dense(),
1443            )
1444
1445    @skipIfNoTorchVision
1446    def test_resnet18(self):
1447        model = torchvision.models.resnet.resnet18(weights=None)
1448        self._test_imagenet_model(model)
1449
1450    @skipIfNoTorchVision
1451    def test_resnext50_32x4d(self):
1452        model = torchvision.models.resnet.resnext50_32x4d(weights=None)
1453        self._test_imagenet_model(model)
1454
1455    def _lstm_params_list(self):
1456        params_dict = {
1457            "input_size": [1, 5],
1458            "hidden_size": [5, 16],
1459            "num_layers": [1, 3],
1460            "bidirectional": [False, True],
1461            "bias": [False, True],
1462            "batch_first": [False, True],
1463            "dropout": [0, 0.4, 0.7, 1],
1464            "batch_size": [1, 2],
1465            "seq_len": [1, 3],
1466            "training": [False, True]
1467        }
1468
1469        params_list = list(params_dict.values())
1470        return params_list
1471
1472    def _cast_dtype(self, input, dtype):
1473        if dtype == torch.bfloat16:
1474            input = input.to(torch.bfloat16)
1475        elif dtype == torch.half:
1476            input = input.to(torch.half)
1477        return input
1478
1479    def test_lstm(self):
1480        seed = 2023
1481        torch.manual_seed(seed)
1482
1483        params_list = self._lstm_params_list()
1484        for dtype in types:
1485            bf16 = dtype == torch.bfloat16
1486            fp16 = dtype == torch.half
1487            rtol = 1.3e-6
1488            atol = 1e-5
1489
1490            if bf16:
1491                rtol = 0.02
1492                atol = 0.02
1493            if fp16:
1494                rtol = 1e-3
1495                atol = 1e-3
1496            for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
1497                    in itertools.product(*params_list):
1498                num_directions = 2 if bidirectional else 1
1499                if batch_first:
1500                    input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32)
1501                else:
1502                    input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
1503                h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1504                c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1505                if fp16:
1506                    # TODO add traing support when oneDNN support lstm FP16 training
1507                    training = False
1508                model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
1509                                      bias=bias, dropout=dropout, batch_first=batch_first).float()
1510                model.train() if training else model.eval()
1511                input1 = input.clone().requires_grad_(training)
1512                input2 = input.clone().requires_grad_(training)
1513
1514                h1 = h.clone().requires_grad_(training)
1515                h2 = h.clone().requires_grad_(training)
1516                c1 = c.clone().requires_grad_(training)
1517                c2 = c.clone().requires_grad_(training)
1518
1519                model1 = copy.deepcopy(model)
1520                model2 = copy.deepcopy(model)
1521                with torch.no_grad() if not training else nullcontext():
1522                    with torch.backends.mkldnn.flags(enabled=False):
1523                        torch.manual_seed(seed)
1524                        output1, (hn1, cn1) = self._cast_dtype(model1, dtype)(
1525                            self._cast_dtype(input1, dtype),
1526                            (
1527                                self._cast_dtype(h1, dtype),
1528                                self._cast_dtype(c1, dtype),
1529                            ),
1530                        )
1531
1532                    torch.manual_seed(seed)
1533                    output2, (hn2, cn2) = self._cast_dtype(model2, dtype)(
1534                        self._cast_dtype(input2, dtype),
1535                        (
1536                            self._cast_dtype(h2, dtype),
1537                            self._cast_dtype(c2, dtype),
1538                        ),
1539                    )
1540                    self.assertEqual(output1, output2, rtol=rtol, atol=atol)
1541                    self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
1542                    self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
1543
1544                    if training:
1545                        with torch.backends.mkldnn.flags(enabled=False):
1546                            torch.manual_seed(seed)
1547                            output1.sum().backward(retain_graph=True)
1548
1549                        torch.manual_seed(seed)
1550                        output2.sum().backward(retain_graph=True)
1551
1552                        self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
1553                        for name, para in model1.named_parameters():
1554                            self.assertEqual(para, getattr(model2, name))
1555                            self.assertEqual(
1556                                para.grad,
1557                                getattr(model2, name).grad,
1558                                rtol=rtol,
1559                                atol=atol,
1560                            )
1561
1562                        with torch.backends.mkldnn.flags(enabled=False):
1563                            torch.manual_seed(seed)
1564                            hn1.sum().backward(retain_graph=True)
1565                        torch.manual_seed(seed)
1566                        hn2.sum().backward(retain_graph=True)
1567                        self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol)
1568
1569                        with torch.backends.mkldnn.flags(enabled=False):
1570                            torch.manual_seed(seed)
1571                            cn1.sum().backward(retain_graph=True)
1572                        torch.manual_seed(seed)
1573                        cn2.sum().backward(retain_graph=True)
1574                        self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol)
1575
1576    @dtypes(torch.float16, torch.bfloat16)
1577    def test_matmul_lower_precision(self, dtype):
1578        support_check = {
1579            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1580            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1581        }
1582
1583        def common(self, shape1, shape2, op, dtype):
1584            a = torch.randn(shape1, dtype=dtype)
1585            a_ref = a.float()
1586            b = torch.randn(shape2, dtype=dtype)
1587            b_ref = b.float()
1588
1589            y = op(a, b)
1590            y_ref = op(a_ref, b_ref)
1591            self.assertEqual(y, y_ref, exact_dtype=False)
1592
1593        if support_check[dtype]():
1594            a1 = torch.randn([64, 1, 33], dtype=dtype)
1595            # a2 is contiguous tensor but it's strides
1596            # is not default contiguous strides.
1597            a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1])
1598            self.assertTrue(a2.is_contiguous())
1599            b = torch.randn(64, 33, 256).to(dtype=dtype)
1600            y1 = torch.ops.aten.bmm(a1, b)
1601            y2 = torch.bmm(a2, b)
1602            self.assertEqual(y1, y2)
1603
1604            for shape1, shape2, op in [
1605                ((33, 77), (77, 22), torch.matmul),
1606                ((128, 256), (256, 10), torch.matmul),
1607                ((7, 300), (300, 3), torch.matmul),
1608                ((1, 100), (100, 60), torch.matmul),
1609                ((100, 1), (1, 100), torch.matmul),
1610                ((20, 54, 78), (20, 78, 10), torch.bmm),
1611                ((1, 300, 1), (1, 1, 300), torch.bmm),
1612            ]:
1613                common(self, shape1, shape2, op, dtype)
1614
1615
1616instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
1617
1618if __name__ == '__main__':
1619    run_tests()
1620