xref: /aosp_15_r20/external/pytorch/test/nn/test_convolution.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import itertools
3import math
4import unittest
5import warnings
6from itertools import product
7
8import torch
9import torch.autograd.forward_ad as fwAD
10import torch.backends.cudnn as cudnn
11import torch.nn as nn
12import torch.nn.functional as F
13from torch.testing import make_tensor
14from torch.testing._internal.common_cuda import (
15    TEST_CUDA,
16    TEST_CUDNN,
17    tf32_is_not_fp32,
18    tf32_on_and_off,
19)
20from torch.testing._internal.common_device_type import (
21    disablecuDNN,
22    disableMkldnn,
23    dtypes,
24    dtypesIfCUDA,
25    instantiate_device_type_tests,
26    largeTensorTest,
27    onlyCPU,
28    onlyCUDA,
29    onlyNativeDeviceTypes,
30    precisionOverride,
31    skipCPUIfNoMkldnn,
32    skipCUDAIfCudnnVersionLessThan,
33    skipCUDAIfMiopen,
34    skipCUDAIfNoCudnn,
35    skipCUDAIfNoMiopen,
36    skipCUDAIfNotMiopenSuggestNHWC,
37    skipCUDAIfRocm,
38    skipCUDAIfRocmVersionLessThan,
39    skipMeta,
40)
41from torch.testing._internal.common_dtype import (
42    floating_and_complex_types_and,
43    floating_types_and,
44)
45from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
46from torch.testing._internal.common_utils import (
47    download_file,
48    dtype2prec_DONTUSE,
49    gradcheck,
50    GRADCHECK_NONDET_TOL,
51    gradgradcheck,
52    instantiate_parametrized_tests,
53    parametrize as parametrize_test,
54    run_tests,
55    set_default_dtype,
56    skipIfNotMiopenSuggestNHWC,
57    skipIfRocmVersionLessThan,
58    subtest,
59    TEST_SCIPY,
60    TEST_WITH_ROCM,
61)
62
63
64AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
65
66
67if TEST_SCIPY:
68    import scipy.ndimage
69    import scipy.signal
70
71
72class TestConvolutionNN(NNTestCase):
73    _do_cuda_memory_leak_check = True
74    _do_cuda_non_default_stream = True
75
76    def test_conv_backcompat(self):
77        from torch.serialization import SourceChangeWarning
78
79        # This file was generated by running on PyTorch 1.0.1 on Python 2:
80        #
81        #     import torch
82        #     from torch import nn
83        #     m = nn.Conv2d(1, 1, 1)
84        #     torch.save(m, 'legacy_conv2d.pt')
85        #
86        # NB: This Pickle also contains some Unicode data!
87        path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
88        with warnings.catch_warnings():
89            warnings.simplefilter("ignore", SourceChangeWarning)
90            # weights_only=False as this is legacy code that saves the model
91            m = torch.load(path, encoding="utf-8", weights_only=False)
92        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
93        self.assertEqual(m(input).size(), (1, 1, 1, 1))
94
95    def test_invalid_conv1d(self):
96        for dtype in [
97            torch.half,
98            torch.bfloat16,
99            torch.float,
100            torch.double,
101            torch.cfloat,
102            torch.cdouble,
103        ]:
104            module = nn.Conv1d(
105                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
106            ).to(dtype)
107            input = torch.randn(1, 3, 4).to(dtype)
108            with self.assertRaisesRegex(
109                RuntimeError,
110                r"Calculated padded input size per channel: \(4\). "
111                + r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size",
112            ):
113                module(input)
114
115            # Negative stride check
116            module = nn.Conv1d(
117                in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True
118            ).to(dtype)
119            input = torch.randn(1, 3, 4).to(dtype)
120            with self.assertRaisesRegex(
121                RuntimeError, "non-positive stride is not supported"
122            ):
123                module(input)
124
125    def test_mismatch_shape_conv2d(self):
126        for dtype in (torch.float, torch.cfloat):
127            x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
128            w = torch.randn(6, 1, 5, 5, dtype=dtype)
129
130            with self.assertRaisesRegex(
131                RuntimeError,
132                r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got "
133                + r"input of size: \[1, 10, 1, 28, 28\]",
134            ):
135                F.conv2d(x, w)
136
137    def test_conv2d_discontiguous_weight(self):
138        for dtype in (torch.float, torch.cfloat):
139            # Test for https://github.com/pytorch/pytorch/issues/55781
140            x = torch.ones(64, 16, 16, 16, dtype=dtype)
141            weight = (
142                torch.arange(0, 1.0, 1 / 2.0**10)
143                .reshape(32, 16, 1, 2)
144                .to(dtype)[:, :, :, ::2]
145            )
146            self.assertFalse(weight.is_contiguous())
147            y = torch.nn.functional.conv2d(x, weight, None)
148            if torch.backends.mkldnn.is_available():
149                # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
150                with torch.backends.mkldnn.flags(enabled=False):
151                    y_ = torch.nn.functional.conv2d(x, weight, None)
152                    self.assertEqual(y, y_)
153            self.assertEqual(y.sum(), 4186112.0)
154
155    def test_invalid_conv2d(self):
156        for dtype in [
157            torch.half,
158            torch.bfloat16,
159            torch.float,
160            torch.double,
161            torch.cfloat,
162            torch.cdouble,
163        ]:
164            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(
165                dtype
166            )
167            input = torch.empty(1, 1, 4, 4).to(dtype)
168            self.assertRaises(RuntimeError, lambda: module(input))
169
170            module = nn.Conv2d(
171                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
172            )
173            input = torch.randn(1, 3, 1, 1)
174            with self.assertRaisesRegex(
175                RuntimeError,
176                r"Calculated padded input size per channel: \(1 x 1\). "
177                + r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size",
178            ):
179                module(input)
180
181            # Negative stride check
182            module = nn.Conv2d(
183                in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True
184            ).to(dtype)
185            input = torch.randn(1, 3, 4, 4).to(dtype)
186            with self.assertRaisesRegex(
187                RuntimeError, "non-positive stride is not supported"
188            ):
189                module(input)
190
191            # Zero stride check
192            module = nn.Conv2d(
193                in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True
194            ).to(dtype)
195            input = torch.randn(1, 3, 4, 4).to(dtype)
196            with self.assertRaisesRegex(
197                RuntimeError, "non-positive stride is not supported"
198            ):
199                module(input)
200
201    def test_invalid_conv3d(self):
202        for dtype in [
203            torch.half,
204            torch.bfloat16,
205            torch.float,
206            torch.double,
207            torch.cfloat,
208            torch.cdouble,
209        ]:
210            module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(
211                dtype
212            )
213            input = torch.empty(1, 1, 4, 4, 4).to(dtype)
214            self.assertRaises(RuntimeError, lambda: module(input))
215
216            # Negative stride check
217            module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
218            input = torch.empty(1, 1, 4, 4, 4)
219            with self.assertRaisesRegex(
220                RuntimeError, "non-positive stride is not supported"
221            ):
222                module(input)
223
224    def test_conv_invalid_groups(self):
225        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
226            torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
227        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
228            torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
229        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
230            torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
231
232    def test_Conv1d_module_same_padding(self):
233        # Compare module against functional: without strides/dilation, asymmetric padding
234        x = torch.rand(1, 1, 20)
235        module = nn.Conv1d(
236            in_channels=1, out_channels=1, kernel_size=10, padding="same"
237        )
238        expect = F.conv1d(x, module.weight, module.bias, padding="same")
239        self.assertEqual(expect, module(x))
240
241        # Test dilation, symmetric padding
242        module = nn.Conv1d(
243            in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2
244        )
245        expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2)
246        self.assertEqual(expect, module(x))
247
248        # Test non-zero padding_mode, requiring explicit padding
249        module = nn.Conv1d(
250            in_channels=1,
251            out_channels=1,
252            kernel_size=10,
253            padding="same",
254            padding_mode="replicate",
255        )
256        x_padded = F.pad(x, [4, 5], mode="replicate")
257        expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid")
258        self.assertEqual(expect, module(x))
259        self.assertEqual(x.size(), expect.size())
260
261        # Test connstruction with invalid padding string raises
262        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
263            module = nn.Conv1d(
264                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
265            )
266
267        # Test connstruction with same padding and strides raises
268        with self.assertRaisesRegex(ValueError, "padding='same'"):
269            module = nn.Conv1d(
270                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
271            )
272
273    def test_Conv2d_module_same_padding(self):
274        # Compare module against functional:
275        # without strides/dilation, both symmetric and asymmetric padding
276        x = torch.rand(1, 1, 9, 20)
277        module = nn.Conv2d(
278            in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same"
279        )
280        expect = F.conv2d(x, module.weight, module.bias, padding="same")
281        self.assertEqual(expect, module(x))
282
283        # with dilation, symmetric padding
284        module = nn.Conv2d(
285            in_channels=1,
286            out_channels=1,
287            kernel_size=(3, 4),
288            padding="same",
289            dilation=(1, 2),
290        )
291        expect = F.conv2d(
292            x, module.weight, module.bias, padding="same", dilation=(1, 2)
293        )
294        self.assertEqual(expect, module(x))
295
296        # Test non-zero padding_mode, requiring explicit padding
297        module = nn.Conv2d(
298            in_channels=1,
299            out_channels=1,
300            kernel_size=(3, 4),
301            padding="same",
302            padding_mode="reflect",
303        )
304        x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect")
305        expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid")
306        self.assertEqual(expect, module(x))
307        self.assertEqual(x.size(), expect.size())
308
309        # Test connstruction with invalid padding string raises
310        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
311            module = nn.Conv2d(
312                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
313            )
314
315        # Test connstruction with same padding and strides raises
316        with self.assertRaisesRegex(ValueError, "padding='same'"):
317            module = nn.Conv2d(
318                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
319            )
320        with self.assertRaisesRegex(ValueError, "padding='same'"):
321            module = nn.Conv2d(
322                in_channels=3,
323                out_channels=33,
324                kernel_size=10,
325                padding="same",
326                stride=(1, 3),
327            )
328        with self.assertRaisesRegex(ValueError, "padding='same'"):
329            module = nn.Conv2d(
330                in_channels=3,
331                out_channels=33,
332                kernel_size=10,
333                padding="same",
334                stride=(4, 1),
335            )
336
337    def test_Conv3d_module_same_padding(self):
338        # Compare module against functional:
339        x = torch.rand(1, 1, 4, 4, 4)
340        # without dilation, both symmetric and asymmetric padding
341        module = nn.Conv3d(
342            in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same"
343        )
344        expect = F.conv3d(x, module.weight, module.bias, padding="same")
345        self.assertEqual(expect, module(x))
346
347        # with dilation, both symmetric and asymmetric padding
348        module = nn.Conv3d(
349            in_channels=1,
350            out_channels=1,
351            kernel_size=(2, 3, 4),
352            padding="same",
353            dilation=(3, 2, 1),
354        )
355        expect = F.conv3d(
356            x, module.weight, module.bias, padding="same", dilation=(3, 2, 1)
357        )
358        self.assertEqual(expect, module(x))
359
360        # Test non-zero padding_mode, requiring explicit padding
361        module = nn.Conv3d(
362            in_channels=1,
363            out_channels=1,
364            kernel_size=(2, 3, 4),
365            padding="same",
366            padding_mode="circular",
367        )
368        x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular")
369        expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid")
370        self.assertEqual(expect, module(x))
371        self.assertEqual(x.size(), expect.size())
372
373        # Test connstruction with invalid padding string raises
374        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
375            module = nn.Conv3d(
376                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
377            )
378
379        # Test connstruction with same padding and strides raises
380        with self.assertRaisesRegex(ValueError, "padding='same'"):
381            module = nn.Conv2d(
382                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
383            )
384        with self.assertRaisesRegex(ValueError, "padding='same'"):
385            module = nn.Conv2d(
386                in_channels=3,
387                out_channels=33,
388                kernel_size=10,
389                padding="same",
390                stride=(1, 1, 3),
391            )
392        with self.assertRaisesRegex(ValueError, "padding='same'"):
393            module = nn.Conv2d(
394                in_channels=3,
395                out_channels=33,
396                kernel_size=10,
397                padding="same",
398                stride=(1, 4, 1),
399            )
400        with self.assertRaisesRegex(ValueError, "padding='same'"):
401            module = nn.Conv2d(
402                in_channels=3,
403                out_channels=33,
404                kernel_size=10,
405                padding="same",
406                stride=(5, 1, 1),
407            )
408
409    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
410    def test_thnn_conv_strided_padded_dilated(self):
411        for convfn, dims, transposed in (
412            (torch.nn.functional.conv2d, 2, False),
413            (torch.nn.functional.conv_transpose2d, 2, True),
414            (torch.nn.functional.conv3d, 3, False),
415            (torch.nn.functional.conv_transpose3d, 3, True),
416        ):
417            for stride, padding, dilation in (
418                (2, 0, 1),
419                (1, 1, 1),
420                (2, 1, 1),
421                (1, 0, 2),
422            ):
423                kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
424                inp_shape = (1, 2) + dims * (4,)
425                weight_shape = (2, 2) + dims * (1,)
426                inputs = torch.randn(
427                    inp_shape, dtype=torch.double, device="cuda", requires_grad=True
428                )
429                weight = torch.randn(
430                    weight_shape, dtype=torch.double, device="cuda", requires_grad=True
431                )
432                bias = torch.randn(
433                    2, dtype=torch.double, device="cuda", requires_grad=True
434                )
435                with torch.backends.cudnn.flags(enabled=False):
436                    res = convfn(inputs, weight, bias, **kwargs)
437                res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
438                self.assertEqual(res, res_cpu)
439                with torch.backends.cudnn.flags(enabled=False):
440                    torch.autograd.gradcheck(
441                        lambda x, w, b: convfn(x, w, b, **kwargs),
442                        (inputs, weight, bias),
443                    )
444                    torch.autograd.gradcheck(
445                        lambda x, w, b: convfn(x, w, b, **kwargs),
446                        (inputs.cpu(), weight.cpu(), bias.cpu()),
447                    )
448
449    def test_Conv2d_inconsistent_types(self):
450        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
451        weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
452        # inconsistent types should raise an exception
453        self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
454        # but it should work with the same type
455        nn.functional.conv2d(inputs.float(), weights.float())
456
457    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
458    def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
459        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
460        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
461        bias = torch.randn(1, dtype=torch.double, device="cuda")
462
463        with torch.backends.cudnn.flags(enabled=False):
464            # inconsistent types should raise an exception
465            self.assertRaises(
466                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
467            )
468            self.assertRaises(
469                RuntimeError,
470                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
471            )
472
473            # but it should work with the same type
474            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
475
476    def test_Conv2d_1x1(self):
477        in_channels = 2
478        out_channels = 2
479        mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
480        input = torch.randn(
481            1, in_channels, 5, 5, requires_grad=True, dtype=torch.double
482        )
483        for enabled in (False, True):
484            with torch.backends.mkldnn.flags(enabled=enabled):
485                gradcheck(F.conv2d, (input, mod.weight))
486
487    def test_Conv2d_OneDNN(self):
488        def run_once(group_val=24, dilation=1):
489            ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
490            weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
491            op = torch.nn.Conv2d(
492                in_channels=group_val,
493                out_channels=group_val,
494                kernel_size=[3, 3],
495                stride=[2, 2],
496                padding=[1, 1],
497                dilation=[dilation, dilation],
498                groups=group_val,
499                bias=False,
500                padding_mode="zeros",
501            )
502
503            op.weight.data = weights
504            res = op(ifm)
505            grad_in = torch.ones(res.shape, dtype=torch.float32)
506            res.backward(grad_in)
507            return op.weight.grad
508
509        for gorup_val in (24, 48, 23, 25):
510            for dilation in (1, 2):
511                with torch.backends.mkldnn.flags(enabled=False):
512                    without_onednn = run_once(gorup_val, dilation)
513
514                with torch.backends.mkldnn.flags(enabled=True):
515                    with_onednn = run_once(gorup_val, dilation)
516
517                self.assertEqual(without_onednn, with_onednn)
518
519    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
520    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
521    def test_cudnn_non_contiguous(self):
522        x = torch.randn(192, 16, 50).cuda()
523        x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
524        m = torch.nn.Conv1d(
525            in_channels=16, out_channels=32, kernel_size=2, bias=True
526        ).cuda()
527        result = m(x)
528
529    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
530    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
531    def test_cudnn_not_mutate_stride(self):
532        weight = torch.randn(64, 64, 1, 1)
533        x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
534        weight_stride = weight.stride()
535
536        def conv(x, weight):
537            return torch.convolution(
538                x,
539                weight,
540                stride=(1, 1),
541                padding=(0, 0),
542                dilation=(1, 1),
543                transposed=False,
544                output_padding=(0, 0),
545                groups=1,
546                bias=None,
547            )
548
549        # should have run in nhwc without mutating input strides
550        out_nhwc = conv(x, weight)
551        self.assertEqual(weight.stride(), weight_stride)
552        self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))
553
554        x = x.contiguous(memory_format=torch.contiguous_format)
555        out_c = conv(x, weight)
556        self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
557        self.assertEqual(out_c, out_nhwc)
558        self.assertEqual(weight.stride(), weight_stride)
559
560    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
561    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
562    def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
563        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
564        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
565        bias = torch.randn(1, dtype=torch.double, device="cuda")
566
567        with torch.backends.cudnn.flags(enabled=True):
568            # inconsistent types should raise an exception
569            self.assertRaises(
570                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
571            )
572            self.assertRaises(
573                RuntimeError,
574                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
575            )
576
577            # but it should work with the same type
578            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
579
580    def test_Conv2d_missing_argument(self):
581        c = nn.Conv2d(3, 3, 3)
582        self.assertRaises(TypeError, lambda: c(None))
583
584    def test_Conv2d_backward_twice(self):
585        input = torch.randn(2, 3, 5, 5)
586        c = nn.Conv2d(3, 3, 3)
587        o1 = c(input)
588        o1.sum().backward()
589        self.assertRaisesRegex(
590            RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward()
591        )
592
593    def test_conv_modules_raise_error_on_incorrect_input_size(self):
594        for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
595            modules = [
596                nn.Conv1d(3, 8, 3).to(dtype),
597                nn.ConvTranspose1d(3, 8, 3).to(dtype),
598                nn.Conv2d(3, 8, 3).to(dtype),
599                nn.ConvTranspose2d(3, 8, 3).to(dtype),
600                nn.Conv3d(3, 8, 3).to(dtype),
601                nn.ConvTranspose3d(3, 8, 3).to(dtype),
602            ]
603
604            invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)]
605
606            for invalid_dims, module in zip(invalid_input_dims, modules):
607                for dims in invalid_dims:
608                    input = torch.empty(torch.Size((3,) * dims))
609                    self.assertRaises(RuntimeError, lambda: module(input))
610
611    def test_conv_shapecheck(self):
612        def test(should_raise, module, input_size, dtype):
613            input = torch.empty(3, *input_size).to(dtype)
614            if should_raise:
615                self.assertRaises(RuntimeError, lambda: module(input))
616            else:
617                # just run it to ensure no exception raised.
618                module(input)
619
620        for dtype in [
621            torch.half,
622            torch.bfloat16,
623            torch.float,
624            torch.double,
625            torch.cfloat,
626            torch.cdouble,
627        ]:
628            # Conv1d
629            test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
630            test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
631            test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
632            test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
633            test(
634                False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype
635            )
636
637            # Conv2d
638            test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
639            test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
640            test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)
641
642            # Conv3D
643            test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
644            test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
645            test(
646                False,
647                nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype),
648                (1, 2, 2, 2),
649                dtype,
650            )
651
652    def test_ConvTranspose2d_output_size(self):
653        m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
654        i = torch.randn(2, 3, 6, 6)
655        for h in range(15, 22):
656            for w in range(15, 22):
657                if 18 <= h <= 20 and 18 <= w <= 20:
658                    output = m(i, output_size=(h, w))
659                    self.assertEqual(output.size()[2:], (h, w))
660                else:
661                    self.assertRaises(ValueError, lambda: m(i, (h, w)))
662
663    def test_ConvTranspose2d_output_size_downsample_upsample(self):
664        b, c, hid_c = 2, 3, 2
665        for h in range(13, 24):
666            for w in range(13, 17):
667                for k in range(2, 5):
668                    for d in range(1, 5):
669                        for s in range(1, 4):
670                            for p in range(3):
671                                conv = nn.Conv2d(
672                                    in_channels=c,
673                                    out_channels=hid_c,
674                                    kernel_size=k,
675                                    stride=s,
676                                    padding=p,
677                                    dilation=d,
678                                )
679
680                                t_conv = nn.ConvTranspose2d(
681                                    in_channels=hid_c,
682                                    out_channels=c,
683                                    kernel_size=k,
684                                    stride=s,
685                                    padding=p,
686                                    dilation=d,
687                                )
688
689                                i = torch.randn(b, c, h, w)
690
691                                out = t_conv(conv(i), output_size=i.shape)
692
693                                self.assertEqual(out.size()[2:], i.size()[2:])
694
695    def test_ConvTranspose3d_correct_output_size(self):
696        # Check that ConvTranspose3d can take a 5d output_size.
697        m = nn.ConvTranspose3d(2, 2, 2)
698        i = torch.rand(1, 2, 1, 1, 1)
699        out = m(i, output_size=(1, 2, 2, 2, 2))
700
701    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
702    def test_ConvTranspose2d_half_cublas_gemm(self):
703        with torch.backends.cudnn.flags(enabled=False):
704            inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half)
705            deconv = (
706                nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1)
707                .cuda()
708                .half()
709            )
710            output = deconv(inputs)
711            output.mean().backward()
712
713    # For https://github.com/pytorch/pytorch/pull/1273
714    # Almost identical to the above `test_Conv2d_naive_groups`
715    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
716    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
717    def test_Conv2d_groups_nobias(self):
718        dev_dtypes = [("cpu", torch.float)]
719        if TEST_CUDA:
720            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
721        if AMPERE_OR_ROCM:
722            dev_dtypes += [("cuda", torch.bfloat16)]
723        for device, dtype in dev_dtypes:
724            m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
725            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
726            output = m(i)
727            grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
728            output.backward(grad_output)
729
730            m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
731            m1.weight.data.copy_(m.weight.data[:2])
732            i1 = i.data[:, :2].contiguous().requires_grad_(True)
733            output1 = m1(i1)
734            output1.backward(grad_output[:, :2].contiguous())
735
736            m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
737            m2.weight.data.copy_(m.weight.data[2:])
738            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
739            output2 = m2(i2)
740            output2.backward(grad_output[:, 2:].contiguous())
741
742            self.assertEqual(output, torch.cat([output1, output2], 1))
743            self.assertEqual(
744                i.grad.data,
745                torch.cat([i1.grad.data, i2.grad.data], 1),
746                atol=dtype2prec_DONTUSE[dtype],
747                rtol=0,
748            )
749            self.assertEqual(
750                m.weight.grad.data,
751                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
752                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
753                rtol=0,
754            )
755
756    # Almost identical to the above `test_Conv2d_naive_groups`
757    # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
758    # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
759    # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
760    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
761    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
762    def test_Conv2d_groups_nobias_v2(self):
763        torch.manual_seed(123)
764        dev_dtypes = [("cpu", torch.float)]
765        if TEST_CUDA:
766            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
767        if AMPERE_OR_ROCM:
768            dev_dtypes += [("cuda", torch.bfloat16)]
769        for device, dtype in dev_dtypes:
770            m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
771            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
772            output = m(i)
773            grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
774            output.backward(grad_output)
775
776            m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
777            m1.weight.data.copy_(m.weight.data[:8])
778            i1 = i.data[:, :2].contiguous().requires_grad_(True)
779            output1 = m1(i1)
780            output1.backward(grad_output[:, :8].contiguous())
781
782            m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
783            m2.weight.data.copy_(m.weight.data[8:])
784            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
785            output2 = m2(i2)
786            output2.backward(grad_output[:, 8:].contiguous())
787
788            self.assertEqual(output, torch.cat([output1, output2], 1))
789            self.assertEqual(
790                i.grad.data,
791                torch.cat([i1.grad.data, i2.grad.data], 1),
792                atol=dtype2prec_DONTUSE[dtype],
793                rtol=0,
794            )
795            self.assertEqual(
796                m.weight.grad.data,
797                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
798                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
799                rtol=0,
800            )
801
802    # CPU-only test for group conv3d fast implementation using bmm
803    # See: https://github.com/pytorch/pytorch/pull/36355
804    def test_Conv3d_groups_nobias(self):
805        torch.manual_seed(123)
806        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
807        i = torch.randn(
808            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
809        )
810        output = m(i)
811        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
812        output.backward(grad_output)
813
814        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
815        m1.weight.data.copy_(m.weight.data[:8])
816        i1 = i.data[:, :2].contiguous().requires_grad_(True)
817        output1 = m1(i1)
818        output1.backward(grad_output[:, :8].contiguous())
819
820        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
821        m2.weight.data.copy_(m.weight.data[8:])
822        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
823        output2 = m2(i2)
824        output2.backward(grad_output[:, 8:].contiguous())
825
826        self.assertEqual(output, torch.cat([output1, output2], 1))
827        self.assertEqual(
828            i.grad.data,
829            torch.cat([i1.grad.data, i2.grad.data], 1),
830            atol=dtype2prec_DONTUSE[torch.float],
831            rtol=0,
832        )
833        self.assertEqual(
834            m.weight.grad.data,
835            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
836            atol=dtype2prec_DONTUSE[torch.float],
837            rtol=dtype2prec_DONTUSE[torch.float],
838        )
839
840    def test_Conv3d_groups_wbias(self):
841        torch.manual_seed(123)
842        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
843        i = torch.randn(
844            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
845        )
846        output = m(i)
847        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
848        output.backward(grad_output)
849
850        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
851        m1.weight.data.copy_(m.weight.data[:8])
852        m1.bias.data.copy_(m.bias.data[:8])
853        i1 = i.data[:, :2].contiguous().requires_grad_(True)
854        output1 = m1(i1)
855        output1.backward(grad_output[:, :8].contiguous())
856
857        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
858        m2.weight.data.copy_(m.weight.data[8:])
859        m2.bias.data.copy_(m.bias.data[8:])
860        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
861        output2 = m2(i2)
862        output2.backward(grad_output[:, 8:].contiguous())
863
864        self.assertEqual(output, torch.cat([output1, output2], 1))
865        self.assertEqual(
866            i.grad.data,
867            torch.cat([i1.grad.data, i2.grad.data], 1),
868            atol=dtype2prec_DONTUSE[torch.float],
869            rtol=dtype2prec_DONTUSE[torch.float],
870        )
871        self.assertEqual(
872            m.weight.grad.data,
873            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
874            atol=dtype2prec_DONTUSE[torch.float],
875            rtol=dtype2prec_DONTUSE[torch.float],
876        )
877        self.assertEqual(
878            m.bias.grad.data,
879            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
880            atol=dtype2prec_DONTUSE[torch.float],
881            rtol=dtype2prec_DONTUSE[torch.float],
882        )
883
884    def test_conv_tbc(self):
885        with set_default_dtype(torch.double):
886            inp = torch.randn(9, 4, 5, requires_grad=True)
887            weight = torch.randn(3, 5, 6, requires_grad=True)
888            bias = torch.randn(6, requires_grad=True)
889
890            gradcheck(
891                lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)
892            )
893
894    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
895    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
896    @skipIfRocmVersionLessThan((4, 3))
897    @skipIfNotMiopenSuggestNHWC
898    def test_grouped_conv_cudnn_nhwc_support(self):
899        # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
900        input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
901            memory_format=torch.channels_last
902        )
903        weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(
904            memory_format=torch.channels_last
905        )
906        out = torch.convolution(
907            input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4
908        )
909        input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(
910            memory_format=torch.channels_last
911        )
912        out_transpose = torch.convolution(
913            input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4
914        )
915
916    @unittest.expectedFailure
917    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
918    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
919    def test_conv_cudnn_memory_layout_dominance(self):
920        # desired behavior here is to have the memory_layout of conv.weight to
921        # dominante the layout of output.
922        # which is not the same as current behavior, we'll fix this in
923        # following up PRs and remove the `expectedFailure` tag
924        input = torch.randint(
925            1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True
926        )
927        conv = nn.Conv2d(8, 4, 3).cuda().float()
928
929        out = conv(input)
930        self.assertTrue(out.is_contiguous())
931
932        input = input.contiguous(memory_format=torch.channels_last)
933        out = conv(input)
934        self.assertTrue(out.is_contiguous())
935
936        conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
937        out = conv(input)
938        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
939
940        input = input.contiguous()
941        out = conv(input)
942        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
943
944    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
945    def test_cudnn_noncontiguous_weight(self):
946        # Noncontiguous weights must be contiguous() before being
947        # passed to cuDNN
948        input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
949        weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
950        weights2 = (
951            torch.tensor([1], dtype=torch.double, device="cuda")
952            .expand(1, 1, 2)
953            .contiguous()
954        )
955        self.assertEqual(
956            F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
957            F.conv1d(input, weights2, bias=None, stride=2, dilation=2),
958        )
959
960    def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"):
961        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
962            for batch, stride, padding, chan_in, chan_out, dilation in product(
963                [1, 2], [1, 2], [0, 1, 2], [2], [3], [1]
964            ):
965                for has_bias in [True, False]:
966                    input_shape = [batch, chan_in]
967                    weight_shape = [chan_out, chan_in]
968                    for _ in range(dim):
969                        input_shape.append(inp_size)
970                        weight_shape.append(kern)
971
972                    input = torch.randn(input_shape, requires_grad=True)
973                    weight = torch.randn(weight_shape, requires_grad=True)
974                    if has_bias:
975                        bias = torch.randn([chan_out], requires_grad=True)
976                    output = func_forward(
977                        input,
978                        weight,
979                        stride=stride,
980                        padding=padding,
981                        dilation=dilation,
982                        bias=bias,
983                    )
984
985                    gradient_o = torch.randn(output.shape)
986                    gradient_w = torch.autograd.grad(
987                        output, input if (gradient == "input") else weight, gradient_o
988                    )
989
990                    self.assertEqual(
991                        gradient_w[0],
992                        func_backward(
993                            input_shape if (gradient == "input") else input,
994                            weight_shape if (gradient == "weight") else weight,
995                            gradient_o,
996                            stride=stride,
997                            padding=padding,
998                            dilation=dilation,
999                        ),
1000                    )
1001
1002    def test_grad_conv1d_input(self):
1003        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input")
1004
1005    def test_grad_conv1d_weight(self):
1006        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight")
1007
1008    def test_grad_conv2d_input(self):
1009        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input")
1010
1011    def test_grad_conv2d_weight(self):
1012        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight")
1013
1014    def test_grad_conv3d_input(self):
1015        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input")
1016
1017    def test_grad_conv3d_weight(self):
1018        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight")
1019
1020    @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
1021    def test_nnpack_conv(self):
1022        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
1023            for batch, stride, padding, chan_in, chan_out in product(
1024                [1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]
1025            ):
1026                for has_bias in [True, False]:
1027                    input_shape = [batch, chan_in]
1028                    weight_shape = [chan_out, chan_in]
1029                    for _ in range(2):
1030                        input_shape.append(inp_size)
1031                        weight_shape.append(kern)
1032
1033                    input = torch.randn(
1034                        input_shape, requires_grad=True, dtype=torch.float
1035                    )
1036                    weight = torch.randn(
1037                        weight_shape, requires_grad=True, dtype=torch.float
1038                    )
1039                    if has_bias:
1040                        bias = torch.randn(
1041                            [chan_out], requires_grad=True, dtype=torch.float
1042                        )
1043                    output = torch._nnpack_spatial_convolution(
1044                        input, weight, stride=stride, padding=padding, bias=bias
1045                    )
1046                    output_expected = torch.nn.functional.conv2d(
1047                        input, weight, stride=stride, padding=padding, bias=bias
1048                    )
1049                    self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
1050
1051                    gradient_o = torch.randn(output.shape, dtype=torch.float)
1052
1053                    grads = torch.autograd.grad(output, [input, weight], gradient_o)
1054                    grads_expected = torch.autograd.grad(
1055                        output_expected, [input, weight], gradient_o
1056                    )
1057                    for gr, gr_expected in zip(grads, grads_expected):
1058                        self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)
1059
1060    def test_conv_padding_mode(self):
1061        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
1062            nn.Conv2d(3, 3, 3, padding_mode="xyz")
1063
1064        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
1065            nn.Conv2d(3, 3, 3, padding_mode=3)
1066
1067        with self.assertRaisesRegex(ValueError, 'Only "zeros" '):
1068            nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
1069
1070    def test_functional_grad_conv(self):
1071        # Conv 1D
1072        input = torch.randn(1, 1, 5, requires_grad=True)
1073        weight = torch.randn(1, 1, 3, requires_grad=True)
1074        output = F.conv1d(input, weight, dilation=2)
1075        grad_output = torch.randn(output.shape)
1076
1077        grad_input_autograd, grad_weight_autograd = torch.autograd.grad(
1078            output, (input, weight), grad_output
1079        )
1080
1081        grad_input_functional = torch.nn.grad.conv1d_input(
1082            input.shape, weight, grad_output, dilation=2
1083        )
1084        self.assertEqual(grad_input_functional, grad_input_autograd)
1085
1086        grad_weight_functional = torch.nn.grad.conv1d_weight(
1087            input, weight.shape, grad_output, dilation=2
1088        )
1089        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1090
1091        # Conv 2D
1092        input = torch.randn(1, 1, 5, 5, requires_grad=True)
1093        weight = torch.randn(1, 1, 3, 3, requires_grad=True)
1094        output = F.conv2d(input, weight, dilation=2)
1095        grad_output = torch.randn(output.shape)
1096
1097        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1098            output, (input, weight), grad_output
1099        )
1100
1101        grad_input_functional = torch.nn.grad.conv2d_input(
1102            input.shape, weight, grad_output, dilation=2
1103        )
1104        self.assertEqual(grad_input_functional, grad_input_autograd)
1105
1106        grad_weight_functional = torch.nn.grad.conv2d_weight(
1107            input, weight.shape, grad_output, dilation=2
1108        )
1109        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1110
1111        # Conv 3D
1112        input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
1113        weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
1114        output = F.conv3d(input, weight, dilation=2)
1115        grad_output = torch.randn(output.shape)
1116
1117        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1118            output, (input, weight), grad_output
1119        )
1120
1121        grad_input_functional = torch.nn.grad.conv3d_input(
1122            input.shape, weight, grad_output, dilation=2
1123        )
1124        self.assertEqual(grad_input_functional, grad_input_autograd)
1125
1126        grad_weight_functional = torch.nn.grad.conv3d_weight(
1127            input, weight.shape, grad_output, dilation=2
1128        )
1129        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1130
1131    def test_functional_grad_conv2d(self):
1132        BATCH_SIZE = 4
1133        IN_CH = 8
1134        OUT_CH = 16
1135        SPATIAL = 32
1136
1137        def _test_conv2d(stride, kernel_size, groups, dilation):
1138            padding = kernel_size // 2
1139
1140            input = (
1141                torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL)
1142                .uniform_(-8.0, 8.0)
1143                .requires_grad_(True)
1144            )
1145
1146            weight = (
1147                torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size)
1148                .uniform_(-4.0, 4.0)
1149                .requires_grad_(True)
1150            )
1151
1152            output = F.conv2d(
1153                input,
1154                weight,
1155                stride=stride,
1156                padding=padding,
1157                dilation=dilation,
1158                groups=groups,
1159            )
1160
1161            grad_output = torch.randn(output.shape)
1162
1163            (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1164                output, (input, weight), grad_output
1165            )
1166
1167            grad_input_functional = torch.nn.grad.conv2d_input(
1168                input.shape,
1169                weight,
1170                grad_output,
1171                stride=stride,
1172                padding=padding,
1173                dilation=dilation,
1174                groups=groups,
1175            )
1176            self.assertEqual(grad_input_functional, grad_input_autograd)
1177
1178            grad_weight_functional = torch.nn.grad.conv2d_weight(
1179                input,
1180                weight.shape,
1181                grad_output,
1182                stride=stride,
1183                padding=padding,
1184                dilation=dilation,
1185                groups=groups,
1186            )
1187            self.assertEqual(grad_weight_functional, grad_weight_autograd)
1188
1189        strides = [1, 2]
1190        kernel_sizes = [1, 3, 5]
1191        groups = [1, 2, 4]
1192        dilates = [1, 2]
1193
1194        for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
1195            _test_conv2d(s, k, g, d)
1196
1197    def test_permute_conv2d_issue_120211(self):
1198        def reproducer(radius: int):
1199            image = torch.rand(1, 1024, 1024, 3)
1200            image = image.permute(0, 3, 1, 2)
1201            kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device)
1202            image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3])
1203
1204        for i in range(0, 128):
1205            # This should not fail
1206            reproducer(radius=i)
1207
1208    def test_conv3d_issue_120406(self):
1209        # This should not fail
1210        F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3)
1211
1212    def test_conv1d_issue_120547(self):
1213        weight = torch.ones([16, 1, 32])
1214        bias = torch.ones([16])
1215        stride, padding, dilation, groups = (1, 16, 1, 16)
1216        input = torch.rand((1, 1, 16))
1217        input = input.transpose(1, 2)
1218        # This should not fail
1219        F.conv1d(input, weight, bias, stride, padding, dilation, groups)
1220
1221
1222class TestConvolutionNNDeviceType(NNTestCase):
1223    def run_conv_double_back_test(
1224        self,
1225        kern,
1226        stride,
1227        padding,
1228        chan_in,
1229        chan_out,
1230        batch_size,
1231        inp_size,
1232        dilation,
1233        no_weight,
1234        groups=1,
1235        use_cuda=False,
1236        use_bias=True,
1237        dtype=torch.double,
1238    ):
1239        if use_cuda:
1240            device = torch.device("cuda")
1241        else:
1242            device = torch.device("cpu")
1243
1244        x = torch.randn(
1245            batch_size,
1246            chan_in,
1247            inp_size,
1248            inp_size,
1249            device=device,
1250            dtype=dtype,
1251            requires_grad=True,
1252        )
1253        weight = torch.randn(
1254            chan_out,
1255            chan_in // groups,
1256            kern,
1257            kern,
1258            device=device,
1259            dtype=dtype,
1260            requires_grad=not no_weight,
1261        )
1262        if use_bias:
1263            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
1264        else:
1265            bias = None
1266
1267        def func(*inputs):
1268            if use_bias:
1269                lx, lweight, lbias = inputs
1270            else:
1271                lx, lweight = inputs
1272                lbias = None
1273            # We disable cudnn during forward to avoid finite difference imprecision issues
1274            with cudnn.flags(enabled=False):
1275                out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
1276            return out
1277
1278        if use_bias:
1279            inputs = x, weight, bias
1280        else:
1281            inputs = x, weight
1282
1283        dummy_out = func(*inputs)
1284        grad_y = torch.randn_like(
1285            dummy_out, device=device, dtype=dtype, requires_grad=True
1286        )
1287
1288        # Issue #15353: test mkldnn double backward, don't run gradgradcheck due
1289        # to imprecision issues
1290        if dtype == torch.float:
1291            (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
1292            return g.requires_grad
1293
1294        return gradgradcheck(func, inputs, (grad_y,))
1295
1296    @onlyCUDA
1297    @skipCUDAIfNoCudnn
1298    @dtypes(
1299        *floating_and_complex_types_and(
1300            torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []
1301        )
1302    )
1303    def test_Conv2d_deterministic_cudnn(self, device, dtype):
1304        inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True)
1305        with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
1306            conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
1307            conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
1308            conv2.bias.data.copy_(conv1.bias.data)
1309            conv2.weight.data.copy_(conv1.weight.data)
1310            out1 = conv1(inputs)
1311            out2 = conv2(inputs)
1312            self.assertEqual(out1, out2, atol=0.0, rtol=0)
1313            y = torch.randn(out1.size(), device=device, dtype=dtype)
1314            out1.backward(y)
1315            out2.backward(y)
1316            self.assertEqual(
1317                conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0
1318            )
1319            self.assertEqual(
1320                conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0
1321            )
1322
1323    @onlyCUDA
1324    @dtypes(
1325        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
1326    )
1327    def test_Conv2d_large_workspace(self, device, dtype):
1328        # These sizes require huge cuDNN workspaces. Make sure we choose a
1329        # reasonable algorithm that does not run out of memory
1330        sizes = [
1331            (1, 256, 109, 175),
1332            (1, 256, 80, 128),
1333            (1, 256, 120, 192),
1334        ]
1335
1336        def run_test(benchmark):
1337            with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
1338                conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(
1339                    device, dtype
1340                )
1341                for size in sizes:
1342                    x = torch.randn(size, device=device, dtype=dtype)
1343                    out = conv(x.detach().clone().requires_grad_())
1344                    out.backward(torch.ones_like(out))
1345
1346        run_test(benchmark=False)
1347        run_test(benchmark=True)
1348
1349    @onlyCUDA
1350    @dtypes(torch.half, torch.float)
1351    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
1352        net1 = torch.nn.ConvTranspose2d(
1353            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
1354        ).to(device=device, dtype=dtype)
1355        net2 = torch.nn.ConvTranspose2d(
1356            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
1357        ).to(device=device, dtype=dtype)
1358        net3 = torch.nn.ConvTranspose2d(
1359            32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
1360        ).to(device=device, dtype=dtype)
1361        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
1362        x = net1(x)
1363        x = net2(x)
1364        x = net3(x)
1365        x.backward(torch.randn_like(x))
1366        torch.cuda.synchronize()
1367
1368    @onlyCUDA
1369    @dtypes(torch.float, torch.double, torch.half)
1370    # Very similar to test_Conv2d_naive_groups but with special care to handle
1371    # the number of groups == number of input channels
1372    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1373    @tf32_on_and_off(0.01)
1374    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
1375        for depth_multiplier in [1, 2]:
1376            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
1377                device, dtype
1378            )
1379            i = (
1380                torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype)
1381                .div_(2)
1382                .requires_grad_()
1383            )
1384            output = m(i)
1385            grad_output = (
1386                torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
1387                / 2
1388            )
1389            output.backward(grad_output)
1390
1391            offset = 1 * depth_multiplier
1392
1393            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1394            m1.weight.data = m.weight.data[:offset].clone()
1395            m1.bias.data = m.bias.data[:offset].clone()
1396            i1 = i.detach()[:, :1].clone().requires_grad_()
1397            output1 = m1(i1)
1398            output1.backward(grad_output[:, :offset].contiguous())
1399
1400            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1401            m2.weight.data.copy_(m.weight.data[offset:])
1402            m2.bias.data.copy_(m.bias.data[offset:])
1403            i2 = i.detach()[:, 1:].clone().requires_grad_()
1404            output2 = m2(i2)
1405            output2.backward(grad_output[:, offset:].contiguous())
1406
1407            self.assertEqual(
1408                output,
1409                torch.cat([output1, output2], 1),
1410                atol=dtype2prec_DONTUSE[dtype],
1411                rtol=0,
1412            )
1413            self.assertEqual(
1414                i.grad.data,
1415                torch.cat([i1.grad.data, i2.grad.data], 1),
1416                atol=dtype2prec_DONTUSE[dtype],
1417                rtol=0,
1418            )
1419            self.assertEqual(
1420                m.bias.grad.data,
1421                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
1422                atol=dtype2prec_DONTUSE[dtype],
1423                rtol=0,
1424            )
1425            self.assertEqual(
1426                m.weight.grad.data,
1427                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
1428                atol=dtype2prec_DONTUSE[dtype],
1429                rtol=0,
1430            )
1431
1432    @onlyCUDA
1433    @dtypes(torch.float, torch.double, torch.half)
1434    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1435    @tf32_on_and_off(0.01)
1436    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
1437        for depth_multiplier in [1, 2]:
1438            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
1439                device, dtype
1440            )
1441            i = (
1442                torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype)
1443                .div_(2)
1444                .requires_grad_()
1445            )
1446            output = m(i)
1447            grad_output = (
1448                torch.randn(
1449                    2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
1450                )
1451                / 2
1452            )
1453            output.backward(grad_output)
1454
1455            offset = 1 * depth_multiplier
1456
1457            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1458            m1.weight.data = m.weight.data[:offset].clone()
1459            m1.bias.data = m.bias.data[:offset].clone()
1460            i1 = i.detach()[:, :1].clone().requires_grad_()
1461            output1 = m1(i1)
1462            output1.backward(grad_output[:, :offset].contiguous())
1463
1464            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1465            m2.weight.data.copy_(m.weight.data[offset:])
1466            m2.bias.data.copy_(m.bias.data[offset:])
1467            i2 = i.detach()[:, 1:].clone().requires_grad_()
1468            output2 = m2(i2)
1469            output2.backward(grad_output[:, offset:].contiguous())
1470            is_cuda_sm86 = device.startswith(
1471                "cuda"
1472            ) and torch.cuda.get_device_capability(0) == (8, 6)
1473            atol, rtol = (
1474                (3e-4, 3e-2)
1475                if dtype == torch.float32 and is_cuda_sm86
1476                else (dtype2prec_DONTUSE[dtype], 0)
1477            )
1478
1479            self.assertEqual(
1480                output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
1481            )
1482            self.assertEqual(
1483                i.grad.data,
1484                torch.cat([i1.grad.data, i2.grad.data], 1),
1485                atol=dtype2prec_DONTUSE[dtype],
1486                rtol=0,
1487            )
1488            self.assertEqual(
1489                m.bias.grad.data,
1490                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
1491                atol=dtype2prec_DONTUSE[dtype],
1492                rtol=0,
1493            )
1494            self.assertEqual(
1495                m.weight.grad.data,
1496                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
1497                atol=atol,
1498                rtol=rtol,
1499            )
1500
1501    @onlyCUDA
1502    @dtypes(
1503        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
1504    )
1505    def test_noncontig_conv_grad(self, device, dtype):
1506        # FIXME: remove after adding non-contiguous grad tests for all modules
1507        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
1508        input = torch.randn(
1509            2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
1510        )
1511        output = module(input)
1512
1513        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
1514        assert not grad.is_contiguous()
1515        output.backward(grad, retain_graph=True)
1516        self.assertIsNotNone(input.grad)
1517        result = input.grad.data.clone()
1518        input.grad.data.zero_()
1519
1520        output.backward(grad.contiguous())
1521        self.assertEqual(
1522            result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
1523        )
1524
1525    @onlyCUDA
1526    @dtypes(torch.double)
1527    def test_conv_double_backward(self, device, dtype):
1528        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1529            # Double backward only runs with DoubleTensor due to precision reason
1530            batch_size = 1
1531            for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
1532                for stride, padding, chan_in, chan_out, dilation in product(
1533                    [1], [2], [2], [3], dilations
1534                ):
1535                    no_weight = stride == 2
1536                    result = self.run_conv_double_back_test(
1537                        kern,
1538                        stride,
1539                        padding,
1540                        chan_in,
1541                        chan_out,
1542                        batch_size,
1543                        inp_size,
1544                        dilation,
1545                        no_weight,
1546                        use_cuda=True,
1547                        dtype=dtype,
1548                    )
1549                    self.assertTrue(
1550                        result,
1551                        "Conv double backward test failed with parameters:"
1552                        + "\nkern: "
1553                        + str(kern)
1554                        + "\nstride: "
1555                        + str(stride)
1556                        + "\npadding: "
1557                        + str(padding)
1558                        + "\nchan_in: "
1559                        + str(chan_in)
1560                        + "\nchan_out: "
1561                        + str(chan_out)
1562                        + "\nbatch_size: "
1563                        + str(batch_size)
1564                        + "\ninp_size: "
1565                        + str(inp_size)
1566                        + "\ndilation: "
1567                        + str(dilation),
1568                    )
1569
1570    def test_conv_double_backward_no_bias(self):
1571        kern = 3
1572        stride = 2
1573        chan_in, chan_out = 2, 4
1574        batch_size = 2
1575        inp_size = 5
1576        padding = 1
1577        dilation = 1
1578        no_weight = False
1579        use_bias = True
1580        result = self.run_conv_double_back_test(
1581            kern,
1582            stride,
1583            padding,
1584            chan_in,
1585            chan_out,
1586            batch_size,
1587            inp_size,
1588            dilation,
1589            no_weight,
1590            use_bias=use_bias,
1591        )
1592        self.assertTrue(
1593            result,
1594            "Conv double backward test failed with parameters:"
1595            + "\nkern: "
1596            + str(kern)
1597            + "\nstride: "
1598            + str(stride)
1599            + "\npadding: "
1600            + str(padding)
1601            + "\nchan_in: "
1602            + str(chan_in)
1603            + "\nchan_out: "
1604            + str(chan_out)
1605            + "\nbatch_size: "
1606            + str(batch_size)
1607            + "\ninp_size: "
1608            + str(inp_size)
1609            + "\ndilation: "
1610            + str(dilation),
1611        )
1612
1613    def test_conv_double_backward_groups(self):
1614        kern = 3
1615        stride = 1
1616        padding = 2
1617        chan_in, chan_out = 2, 4
1618        batch_size = 2
1619        inp_size = 6
1620        dilation = 1
1621        no_weight = False
1622        groups = 2
1623        result = self.run_conv_double_back_test(
1624            kern,
1625            stride,
1626            padding,
1627            chan_in * groups,
1628            chan_out * groups,
1629            batch_size,
1630            inp_size,
1631            dilation,
1632            no_weight,
1633            groups=groups,
1634        )
1635        self.assertTrue(
1636            result,
1637            "Conv double backward test failed with parameters:"
1638            + "\nkern: "
1639            + str(kern)
1640            + "\nstride: "
1641            + str(stride)
1642            + "\npadding: "
1643            + str(padding)
1644            + "\nchan_in: "
1645            + str(chan_in)
1646            + "\nchan_out: "
1647            + str(chan_out)
1648            + "\nbatch_size: "
1649            + str(batch_size)
1650            + "\ninp_size: "
1651            + str(inp_size)
1652            + "\ndilation: "
1653            + str(dilation)
1654            + "\ngroups: "
1655            + str(groups),
1656        )
1657
1658    def test_conv_double_backward_stride(self):
1659        batch_size = 2
1660
1661        # Cannot provide ggW when stride is > 1
1662        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
1663            for stride, padding, chan_in, chan_out, dilation in product(
1664                [2], [0, 1], [1], [2], dilations
1665            ):
1666                no_weight = False
1667                self.run_conv_double_back_test(
1668                    kern,
1669                    stride,
1670                    padding,
1671                    chan_in,
1672                    chan_out,
1673                    batch_size,
1674                    inp_size,
1675                    dilation,
1676                    no_weight,
1677                )
1678
1679    @dtypes(torch.float, torch.cfloat)
1680    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1681    def test_conv1d_same_padding(self, device, dtype):
1682        # Test padding='same' outputs the correct shape
1683        test_args = [
1684            # in_size
1685            range(50, 55),
1686            # kernel_size
1687            [1, 2, 3, 8],
1688            # dilation
1689            range(1, 4),
1690            # stride
1691            [1],
1692        ]
1693        for in_size, k_size, dilation, stride in itertools.product(*test_args):
1694            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
1695            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
1696            z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
1697            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
1698
1699        # Compare F.conv1d padding='same' output against manual padding
1700        # Without strides/dilation
1701        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
1702        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
1703        expect = F.conv1d(x, y, padding=1)
1704        actual = F.conv1d(x, y, padding="same")
1705        self.assertEqual(expect, actual)
1706
1707        # With dilation
1708        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
1709        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
1710        expect = F.conv1d(x, y, padding=3, dilation=2)
1711        actual = F.conv1d(x, y, padding="same", dilation=2)
1712        self.assertEqual(expect, actual)
1713
1714        # Dilation with asymmetric padding
1715        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
1716        actual = F.conv1d(x, y, padding="same", dilation=3)
1717        self.assertEqual(expect, actual)
1718
1719    @dtypes(torch.float, torch.cfloat)
1720    def test_conv2d_same_padding(self, device, dtype):
1721        if dtype is torch.cfloat:
1722            rtol, atol = 2e-6, 2e-6
1723        else:
1724            rtol, atol = None, None
1725        # Compare F.conv2d padding='same' output against manual padding
1726        # Without strides/dilation
1727        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
1728        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
1729        expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
1730        actual = F.conv2d(x, y, padding="same")
1731        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1732
1733        # With dilation
1734        y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
1735        expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
1736        actual = F.conv2d(x, y, padding="same", dilation=2)
1737        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1738
1739        # Dilation with asymmetric padding
1740        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
1741        expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
1742        actual = F.conv2d(x, y, padding="same", dilation=3)
1743        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1744
1745    @dtypes(torch.float, torch.cfloat)
1746    def test_conv3d_same_padding(self, device, dtype):
1747        if dtype is torch.cfloat:
1748            rtol, atol = 2e-6, 2e-6
1749        else:
1750            rtol, atol = None, None
1751        # Compare F.conv3d padding='same' output against manual padding
1752        # Without strides/dilation
1753        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
1754        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
1755        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
1756        actual = F.conv3d(x, y, padding="same")
1757        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1758
1759        # With dilation
1760        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
1761        actual = F.conv3d(x, y, padding="same", dilation=2)
1762        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1763
1764        # Dilation with asymmetric padding
1765        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
1766        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
1767        actual = F.conv3d(x, y, padding="same", dilation=3)
1768        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1769
1770    @dtypes(torch.float, torch.cfloat)
1771    def test_conv1d_valid_padding(self, device, dtype):
1772        # Test F.conv1d padding='valid' is the same as no padding
1773        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
1774        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
1775        expect = F.conv1d(x, y)
1776        actual = F.conv1d(x, y, padding="valid")
1777        self.assertEqual(expect, actual)
1778
1779    @dtypes(torch.float, torch.cfloat)
1780    def test_conv2d_valid_padding(self, device, dtype):
1781        # Test F.conv2d padding='valid' is the same as no padding
1782        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
1783        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
1784        expect = F.conv2d(x, y)
1785        actual = F.conv2d(x, y, padding="valid")
1786        self.assertEqual(expect, actual)
1787
1788    @dtypes(torch.float, torch.cfloat)
1789    def test_conv3d_valid_padding(self, device, dtype):
1790        # Test F.conv3d padding='valid' is the same as no padding
1791        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
1792        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
1793        expect = F.conv3d(x, y)
1794        actual = F.conv3d(x, y, padding="valid")
1795        self.assertEqual(expect, actual)
1796
1797    @dtypes(torch.float, torch.cfloat)
1798    def test_conv1d_same_padding_backward(self, device, dtype):
1799        # Test F.conv1d gradients work with padding='same'
1800        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
1801        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
1802
1803        # Symmetric padding
1804        z = F.conv1d(x, y, padding=3, dilation=2)
1805        z.sum().abs().backward()
1806        gx_expect, gy_expect = x.grad, y.grad
1807        x.grad, y.grad = None, None
1808
1809        z = F.conv1d(x, y, padding="same", dilation=2)
1810        z.sum().abs().backward()
1811        self.assertEqual(gx_expect, x.grad)
1812        self.assertEqual(gy_expect, y.grad)
1813        x.grad, y.grad = None, None
1814
1815        # Asymmetric padding
1816        z = F.conv1d(x, y, padding=2)[..., 1:]
1817        z.sum().abs().backward()
1818        gx_expect, gy_expect = x.grad, y.grad
1819        x.grad, y.grad = None, None
1820
1821        z = F.conv1d(x, y, padding="same")
1822        z.sum().abs().backward()
1823        self.assertEqual(gx_expect, x.grad)
1824        self.assertEqual(gy_expect, y.grad)
1825
1826    @dtypes(torch.float, torch.cfloat)
1827    @tf32_on_and_off(0.001)
1828    def test_conv2d_same_padding_backward(self, device, dtype):
1829        # Test F.conv2d gradients work with padding='same'
1830        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
1831        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
1832
1833        # Symmetric padding
1834        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
1835        z.sum().abs().backward()
1836        gx_expect, gy_expect = x.grad, y.grad
1837        x.grad, y.grad = None, None
1838
1839        z = F.conv2d(x, y, padding="same", dilation=2)
1840        z.sum().abs().backward()
1841        self.assertEqual(gx_expect, x.grad)
1842        self.assertEqual(gy_expect, y.grad)
1843        x.grad, y.grad = None, None
1844
1845        # Asymmetric padding
1846        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
1847        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
1848        z.sum().abs().backward()
1849        gx_expect, gy_expect = x.grad, y.grad
1850        x.grad, y.grad = None, None
1851
1852        z = F.conv2d(x, y, padding="same")
1853        z.sum().abs().backward()
1854        self.assertEqual(gx_expect, x.grad)
1855        self.assertEqual(gy_expect, y.grad)
1856
1857    @dtypes(torch.double, torch.cdouble)
1858    def test_conv3d_same_padding_backward(self, device, dtype):
1859        check_forward_ad = torch.device(device).type != "xla"
1860
1861        # Test F.conv3d gradients work with padding='same'
1862        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
1863        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
1864
1865        # Symmetric padding
1866        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
1867        z.sum().abs().backward()
1868        gx_expect, gy_expect = x.grad, y.grad
1869        x.grad, y.grad = None, None
1870
1871        z = F.conv3d(x, y, padding="same", dilation=2)
1872        z.sum().abs().backward()
1873        self.assertEqual(gx_expect, x.grad)
1874        self.assertEqual(gy_expect, y.grad)
1875        x.grad, y.grad = None, None
1876
1877        gradcheck(
1878            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
1879            (x, y),
1880            check_forward_ad=check_forward_ad,
1881            nondet_tol=1e-5,
1882        )
1883        if torch.device(device).type != "cuda":
1884            # https://github.com/pytorch/pytorch/issues/70702
1885            gradgradcheck(
1886                lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
1887                (x, y),
1888                check_fwd_over_rev=True,
1889            )
1890
1891        # Asymmetric padding
1892        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
1893        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
1894        z.sum().abs().backward()
1895        gx_expect, gy_expect = x.grad, y.grad
1896        x.grad, y.grad = None, None
1897
1898        z = F.conv3d(x, y, padding="same")
1899        z.sum().abs().backward()
1900        self.assertEqual(gx_expect, x.grad)
1901        self.assertEqual(gy_expect, y.grad)
1902
1903        gradcheck(
1904            lambda x, y: F.conv3d(x, y, padding="same"),
1905            (x, y),
1906            check_forward_ad=check_forward_ad,
1907            nondet_tol=1e-5,
1908        )
1909        if torch.device(device).type != "cuda":
1910            # https://github.com/pytorch/pytorch/issues/70702
1911            gradgradcheck(
1912                lambda x, y: F.conv3d(x, y, padding="same"),
1913                (x, y),
1914                check_fwd_over_rev=True,
1915            )
1916
1917    @dtypes(torch.float, torch.cfloat)
1918    def test_conv1d_valid_padding_backward(self, device, dtype):
1919        # Test F.conv1d gradients work with padding='valid'
1920        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
1921        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
1922        F.conv1d(x, y, padding=0).sum().abs().backward()
1923        gx_expect, gy_expect = x.grad, y.grad
1924        x.grad, y.grad = None, None
1925
1926        F.conv1d(x, y, padding="valid").sum().abs().backward()
1927        gx_actual, gy_actual = x.grad, y.grad
1928        self.assertEqual(gx_expect, gx_actual)
1929        self.assertEqual(gy_expect, gy_actual)
1930
1931    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
1932    @dtypes(torch.float, torch.cfloat)
1933    @parametrize_test("mode", ("valid", "same"))
1934    def test_conv1d_vs_scipy(self, device, dtype, mode):
1935        t = make_tensor((1, 10), device=device, dtype=dtype)
1936        feat_dim = t.shape[1]
1937        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
1938        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
1939
1940        def _test(t, weight, mode):
1941            # SciPy expects two 1-D inputs.
1942            t_a = t.view(-1).cpu().numpy()
1943            w_a = weight.view(-1).cpu().numpy()
1944            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
1945
1946            kwargs = {"padding": mode}
1947            if mode == "same":
1948                # `same` padding in PyTorch conv1d is different
1949                # from SciPy
1950                p = weight.shape[2] // 2
1951                t = torch.nn.functional.pad(t, (p, p))
1952                # We have already taken care of padding
1953                kwargs.pop("padding")
1954
1955            # second input is flipped in SciPy's convolve
1956            weight_flipped = torch.flip(weight, (2,))
1957            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
1958            if mode == "same":
1959                actual = actual[:feat_dim]
1960
1961            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
1962
1963        # Global dtype for this test suite is torch.double
1964        # This leads to change in type-promotion
1965        # and conv1d outputs `complex128` for `complex64` input.
1966        with set_default_dtype(torch.float):
1967            _test(t, weight_even, mode)
1968            _test(t, weight_odd, mode)
1969
1970    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
1971    @dtypes(torch.float, torch.cfloat)
1972    @parametrize_test("mode", ("valid", "same"))
1973    def test_conv2d_vs_scipy(self, device, dtype, mode):
1974        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
1975        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
1976        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
1977
1978        def _test(t, weight, mode):
1979            # SciPy expects two 2-D inputs.
1980            t_a = t.squeeze(0).cpu().numpy()
1981            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
1982            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
1983
1984            kwargs = {"padding": mode}
1985            if mode == "same":
1986                # `same` padding in PyTorch conv2d is different
1987                # from SciPy
1988                left_right_pad = weight.shape[3] // 2
1989                top_bottom_pad = weight.shape[2] // 2
1990                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
1991                t = torch.nn.functional.pad(t, p)
1992                # We have already taken care of padding
1993                kwargs.pop("padding")
1994
1995            # second input is flipped in SciPy's convolve2d
1996            weight_flipped = torch.flip(weight, (2, 3))
1997            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
1998            if mode == "same":
1999                actual = actual[:5, :10]
2000
2001            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
2002
2003        # Global dtype for this test suite is torch.double
2004        # This leads to change in type-promotion
2005        # and conv1d outputs `complex128` for `complex64` input.
2006        with set_default_dtype(torch.float):
2007            _test(t, weight_even, mode)
2008            _test(t, weight_odd, mode)
2009
2010    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
2011    @dtypes(torch.float, torch.cfloat)
2012    @parametrize_test("mode", ("valid", "same"))
2013    def test_conv3d_vs_scipy(self, device, dtype, mode):
2014        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
2015        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
2016        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
2017
2018        def _test(t, weight, mode):
2019            # SciPy expects two 3-D inputs.
2020            t_a = t.squeeze(0).cpu().numpy()
2021            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
2022            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
2023
2024            kwargs = {"padding": mode}
2025            if mode == "same":
2026                # `same` padding in PyTorch conv3d is different
2027                # from SciPy
2028                left_right_pad = weight.shape[4] // 2
2029                top_bottom_pad = weight.shape[3] // 2
2030                front_back_pad = weight.shape[2] // 2
2031                p = (
2032                    left_right_pad,
2033                    left_right_pad,
2034                    top_bottom_pad,
2035                    top_bottom_pad,
2036                    front_back_pad,
2037                    front_back_pad,
2038                )
2039                t = torch.nn.functional.pad(t, p)
2040                # We have already taken care of padding
2041                kwargs.pop("padding")
2042
2043            # second input is flipped in SciPy's convolve
2044            weight_flipped = torch.flip(weight, (2, 3, 4))
2045            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
2046            if mode == "same":
2047                actual = actual[:5, :5, :10]
2048
2049            if tf32_is_not_fp32() and (
2050                dtype == torch.float or dtype == torch.complex64
2051            ):
2052                self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
2053            else:
2054                self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
2055
2056        # Global dtype for this test suite is torch.double
2057        # This leads to change in type-promotion
2058        # and conv1d outputs `complex128` for `complex64` input.
2059        with set_default_dtype(torch.float):
2060            _test(t, weight_even, mode)
2061            _test(t, weight_odd, mode)
2062
2063    @dtypes(torch.float, torch.complex64)
2064    def test_conv2d_valid_padding_backward(self, device, dtype):
2065        # Test F.conv2d gradients work with padding='valid'
2066        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
2067        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
2068        F.conv2d(x, y, padding=0).sum().abs().backward()
2069        gx_expect, gy_expect = x.grad, y.grad
2070        x.grad, y.grad = None, None
2071
2072        F.conv2d(x, y, padding="valid").sum().abs().backward()
2073        gx_actual, gy_actual = x.grad, y.grad
2074        self.assertEqual(gx_expect, gx_actual)
2075        self.assertEqual(gy_expect, gy_actual)
2076
2077    @dtypes(torch.double, torch.cdouble)
2078    def test_conv3d_valid_padding_backward(self, device, dtype):
2079        check_forward_ad = torch.device(device).type != "xla"
2080
2081        # Test F.conv3d gradients work with padding='valid'
2082        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
2083        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
2084        F.conv3d(x, y, padding=0).sum().abs().backward()
2085        gx_expect, gy_expect = x.grad, y.grad
2086        x.grad, y.grad = None, None
2087
2088        F.conv3d(x, y, padding="valid").sum().abs().backward()
2089        gx_actual, gy_actual = x.grad, y.grad
2090        self.assertEqual(gx_expect, gx_actual)
2091        self.assertEqual(gy_expect, gy_actual)
2092
2093        gradcheck(
2094            lambda x, y: F.conv3d(x, y, padding="valid"),
2095            (x, y),
2096            check_forward_ad=check_forward_ad,
2097        )
2098        gradgradcheck(
2099            lambda x, y: F.conv3d(x, y, padding="valid"),
2100            (x, y),
2101            check_fwd_over_rev=check_forward_ad,
2102        )
2103
2104    @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d")
2105    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
2106        # For inputs with no batch dim, verify output is the correct shape when output_size is set.
2107        # See https://github.com/pytorch/pytorch/issues/75889
2108        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
2109        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
2110        ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
2111        m = ConvTransposeNd(
2112            1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
2113        )
2114        output = m(inp, output_size=output_size)
2115        self.assertEqual(output.shape, output_size)
2116
2117    @skipMeta
2118    @parametrize_test(
2119        "input_shape,transposed,dilated,groups,layout,backend_expected",
2120        [
2121            # === slow ===
2122            subtest(
2123                (
2124                    (2, 6, 7),
2125                    False,
2126                    False,
2127                    3,
2128                    torch.strided,
2129                    torch._C._ConvBackend.Slow2d,
2130                ),
2131                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2132                name="slow1d",
2133            ),
2134            subtest(
2135                (
2136                    (2, 6, 7),
2137                    True,
2138                    False,
2139                    3,
2140                    torch.strided,
2141                    torch._C._ConvBackend.SlowTranspose2d,
2142                ),
2143                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2144                name="slow1d_transposed",
2145            ),
2146            subtest(
2147                (
2148                    (2, 6, 7),
2149                    False,
2150                    True,
2151                    3,
2152                    torch.strided,
2153                    torch._C._ConvBackend.SlowDilated2d,
2154                ),
2155                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2156                name="slow1d_dilated",
2157            ),
2158            subtest(
2159                (
2160                    (2, 6, 7),
2161                    True,
2162                    True,
2163                    3,
2164                    torch.strided,
2165                    torch._C._ConvBackend.SlowTranspose2d,
2166                ),
2167                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2168                name="slow1d_dilated_transposed",
2169            ),
2170            subtest(
2171                (
2172                    (2, 6, 7, 8),
2173                    False,
2174                    False,
2175                    3,
2176                    torch.strided,
2177                    torch._C._ConvBackend.Slow2d,
2178                ),
2179                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2180                name="slow2d",
2181            ),
2182            subtest(
2183                (
2184                    (2, 6, 7, 8),
2185                    True,
2186                    False,
2187                    3,
2188                    torch.strided,
2189                    torch._C._ConvBackend.SlowTranspose2d,
2190                ),
2191                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2192                name="slow2d_transposed",
2193            ),
2194            subtest(
2195                (
2196                    (2, 6, 7, 8),
2197                    False,
2198                    True,
2199                    3,
2200                    torch.strided,
2201                    torch._C._ConvBackend.SlowDilated2d,
2202                ),
2203                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2204                name="slow2d_dilated",
2205            ),
2206            subtest(
2207                (
2208                    (2, 6, 7, 8),
2209                    True,
2210                    True,
2211                    3,
2212                    torch.strided,
2213                    torch._C._ConvBackend.SlowTranspose2d,
2214                ),
2215                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2216                name="slow2d_dilated_transposed",
2217            ),
2218            subtest(
2219                (
2220                    (2, 6, 7, 8, 9),
2221                    False,
2222                    False,
2223                    3,
2224                    torch.strided,
2225                    torch._C._ConvBackend.Slow3d,
2226                ),
2227                decorators=[onlyCPU, disableMkldnn],
2228                name="slow3d_cpu",
2229            ),
2230            # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
2231            subtest(
2232                (
2233                    (2, 6, 7, 8, 9),
2234                    False,
2235                    False,
2236                    3,
2237                    torch.strided,
2238                    torch._C._ConvBackend.SlowDilated3d,
2239                ),
2240                decorators=[onlyCUDA, disablecuDNN],
2241                name="slow3d_cuda",
2242            ),
2243            # FIXME: RuntimeError: CUDA out of memory.
2244            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
2245            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
2246            subtest(
2247                (
2248                    (2, 6, 7, 8, 9),
2249                    False,
2250                    True,
2251                    3,
2252                    torch.strided,
2253                    torch._C._ConvBackend.SlowDilated3d,
2254                ),
2255                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2256                name="slow3d_dilated",
2257            ),
2258            # FIXME: RuntimeError: CUDA out of memory.
2259            # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
2260            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
2261            subtest(
2262                (
2263                    (0, 6, 7),
2264                    False,
2265                    False,
2266                    3,
2267                    torch.strided,
2268                    torch._C._ConvBackend.Empty,
2269                ),
2270                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2271                name="empty_batch1d",
2272            ),
2273            subtest(
2274                (
2275                    (2, 0, 7),
2276                    False,
2277                    False,
2278                    3,
2279                    torch.strided,
2280                    torch._C._ConvBackend.Empty,
2281                ),
2282                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2283                name="empty_channel1d",
2284            ),
2285            subtest(
2286                (
2287                    (0, 0, 7),
2288                    False,
2289                    False,
2290                    3,
2291                    torch.strided,
2292                    torch._C._ConvBackend.Empty,
2293                ),
2294                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2295                name="empty_batch_channel1d",
2296            ),
2297            subtest(
2298                (
2299                    (0, 6, 7, 8),
2300                    False,
2301                    False,
2302                    3,
2303                    torch.strided,
2304                    torch._C._ConvBackend.Empty,
2305                ),
2306                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2307                name="empty_batch2d",
2308            ),
2309            subtest(
2310                (
2311                    (2, 0, 7, 8),
2312                    False,
2313                    False,
2314                    3,
2315                    torch.strided,
2316                    torch._C._ConvBackend.Empty,
2317                ),
2318                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2319                name="empty_channel2d",
2320            ),
2321            subtest(
2322                (
2323                    (0, 0, 7, 8),
2324                    False,
2325                    False,
2326                    3,
2327                    torch.strided,
2328                    torch._C._ConvBackend.Empty,
2329                ),
2330                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2331                name="empty_batch_channel2d",
2332            ),
2333            subtest(
2334                (
2335                    (0, 6, 7, 8, 9),
2336                    False,
2337                    False,
2338                    3,
2339                    torch.strided,
2340                    torch._C._ConvBackend.Empty,
2341                ),
2342                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2343                name="empty_batch3d",
2344            ),
2345            subtest(
2346                (
2347                    (2, 0, 7, 8, 9),
2348                    False,
2349                    False,
2350                    3,
2351                    torch.strided,
2352                    torch._C._ConvBackend.Empty,
2353                ),
2354                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2355                name="empty_channel3d",
2356            ),
2357            subtest(
2358                (
2359                    (0, 0, 7, 8, 9),
2360                    False,
2361                    False,
2362                    3,
2363                    torch.strided,
2364                    torch._C._ConvBackend.Empty,
2365                ),
2366                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2367                name="empty_batch_channel3d",
2368            ),
2369            # === cuda ===
2370            # Note that disablecuDNN disables miopen as well.
2371            subtest(
2372                (
2373                    (2, 6, 7),
2374                    False,
2375                    False,
2376                    6,
2377                    torch.strided,
2378                    torch._C._ConvBackend.CudaDepthwise2d,
2379                ),
2380                decorators=[onlyCUDA, disablecuDNN],
2381                name="cuda_depthwise1d",
2382            ),
2383            subtest(
2384                (
2385                    (2, 6, 7, 8),
2386                    False,
2387                    False,
2388                    6,
2389                    torch.strided,
2390                    torch._C._ConvBackend.CudaDepthwise2d,
2391                ),
2392                decorators=[onlyCUDA, disablecuDNN],
2393                name="cuda_depthwise2d",
2394            ),
2395            subtest(
2396                (
2397                    (2, 6, 7, 8, 9),
2398                    False,
2399                    False,
2400                    6,
2401                    torch.strided,
2402                    torch._C._ConvBackend.CudaDepthwise3d,
2403                ),
2404                decorators=[onlyCUDA, disablecuDNN],
2405                name="cuda_depthwise3d",
2406            ),
2407            # === cudnn ===
2408            subtest(
2409                (
2410                    (2, 6, 7),
2411                    False,
2412                    False,
2413                    3,
2414                    torch.strided,
2415                    torch._C._ConvBackend.Cudnn,
2416                ),
2417                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2418                name="cudnn1d",
2419            ),
2420            subtest(
2421                (
2422                    (2, 6, 7, 8),
2423                    False,
2424                    False,
2425                    3,
2426                    torch.strided,
2427                    torch._C._ConvBackend.Cudnn,
2428                ),
2429                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2430                name="cudnn2d",
2431            ),
2432            subtest(
2433                (
2434                    (2, 6, 7, 8, 9),
2435                    False,
2436                    False,
2437                    3,
2438                    torch.strided,
2439                    torch._C._ConvBackend.Cudnn,
2440                ),
2441                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2442                name="cudnn3d",
2443            ),
2444            subtest(
2445                (
2446                    (2, 6, 7),
2447                    True,
2448                    False,
2449                    3,
2450                    torch.strided,
2451                    torch._C._ConvBackend.CudnnTranspose,
2452                ),
2453                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2454                name="cudnn1d_transposed",
2455            ),
2456            subtest(
2457                (
2458                    (2, 6, 7, 8),
2459                    True,
2460                    False,
2461                    3,
2462                    torch.strided,
2463                    torch._C._ConvBackend.CudnnTranspose,
2464                ),
2465                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2466                name="cudnn2d_transposed",
2467            ),
2468            # FIXME: RuntimeError: CUDA out of memory.
2469            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
2470            #         decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
2471            # === miopen ===
2472            subtest(
2473                (
2474                    (2, 6, 7),
2475                    False,
2476                    False,
2477                    3,
2478                    torch.strided,
2479                    torch._C._ConvBackend.Miopen,
2480                ),
2481                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2482                name="miopen1d",
2483            ),
2484            subtest(
2485                (
2486                    (2, 6, 7, 8),
2487                    False,
2488                    False,
2489                    3,
2490                    torch.strided,
2491                    torch._C._ConvBackend.Miopen,
2492                ),
2493                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2494                name="miopen2d",
2495            ),
2496            subtest(
2497                (
2498                    (2, 6, 7, 8, 9),
2499                    False,
2500                    False,
2501                    3,
2502                    torch.strided,
2503                    torch._C._ConvBackend.Miopen,
2504                ),
2505                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2506                name="miopen3d",
2507            ),
2508            subtest(
2509                (
2510                    (2, 6, 7),
2511                    True,
2512                    False,
2513                    3,
2514                    torch.strided,
2515                    torch._C._ConvBackend.MiopenTranspose,
2516                ),
2517                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2518                name="miopen1d_transposed",
2519            ),
2520            subtest(
2521                (
2522                    (2, 6, 7, 8),
2523                    True,
2524                    False,
2525                    3,
2526                    torch.strided,
2527                    torch._C._ConvBackend.MiopenTranspose,
2528                ),
2529                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2530                name="miopen2d_transposed",
2531            ),
2532            subtest(
2533                (
2534                    (2, 6, 7, 8, 9),
2535                    True,
2536                    False,
2537                    3,
2538                    torch.strided,
2539                    torch._C._ConvBackend.MiopenTranspose,
2540                ),
2541                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2542                name="miopen3d_transposed",
2543            ),
2544            subtest(
2545                (
2546                    (2, 6, 7),
2547                    False,
2548                    False,
2549                    6,
2550                    torch.strided,
2551                    torch._C._ConvBackend.MiopenDepthwise,
2552                ),
2553                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2554                name="miopen_depthwise1d",
2555            ),
2556            subtest(
2557                (
2558                    (2, 6, 7, 8),
2559                    False,
2560                    False,
2561                    6,
2562                    torch.strided,
2563                    torch._C._ConvBackend.MiopenDepthwise,
2564                ),
2565                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2566                name="miopen_depthwise2d",
2567            ),
2568            subtest(
2569                (
2570                    (2, 6, 7, 8, 9),
2571                    False,
2572                    False,
2573                    6,
2574                    torch.strided,
2575                    torch._C._ConvBackend.MiopenDepthwise,
2576                ),
2577                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2578                name="miopen_depthwise3d",
2579            ),
2580            # === mkldnn ===
2581            subtest(
2582                (
2583                    (2, 6, 7),
2584                    False,
2585                    False,
2586                    3,
2587                    torch._mkldnn,
2588                    torch._C._ConvBackend.Mkldnn,
2589                ),
2590                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2591                name="mkldnn1d",
2592            ),
2593            subtest(
2594                (
2595                    (2, 6, 7, 8),
2596                    False,
2597                    False,
2598                    3,
2599                    torch._mkldnn,
2600                    torch._C._ConvBackend.Mkldnn,
2601                ),
2602                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2603                name="mkldnn2d",
2604            ),
2605            subtest(
2606                (
2607                    (2, 6, 7, 8, 9),
2608                    False,
2609                    False,
2610                    3,
2611                    torch._mkldnn,
2612                    torch._C._ConvBackend.Mkldnn,
2613                ),
2614                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2615                name="mkldnn3d",
2616            ),
2617            # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
2618            subtest(
2619                (
2620                    (2, 6, 7),
2621                    True,
2622                    False,
2623                    3,
2624                    torch._mkldnn,
2625                    torch._C._ConvBackend.Mkldnn,
2626                ),
2627                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2628                name="mkldnn1d_transposed",
2629            ),
2630            subtest(
2631                (
2632                    (2, 6, 7, 8),
2633                    True,
2634                    False,
2635                    3,
2636                    torch._mkldnn,
2637                    torch._C._ConvBackend.Mkldnn,
2638                ),
2639                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2640                name="mkldnn2d_transposed",
2641            ),
2642            subtest(
2643                (
2644                    (2, 6, 7, 8, 9),
2645                    True,
2646                    False,
2647                    3,
2648                    torch._mkldnn,
2649                    torch._C._ConvBackend.Mkldnn,
2650                ),
2651                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2652                name="mkldnn3d_transposed",
2653            ),
2654            subtest(
2655                (
2656                    (2, 6, 7),
2657                    False,
2658                    True,
2659                    3,
2660                    torch.strided,
2661                    torch._C._ConvBackend.Mkldnn,
2662                ),
2663                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2664                name="mkldnn1d_cpu_input",
2665            ),
2666            subtest(
2667                (
2668                    (2, 6, 7, 8),
2669                    False,
2670                    True,
2671                    3,
2672                    torch.strided,
2673                    torch._C._ConvBackend.Mkldnn,
2674                ),
2675                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2676                name="mkldnn2d_cpu_input",
2677            ),
2678            subtest(
2679                (
2680                    (2, 6, 7, 8, 9),
2681                    False,
2682                    True,
2683                    3,
2684                    torch.strided,
2685                    torch._C._ConvBackend.Mkldnn,
2686                ),
2687                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2688                name="mkldnn3d_cpu_input",
2689            ),
2690            subtest(
2691                (
2692                    (0, 6, 7),
2693                    False,
2694                    False,
2695                    3,
2696                    torch._mkldnn,
2697                    torch._C._ConvBackend.MkldnnEmpty,
2698                ),
2699                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2700                name="mkldnn_empty_batch1d",
2701            ),
2702            subtest(
2703                (
2704                    (2, 0, 7),
2705                    False,
2706                    False,
2707                    3,
2708                    torch._mkldnn,
2709                    torch._C._ConvBackend.MkldnnEmpty,
2710                ),
2711                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2712                name="mkldnn_empty_channel1d",
2713            ),
2714            subtest(
2715                (
2716                    (0, 0, 7),
2717                    False,
2718                    False,
2719                    3,
2720                    torch._mkldnn,
2721                    torch._C._ConvBackend.MkldnnEmpty,
2722                ),
2723                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2724                name="mkldnn_empty_batch_channel1d",
2725            ),
2726            subtest(
2727                (
2728                    (0, 6, 7, 8),
2729                    False,
2730                    False,
2731                    3,
2732                    torch._mkldnn,
2733                    torch._C._ConvBackend.MkldnnEmpty,
2734                ),
2735                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2736                name="mkldnn_empty_batch2d",
2737            ),
2738            subtest(
2739                (
2740                    (2, 0, 7, 8),
2741                    False,
2742                    False,
2743                    3,
2744                    torch._mkldnn,
2745                    torch._C._ConvBackend.MkldnnEmpty,
2746                ),
2747                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2748                name="mkldnn_empty_channel2d",
2749            ),
2750            subtest(
2751                (
2752                    (0, 0, 7, 8),
2753                    False,
2754                    False,
2755                    3,
2756                    torch._mkldnn,
2757                    torch._C._ConvBackend.MkldnnEmpty,
2758                ),
2759                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2760                name="mkldnn_empty_batch_channel2d",
2761            ),
2762            subtest(
2763                (
2764                    (0, 6, 7, 8, 9),
2765                    False,
2766                    False,
2767                    3,
2768                    torch._mkldnn,
2769                    torch._C._ConvBackend.MkldnnEmpty,
2770                ),
2771                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2772                name="mkldnn_empty_batch3d",
2773            ),
2774            subtest(
2775                (
2776                    (2, 0, 7, 8, 9),
2777                    False,
2778                    False,
2779                    3,
2780                    torch._mkldnn,
2781                    torch._C._ConvBackend.MkldnnEmpty,
2782                ),
2783                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2784                name="mkldnn_empty_channel3d",
2785            ),
2786            subtest(
2787                (
2788                    (0, 0, 7, 8, 9),
2789                    False,
2790                    False,
2791                    3,
2792                    torch._mkldnn,
2793                    torch._C._ConvBackend.MkldnnEmpty,
2794                ),
2795                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2796                name="mkldnn_empty_batch_channel3d",
2797            ),
2798            # Note: Tests for mobile backends are not currently supported. This comprises
2799            # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
2800            # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
2801        ],
2802    )
2803    # Test with both bias and no bias.
2804    @parametrize_test("has_bias", [False, True])
2805    # Test with both stride=1 and stride>1 cases.
2806    @parametrize_test("strided", [False, True])
2807    # Test with both contiguous and non-contiguous inputs.
2808    @parametrize_test("contiguous", [False, True])
2809    def test_conv_backend(
2810        self,
2811        device,
2812        input_shape,
2813        has_bias,
2814        strided,
2815        contiguous,
2816        transposed,
2817        dilated,
2818        groups,
2819        layout,
2820        backend_expected,
2821    ):
2822        # Build up inputs.
2823        dtype = torch.float32
2824        C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
2825        x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
2826        weight = torch.randn(
2827            C_in if transposed else C_out,
2828            C_out // groups if transposed else C_in // groups,
2829            *[kernel_size for _ in range(dim)],
2830            device=device,
2831            dtype=dtype,
2832            requires_grad=True,
2833        )
2834        bias = (
2835            torch.randn(C_out, device=device, dtype=dtype, requires_grad=True)
2836            if has_bias
2837            else None
2838        )
2839
2840        def _make_noncontiguous(inp):
2841            if inp is None:
2842                return None
2843            old_requires_grad = inp.requires_grad
2844            inp = torch.repeat_interleave(inp, 2, dim=-1)
2845            inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
2846            return inp
2847
2848        if not contiguous:
2849            x = _make_noncontiguous(x)
2850            weight = _make_noncontiguous(weight)
2851            bias = _make_noncontiguous(bias)
2852
2853        if layout is torch._mkldnn:
2854            x = x.to_mkldnn()
2855            # Note that weight and bias are not supported as mkldnn tensors during training.
2856
2857        stride = (2,) * dim if strided else (1,) * dim
2858        padding = (0,) * dim
2859        dilation = (2,) * dim if dilated else (1,) * dim
2860        output_padding = (0,) * dim
2861        inputs = [
2862            x,
2863            weight,
2864            bias,
2865            stride,
2866            padding,
2867            dilation,
2868            transposed,
2869            output_padding,
2870            groups,
2871        ]
2872
2873        # Ensure correct backend is selected.
2874        backend_actual = torch._C._select_conv_backend(*inputs)
2875        self.assertEqual(backend_actual, backend_expected)
2876
2877        # Ensure backward call succeeds.
2878        convolution = torch.ops.aten.convolution
2879        output = convolution(*inputs)
2880        grad_output = torch.randn(output.shape, device=device, dtype=dtype)
2881        if not contiguous:
2882            grad_output = _make_noncontiguous(grad_output)
2883        if layout is torch._mkldnn:
2884            grad_output = grad_output.to_mkldnn()
2885        output.backward(grad_output)
2886
2887        # mkldnn doesn't support gradcheck :(
2888        if layout is torch._mkldnn:
2889            return
2890
2891        if backend_actual != torch._C._ConvBackend.Empty:  # FIXME: forward AD fails
2892            # Forward AD and forward-over-reverse AD smoke test in float32
2893            # TODO: remove this if we introduce per-op gradient tests for float32
2894            with fwAD.dual_level():
2895                dual_inputs = [
2896                    (
2897                        fwAD.make_dual(i, torch.rand_like(i))
2898                        if isinstance(i, torch.Tensor)
2899                        else i
2900                    )
2901                    for i in inputs
2902                ]
2903                # Forward AD
2904                output = convolution(*dual_inputs)
2905                # Forward over reverse AD
2906                grad_output_d = fwAD.make_dual(
2907                    torch.rand_like(output), torch.rand_like(output)
2908                )
2909                if has_bias:
2910                    torch.autograd.grad(output, [x, weight, bias], grad_output_d)
2911                else:
2912                    torch.autograd.grad(output, [x, weight], grad_output_d)
2913
2914        # Convert to float64 for gradcheck.
2915        x = x.to(torch.float64).detach().requires_grad_(True)
2916        weight = weight.to(torch.float64).detach().requires_grad_(True)
2917        if bias is not None:
2918            bias = bias.to(torch.float64).detach().requires_grad_(True)
2919        inputs = [
2920            x,
2921            weight,
2922            bias,
2923            stride,
2924            padding,
2925            dilation,
2926            transposed,
2927            output_padding,
2928            groups,
2929        ]
2930
2931        # Set some backend-specific validation settings.
2932        gradcheck_nondet_tol = 0.0
2933        if torch.backends.cudnn.is_available():
2934            # cuDNN introduces non-determinism
2935            gradcheck_nondet_tol = GRADCHECK_NONDET_TOL
2936
2937        self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
2938
2939        # double backward doesn't support bias gradients
2940        if bias is not None:
2941            bias.requires_grad_(False)
2942        self.assertTrue(
2943            gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)
2944        )
2945
2946    @onlyCPU
2947    def test_conv_contiguous_for_oneDNN(self):
2948        # See https://github.com/pytorch/pytorch/issues/80837.
2949        for dtype in [torch.float, torch.bfloat16, torch.half]:
2950            conv = nn.Conv2d(
2951                1,
2952                128,
2953                kernel_size=(5, 2),
2954                stride=(2, 1),
2955                padding=(0, 1),
2956                dilation=(1, 1),
2957                groups=1,
2958                bias=True,
2959                padding_mode="zeros",
2960            ).to(dtype=dtype)
2961
2962            x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
2963            x = torch.transpose(x, 1, 4)
2964            x2 = x[..., 0]
2965            inputs = [
2966                x2,
2967                conv.weight,
2968                conv.bias,
2969                (2, 1),
2970                (0, 1),
2971                (1, 1),
2972                False,
2973                (0, 1),
2974                1,
2975            ]
2976            if torch.backends.mkldnn.is_available():
2977                y = conv(x2)
2978                # Disable MKLDNN explicitly
2979                with torch.backends.mkldnn.flags(enabled=False):
2980                    y_ = conv(x2)
2981                    self.assertEqual(y, y_)
2982
2983    @onlyCPU
2984    def test_conv_ic1_channels_last_for_oneDNN(self):
2985        # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
2986        for dtype in [torch.float, torch.bfloat16, torch.half]:
2987            conv = torch.nn.Conv2d(
2988                1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False
2989            )
2990            conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
2991            x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
2992            if torch.backends.mkldnn.is_available():
2993                y = conv(x)
2994                # Disable MKLDNN explicitly
2995                with torch.backends.mkldnn.flags(enabled=False):
2996                    y_ = conv(x)
2997                    self.assertEqual(y, y_)
2998
2999    @dtypes(torch.float, torch.cfloat)
3000    def test_conv_empty_channel(self, device, dtype):
3001        in_channels = 0
3002        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
3003        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
3004        _test_module_empty_input(self, mod, inp, check_size=False)
3005
3006        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3007            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
3008            mod(inp)
3009
3010        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
3011        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
3012        _test_module_empty_input(self, mod, inp, check_size=False)
3013
3014        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3015            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
3016            mod(inp)
3017
3018        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
3019        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
3020        _test_module_empty_input(self, mod, inp, check_size=False)
3021
3022        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3023            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
3024            mod(inp)
3025
3026    def test_group_conv_empty(self, device):
3027        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
3028            device
3029        )
3030        inp = torch.randn(0, 4, 4, 4, device=device)
3031        _test_module_empty_input(self, mod, inp, check_size=False)
3032        if self.device_type == "cuda" and self.has_cudnn():
3033            with torch.backends.cudnn.flags(enabled=False):
3034                _test_module_empty_input(self, mod, inp, check_size=False)
3035
3036    def test_group_convTranspose_empty(self, device):
3037        mod = torch.nn.ConvTranspose2d(
3038            4, 4, stride=2, kernel_size=3, padding=1, groups=4
3039        ).to(device)
3040        inp = torch.randn(0, 4, 4, 4, device=device)
3041        _test_module_empty_input(self, mod, inp, check_size=False)
3042        if self.device_type == "cuda" and self.has_cudnn():
3043            with torch.backends.cudnn.flags(enabled=False):
3044                _test_module_empty_input(self, mod, inp, check_size=False)
3045
3046    def test_convTranspose_empty(self, device):
3047        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
3048            device
3049        )
3050        inp = torch.randn(0, 4, 4, 4, device=device)
3051        _test_module_empty_input(self, mod, inp, check_size=False)
3052        if self.device_type == "cuda" and self.has_cudnn():
3053            with torch.backends.cudnn.flags(enabled=False):
3054                _test_module_empty_input(self, mod, inp, check_size=False)
3055
3056    @onlyCUDA
3057    @largeTensorTest("12GB")
3058    def test_conv_large_nosplit(self, device):
3059        # Here we just test the convolution correctly route to the fallback implementation
3060        # that is, it does not crash. The correctness of fallback implementation should be
3061        # covered in other tests
3062        dtype = torch.half if self.device_type == "cuda" else torch.float
3063        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
3064        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
3065        conv1(input_large)
3066        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
3067        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
3068        conv2(input_large)
3069
3070    def test_conv_noncontig_weights(self, device):
3071        for dim in (1, 2, 3):
3072            for grouped in (False, True):
3073                nc = 3
3074                groups = 3 if grouped else 1
3075                w = torch.randn([3] * dim, device=device)
3076                w = w.expand([nc, int(nc / groups)] + list(w.shape))
3077                w = w.detach().requires_grad_()
3078                x = torch.randn(
3079                    [1, nc] + ([5] * dim), device=device, requires_grad=True
3080                )
3081                y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
3082                y.sum().backward()
3083                y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
3084                y.sum().backward()
3085
3086    def test_conv_noncontig_weights_and_bias(self, device):
3087        # need floats to exercise https://github.com/pytorch/pytorch/issues/16018
3088        for bias in [True, False]:
3089            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
3090                device, torch.float
3091            )
3092
3093            input_nc = torch.randn(
3094                (1, 3, 224, 224, 2), device=device, dtype=torch.float
3095            )[:, :, :, :, 1]
3096            input_c = input_nc.contiguous()
3097
3098            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
3099                :, :, :, :, 1
3100            ]
3101            conv1.weight = nn.Parameter(weight_nc)
3102            weight_c = conv1.weight.contiguous()
3103
3104            if bias:
3105                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
3106                conv1.bias = nn.Parameter(bias_nc)
3107                bias_c = conv1.bias.contiguous()
3108
3109            out1 = conv1(input_nc)
3110            conv1.weight = nn.Parameter(weight_c)
3111            if bias:
3112                conv1.bias = nn.Parameter(bias_c)
3113            out2 = conv1(input_c)
3114            self.assertEqual(out1, out2)
3115
3116    @onlyCUDA
3117    @largeTensorTest("12GB")
3118    @skipIfRocmVersionLessThan((6, 0))
3119    def test_conv_transposed_large(self, device):
3120        dtype = torch.half if self.device_type == "cuda" else torch.float
3121        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
3122        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
3123        # forward
3124        ret = conv(input_large)
3125        maxdiff0 = (
3126            (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
3127            .abs_()
3128            .max()
3129            .item()
3130        )
3131        maxdiff1 = (
3132            (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
3133            .abs_()
3134            .max()
3135            .item()
3136        )
3137        maxdiff2 = (
3138            (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
3139            .abs_()
3140            .max()
3141            .item()
3142        )
3143        maxdiff3 = (
3144            (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
3145            .abs_()
3146            .max()
3147            .item()
3148        )
3149        if self.device_type == "cuda":
3150            # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
3151            self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
3152            self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
3153            self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
3154            self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
3155        else:
3156            self.assertEqual(maxdiff0, 0)
3157            self.assertEqual(maxdiff1, 0)
3158            self.assertEqual(maxdiff2, 0)
3159            self.assertEqual(maxdiff3, 0)
3160
3161    @onlyCUDA
3162    @skipCUDAIfRocm
3163    @largeTensorTest("12GB")
3164    def test_conv_large(self, device):
3165        dtype = torch.half if self.device_type == "cuda" else torch.float
3166        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
3167        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
3168        # forward
3169        ret = conv(input_large)
3170        self.assertEqual(ret[:2048], conv(input_large[:2048]))
3171        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
3172        self.assertEqual(ret[4096:], conv(input_large[4096:]))
3173
3174        # backward
3175        conv.zero_grad()
3176        # When computing the backward, we are using the `max(dim=1)`` to create
3177        # some sparsity. Without this sparsity, the rounding error would be
3178        # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
3179        ret.view(4097, -1).max(dim=1).values.sum().backward()
3180        del ret
3181        grad1 = conv.weight.grad.detach().clone()
3182        conv.zero_grad()
3183        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
3184        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
3185        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
3186        grad2 = conv.weight.grad.detach().clone()
3187        # gradients are at the order of hundreds, we need to scale it to
3188        # the order of one so that we can compare
3189        scale = 1 / grad2.abs().mean()
3190        grad1 = grad1 * scale
3191        grad2 = grad2 * scale
3192        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
3193
3194    @onlyCUDA
3195    @skipCUDAIfRocm
3196    @largeTensorTest("20GB", "cpu")
3197    @largeTensorTest("60GB", "cuda")
3198    def test_conv_large_batch_1(self, device):
3199        in_channels = 514
3200        dim = 2048
3201        out_channels = 1
3202        kernel_size = 3
3203        stride = 1
3204        padding = 1
3205
3206        input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half()
3207        model = (
3208            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
3209            .cuda()
3210            .half()
3211        )
3212        output = model(input_tensor)
3213        model_cpu = model.cpu().float()
3214        output_cpu = model(input_tensor.float().cpu())
3215        self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3)
3216
3217    @onlyCUDA
3218    @skipCUDAIfRocm
3219    @largeTensorTest("24GB", "cpu")
3220    @largeTensorTest("20GB", "cuda")
3221    def test_conv3d_large_batch_1(self, device):
3222        x = torch.rand(1, 32, 512, 512, 256)
3223        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
3224        yref = m(x)
3225        y = m.to(device=device)(x.to(device=device))
3226        self.assertEqual(yref, y.cpu())
3227
3228    @onlyCUDA
3229    @skipCUDAIfNoCudnn
3230    def test_contig_wrong_stride_cudnn(self, device):
3231        # x has to have batch_size 1 to test contiguous checks
3232        x = torch.randn(1, 16, 5, 5, device=device)
3233        stride = list(x.stride())
3234        stride[0] = 20
3235        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
3236        x.set_(x.storage(), 0, x.size(), stride)
3237        self.assertTrue(x.is_contiguous())
3238        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
3239        F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))
3240
3241    @onlyCUDA
3242    @tf32_on_and_off(0.005)
3243    def test_Conv2d_size_1_kernel(self, device):
3244        x_cpu = torch.randn(2, 3, 5, 5)
3245        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
3246        y_cpu = conv_cpu(x_cpu)
3247        y = torch.rand_like(y_cpu)
3248        y_cpu.backward(y)
3249
3250        with cudnn.flags(enabled=False):
3251            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
3252            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3253            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3254            y_cuda = conv_cuda(x_cpu.to(device))
3255            y_cuda.backward(y.to(device))
3256
3257        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3258        self.assertEqual(
3259            conv_cpu.bias.grad.data,
3260            conv_cuda.bias.grad.data,
3261            atol=1e-5,
3262            rtol=0,
3263            exact_device=False,
3264        )
3265        self.assertEqual(
3266            conv_cpu.weight.grad.data,
3267            conv_cuda.weight.grad.data,
3268            atol=1e-5,
3269            rtol=0,
3270            exact_device=False,
3271        )
3272
3273    @onlyCUDA
3274    @tf32_on_and_off(0.005)
3275    def test_ConvTranspose2d_size_1_kernel(self, device):
3276        x_cpu = torch.randn(2, 3, 5, 5)
3277        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
3278        y_cpu = conv_cpu(x_cpu)
3279        y = torch.rand_like(y_cpu)
3280        y_cpu.backward(y)
3281
3282        with cudnn.flags(enabled=False):
3283            conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
3284            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3285            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3286            y_cuda = conv_cuda(x_cpu.to(device))
3287            y_cuda.backward(y.to(device))
3288
3289        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3290        self.assertEqual(
3291            conv_cpu.bias.grad.data,
3292            conv_cuda.bias.grad.data,
3293            atol=1e-5,
3294            rtol=0,
3295            exact_device=False,
3296        )
3297        self.assertEqual(
3298            conv_cpu.weight.grad.data,
3299            conv_cuda.weight.grad.data,
3300            atol=1e-5,
3301            rtol=0,
3302            exact_device=False,
3303        )
3304
3305    @onlyCUDA
3306    def test_ConvTranspose3d_size_1_kernel(self, device):
3307        with set_default_dtype(torch.double):
3308            x_cpu = torch.randn(2, 3, 3, 5, 5)
3309            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
3310            y_cpu = conv_cpu(x_cpu)
3311            y = torch.rand_like(y_cpu)
3312            y_cpu.backward(y)
3313
3314            with cudnn.flags(enabled=False):
3315                conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
3316                conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3317                conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3318                y_cuda = conv_cuda(x_cpu.to(device))
3319                y_cuda.backward(y.to(device))
3320
3321            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3322            self.assertEqual(
3323                conv_cpu.bias.grad.data,
3324                conv_cuda.bias.grad.data,
3325                atol=1e-5,
3326                rtol=0,
3327                exact_device=False,
3328            )
3329            self.assertEqual(
3330                conv_cpu.weight.grad.data,
3331                conv_cuda.weight.grad.data,
3332                atol=1e-5,
3333                rtol=0,
3334                exact_device=False,
3335            )
3336
3337    @dtypesIfCUDA(
3338        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
3339    )
3340    @dtypes(torch.float)
3341    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
3342    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
3343    def test_Conv2d_naive_groups(self, device, dtype):
3344        # Check that grouped convolutions matches two half convolutions
3345        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
3346        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
3347        output = m(i)
3348        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
3349        output.backward(grad_output)
3350
3351        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
3352        m1.weight.data.copy_(m.weight.data[:2])
3353        m1.bias.data.copy_(m.bias.data[:2])
3354        i1 = i.data[:, :2].contiguous().requires_grad_(True)
3355        output1 = m1(i1)
3356        output1.backward(grad_output[:, :2].contiguous())
3357
3358        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
3359        m2.weight.data.copy_(m.weight.data[2:])
3360        m2.bias.data.copy_(m.bias.data[2:])
3361        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
3362        output2 = m2(i2)
3363        output2.backward(grad_output[:, 2:].contiguous())
3364
3365        self.assertEqual(output, torch.cat([output1, output2], 1))
3366        self.assertEqual(
3367            i.grad.data,
3368            torch.cat([i1.grad.data, i2.grad.data], 1),
3369            atol=dtype2prec_DONTUSE[dtype],
3370            rtol=0,
3371        )
3372        self.assertEqual(
3373            m.bias.grad.data,
3374            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
3375            atol=dtype2prec_DONTUSE[dtype],
3376            rtol=0,
3377        )
3378        self.assertEqual(
3379            m.weight.grad.data,
3380            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
3381            atol=dtype2prec_DONTUSE[dtype],
3382            rtol=0,
3383        )
3384
3385    @dtypes(torch.double, torch.cdouble)
3386    def test_Conv2d_backward_depthwise(self, device, dtype):
3387        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
3388        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
3389
3390        def conv2d_depthwise(x, weight):
3391            return torch.nn.functional.conv2d(
3392                x, weight, bias=None, stride=(1, 10), groups=2
3393            )
3394
3395        for cudnn_enabled in [False, True]:
3396            with torch.backends.cudnn.flags(enabled=cudnn_enabled):
3397                torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
3398
3399    @onlyCPU
3400    @dtypes(torch.float, torch.double)
3401    def test_conv_thnn_nhwc(self, device, dtype):
3402        def helper(
3403            mod,
3404            n,
3405            c,
3406            h,
3407            w,
3408            out_channels,
3409            kernel_size,
3410            dilation,
3411            groups,
3412            input_format,
3413            weight_format,
3414        ):
3415            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3416                memory_format=input_format
3417            )
3418            input.requires_grad_()
3419            conv = mod(
3420                c, out_channels, kernel_size, dilation=dilation, groups=groups
3421            ).to(device="cpu", dtype=dtype, memory_format=weight_format)
3422            for p in conv.parameters():
3423                p.data = torch.randint_like(p, -3, 3)
3424
3425            ref_input = input.detach().clone().contiguous().requires_grad_()
3426            ref_conv = mod(
3427                c, out_channels, kernel_size, dilation=dilation, groups=groups
3428            )
3429            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3430            ref_conv.load_state_dict(conv.state_dict())
3431            ref_conv = ref_conv.to(
3432                device="cpu", dtype=dtype, memory_format=torch.contiguous_format
3433            )
3434
3435            out = conv(input)
3436            ref_out = ref_conv(ref_input)
3437
3438            grad = torch.randint_like(out, -3, 3)
3439            ref_grad = grad.detach().clone().contiguous()
3440
3441            out.backward(grad)
3442            ref_out.backward(ref_grad)
3443
3444            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
3445            self.assertTrue(ref_out.is_contiguous())
3446            self.assertEqual(out, ref_out, exact_dtype=False)
3447            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3448            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3449            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3450
3451        with torch.backends.mkldnn.flags(enabled=False):
3452            formats = [
3453                [torch.channels_last, torch.channels_last],
3454                [torch.channels_last, torch.contiguous_format],
3455                [torch.contiguous_format, torch.channels_last],
3456            ]
3457            for input_format, weight_format in formats:
3458                # non-dilated conv: thnn_conv2d normal path (with im2col)
3459                helper(
3460                    nn.Conv2d,
3461                    2,
3462                    8,
3463                    4,
3464                    4,
3465                    out_channels=4,
3466                    kernel_size=3,
3467                    dilation=1,
3468                    groups=1,
3469                    input_format=input_format,
3470                    weight_format=weight_format,
3471                )
3472                helper(
3473                    nn.Conv2d,
3474                    2,
3475                    8,
3476                    4,
3477                    4,
3478                    out_channels=8,
3479                    kernel_size=3,
3480                    dilation=1,
3481                    groups=8,
3482                    input_format=input_format,
3483                    weight_format=weight_format,
3484                )
3485                # test when input chanels is 1 and not converted to channels last
3486                helper(
3487                    nn.Conv2d,
3488                    2,
3489                    1,
3490                    10,
3491                    10,
3492                    out_channels=8,
3493                    kernel_size=3,
3494                    dilation=1,
3495                    groups=1,
3496                    input_format=torch.contiguous_format,
3497                    weight_format=torch.channels_last,
3498                )
3499                # non-dilated conv: thnn_conv2d fast path (skip im2col)
3500                helper(
3501                    nn.Conv2d,
3502                    1,
3503                    16,
3504                    56,
3505                    56,
3506                    out_channels=16,
3507                    kernel_size=1,
3508                    dilation=1,
3509                    groups=1,
3510                    input_format=input_format,
3511                    weight_format=weight_format,
3512                )
3513                # ic == oc == 1 here, so need to stick input to CL to activate channels last
3514                helper(
3515                    nn.Conv2d,
3516                    1,
3517                    16,
3518                    56,
3519                    56,
3520                    out_channels=16,
3521                    kernel_size=1,
3522                    dilation=1,
3523                    groups=16,
3524                    input_format=torch.channels_last,
3525                    weight_format=weight_format,
3526                )
3527                # dilated conv: slow_conv_dilated2d
3528                helper(
3529                    nn.Conv2d,
3530                    2,
3531                    8,
3532                    11,
3533                    13,
3534                    out_channels=16,
3535                    kernel_size=3,
3536                    dilation=2,
3537                    groups=1,
3538                    input_format=input_format,
3539                    weight_format=weight_format,
3540                )
3541                helper(
3542                    nn.Conv2d,
3543                    2,
3544                    16,
3545                    11,
3546                    13,
3547                    out_channels=32,
3548                    kernel_size=3,
3549                    dilation=2,
3550                    groups=16,
3551                    input_format=input_format,
3552                    weight_format=weight_format,
3553                )
3554                # transposed-conv: slow_conv_transpose2d
3555                helper(
3556                    nn.ConvTranspose2d,
3557                    2,
3558                    8,
3559                    4,
3560                    4,
3561                    out_channels=4,
3562                    kernel_size=3,
3563                    dilation=1,
3564                    groups=1,
3565                    input_format=input_format,
3566                    weight_format=weight_format,
3567                )
3568                helper(
3569                    nn.ConvTranspose2d,
3570                    2,
3571                    8,
3572                    4,
3573                    4,
3574                    out_channels=8,
3575                    kernel_size=3,
3576                    dilation=1,
3577                    groups=8,
3578                    input_format=input_format,
3579                    weight_format=weight_format,
3580                )
3581                helper(
3582                    nn.ConvTranspose2d,
3583                    1,
3584                    16,
3585                    56,
3586                    56,
3587                    out_channels=16,
3588                    kernel_size=1,
3589                    dilation=1,
3590                    groups=1,
3591                    input_format=input_format,
3592                    weight_format=weight_format,
3593                )
3594                helper(
3595                    nn.ConvTranspose2d,
3596                    1,
3597                    16,
3598                    56,
3599                    56,
3600                    out_channels=32,
3601                    kernel_size=1,
3602                    dilation=1,
3603                    groups=16,
3604                    input_format=input_format,
3605                    weight_format=weight_format,
3606                )
3607
3608    @onlyCUDA
3609    @skipCUDAIfRocmVersionLessThan((4, 3))
3610    @skipCUDAIfNotMiopenSuggestNHWC
3611    @skipCUDAIfCudnnVersionLessThan(7603)
3612    @dtypes(torch.half, torch.float, torch.cfloat)
3613    def test_conv_cudnn_nhwc(self, device, dtype):
3614        def helper(n, c, h, w, out_channels, kernel_size, groups):
3615            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3616                memory_format=torch.channels_last
3617            )
3618            input.requires_grad_()
3619            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
3620                device="cuda", dtype=dtype, memory_format=torch.channels_last
3621            )
3622            for p in conv.parameters():
3623                p.data = torch.randint_like(p, -3, 3)
3624
3625            # use FP64 channels-first conv as reference
3626            ref_input = input.detach().clone().contiguous().double().requires_grad_()
3627            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
3628            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3629            ref_conv.load_state_dict(conv.state_dict())
3630            ref_conv = ref_conv.to(
3631                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
3632            )
3633
3634            out = conv(input)
3635            ref_out = ref_conv(ref_input)
3636
3637            grad = torch.randint_like(out, -3, 3)
3638            ref_grad = grad.detach().clone().double().contiguous()
3639
3640            out.backward(grad)
3641            ref_out.backward(ref_grad)
3642
3643            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
3644            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
3645            self.assertTrue(
3646                conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
3647            )
3648
3649            self.assertTrue(ref_out.is_contiguous())
3650            self.assertTrue(ref_input.grad.is_contiguous())
3651            self.assertTrue(ref_conv.weight.grad.is_contiguous())
3652
3653            self.assertEqual(out, ref_out, exact_dtype=False)
3654            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3655            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3656            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3657
3658        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
3659        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
3660        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
3661        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
3662
3663    @onlyCUDA
3664    @skipCUDAIfRocm
3665    @skipCUDAIfCudnnVersionLessThan(8005)
3666    @dtypes(torch.half, torch.float)
3667    def test_conv_cudnn_ndhwc(self, device, dtype):
3668        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
3669            input = torch.randint(
3670                -2, 2, (n, c, d, h, w), dtype=dtype, device=device
3671            ).to(memory_format=torch.channels_last_3d)
3672            input.requires_grad_()
3673            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
3674                device="cuda", dtype=dtype, memory_format=torch.channels_last_3d
3675            )
3676            for p in conv.parameters():
3677                p.data = torch.randint_like(p, -2, 2)
3678
3679            # use FP64 channels-first conv as reference
3680            ref_input = input.detach().clone().contiguous().double().requires_grad_()
3681            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
3682            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3683            ref_conv.load_state_dict(conv.state_dict())
3684            ref_conv = ref_conv.to(
3685                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
3686            )
3687
3688            out = conv(input)
3689            ref_out = ref_conv(ref_input)
3690
3691            grad = torch.randint_like(out, -2, 2)
3692            ref_grad = grad.detach().clone().double().contiguous()
3693
3694            out.backward(grad)
3695            ref_out.backward(ref_grad)
3696
3697            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
3698            self.assertTrue(
3699                input.grad.is_contiguous(memory_format=torch.channels_last_3d)
3700            )
3701            self.assertTrue(
3702                conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
3703            )
3704
3705            self.assertTrue(ref_out.is_contiguous())
3706            self.assertTrue(ref_input.grad.is_contiguous())
3707            self.assertTrue(ref_conv.weight.grad.is_contiguous())
3708
3709            self.assertEqual(out, ref_out, exact_dtype=False)
3710            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3711            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3712            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3713
3714        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
3715        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
3716        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
3717        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
3718
3719    def _run_conv(
3720        self,
3721        layer,
3722        device,
3723        inp,
3724        grad,
3725        ref_conv,
3726        ref_input,
3727        ref_out,
3728        input_format,
3729        weight_format,
3730        grad_format,
3731        output_format,
3732    ):
3733        conv = (
3734            layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
3735        )
3736        # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3737        conv.load_state_dict(ref_conv.state_dict())
3738        weight_data = (
3739            conv.weight.detach().clone().contiguous(memory_format=weight_format)
3740        )
3741        conv.weight.data = weight_data.resize_(
3742            weight_data.size(), memory_format=weight_format
3743        )
3744        input = inp.clone().contiguous(memory_format=input_format)
3745        input.resize_(input.size(), memory_format=input_format)
3746        input = input.requires_grad_()
3747        grad = grad.contiguous(memory_format=grad_format)
3748        grad.resize_(grad.size(), memory_format=grad_format)
3749        out = conv(input)
3750        out.backward(grad)
3751        self.assertTrue(out.is_contiguous(memory_format=output_format))
3752        self.assertEqual(out, ref_out)
3753        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
3754        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
3755        self.assertEqual(input.grad, ref_input.grad)
3756
3757    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
3758        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
3759        ref_input = data.clone().contiguous().requires_grad_(True)
3760        ref_conv = layer(c, k, filter_size).float().to(device)
3761        ref_out = ref_conv(ref_input)
3762        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
3763        ref_out.backward(grad)
3764
3765        for w_f in [torch.contiguous_format, torch.channels_last]:
3766            for g_f in [torch.contiguous_format, torch.channels_last]:
3767                for input_format in [torch.contiguous_format, torch.channels_last]:
3768                    output_format = torch.contiguous_format
3769                    # Older versions of CudNN have Channels Last support disabled
3770                    if torch.backends.cudnn.version() >= 7603:
3771                        if input_format == torch.channels_last:
3772                            output_format = torch.channels_last
3773                        # This is because we have N111 weight that cannot handle
3774                        # the ambiguous memory_format
3775                        if w_f == torch.channels_last:
3776                            if layer == nn.Conv2d and filter_size * c != 1:
3777                                output_format = torch.channels_last
3778                            if layer == nn.ConvTranspose2d and filter_size * k != 1:
3779                                output_format = torch.channels_last
3780                    self._run_conv(
3781                        layer,
3782                        device,
3783                        data,
3784                        grad,
3785                        ref_conv,
3786                        ref_input,
3787                        ref_out,
3788                        input_format,
3789                        w_f,
3790                        g_f,
3791                        output_format,
3792                    )
3793
3794    @onlyCUDA
3795    @skipCUDAIfRocmVersionLessThan((4, 3))
3796    @skipCUDAIfNotMiopenSuggestNHWC
3797    @skipCUDAIfCudnnVersionLessThan(7603)
3798    @tf32_on_and_off(0.05)
3799    def test_conv_cudnn_mismatch_memory_format(self, device):
3800        configs = [
3801            [4, 2, 8, 8, 4, 2],
3802            [4, 1, 8, 8, 4, 2],
3803            [1, 1, 8, 8, 4, 2],
3804            [4, 2, 2, 8, 4, 1],
3805            [4, 2, 1, 8, 4, 1],
3806            [4, 2, 8, 8, 4, 1],
3807            [4, 1, 8, 8, 4, 1],
3808        ]
3809        for n, c, h, w, k, filter_size in configs:
3810            self._test_conv_cudnn_nhwc_nchw(
3811                nn.Conv2d, n, c, h, w, k, filter_size, device
3812            )
3813            self._test_conv_cudnn_nhwc_nchw(
3814                nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
3815            )
3816
3817    # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
3818    # returning CUDNN_STATUS_BAD_PARAM
3819    # Disabling that specific test for now [see issue # 33918]
3820    @onlyCUDA
3821    @skipCUDAIfNoCudnn
3822    @dtypes(torch.float, torch.double)
3823    def test_conv_cudnn_nhwc_support(self, device, dtype):
3824        input = torch.randn(
3825            (1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
3826        )
3827        weight = torch.randn(
3828            (8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True
3829        )
3830        weight = weight.to(memory_format=torch.channels_last)
3831        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
3832        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
3833        o.sum().backward()
3834
3835    # Test that faster algorithms used for inference produce the same results
3836    # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
3837    @onlyCPU
3838    @dtypes(torch.float)
3839    def test_conv2d_no_grad(self, device, dtype):
3840        for batch in [1, 2, 3]:
3841            for groups in [1, 2, 4]:
3842                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
3843                m = nn.Conv2d(
3844                    groups,
3845                    8,
3846                    kernel_size=(3, 3),
3847                    groups=groups,
3848                    dtype=dtype,
3849                    device=device,
3850                )
3851                with torch.no_grad():
3852                    output_ng = m(input)
3853                output = m(input)
3854                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
3855
3856    @onlyCUDA
3857    @skipCUDAIfNoCudnn
3858    @dtypes(torch.float, torch.float16)
3859    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
3860    def test_cudnn_convolution_relu(self, device, dtype):
3861        for batch, groups, image_size, kernel_size, memory_format in product(
3862            (1, 2, 3),
3863            (1, 2, 4),
3864            ((1, 1), (8, 8)),
3865            ((1, 1), (3, 3)),
3866            (torch.channels_last, torch.contiguous_format),
3867        ):
3868            if image_size[0] < kernel_size[0]:
3869                continue
3870            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
3871            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
3872            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
3873            inp = inp.to(memory_format=memory_format)
3874            w = w.to(memory_format=memory_format)
3875            if torch.version.hip:
3876                cudnn_out = torch.miopen_convolution_relu(
3877                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
3878                )
3879            else:
3880                cudnn_out = torch.cudnn_convolution_relu(
3881                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
3882                )
3883            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
3884            if tf32_is_not_fp32() and dtype == torch.float:
3885                self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
3886            else:
3887                self.assertEqual(conv2d_out.relu(), cudnn_out)
3888
3889    @onlyCUDA
3890    @skipCUDAIfNoCudnn
3891    @dtypes(torch.float, torch.float16)
3892    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
3893    def test_cudnn_convolution_add_relu(self, device, dtype):
3894        for batch, groups, image_size, kernel_size, memory_format in product(
3895            (1, 2, 3),
3896            (1, 2, 4),
3897            ((1, 1), (8, 8)),
3898            ((1, 1), (3, 3)),
3899            (torch.channels_last, torch.contiguous_format),
3900        ):
3901            if image_size[0] < kernel_size[0]:
3902                continue
3903            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
3904            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
3905            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
3906            alpha = 2.0
3907            z = torch.randn_like(conv2d_out)
3908
3909            inp = inp.to(memory_format=memory_format)
3910            w = w.to(memory_format=memory_format)
3911            z = z.to(memory_format=memory_format)
3912            if torch.version.hip:
3913                cudnn_out = torch.miopen_convolution_add_relu(
3914                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
3915                )
3916            else:
3917                cudnn_out = torch.cudnn_convolution_add_relu(
3918                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
3919                )
3920
3921            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
3922            if tf32_is_not_fp32() and dtype == torch.float:
3923                self.assertEqual(
3924                    F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
3925                )
3926            else:
3927                self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
3928
3929    @onlyCUDA
3930    @skipCUDAIfRocm
3931    @skipCUDAIfCudnnVersionLessThan(7603)
3932    def test_convert_conv2d_weight_memory_format(self, device):
3933        input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
3934        model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
3935        for memory_format in [torch.channels_last, torch.contiguous_format]:
3936            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
3937            out = model(input)
3938            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3939
3940        model = (
3941            nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4))
3942            .to(device)
3943            .float()
3944        )
3945        for memory_format in [torch.channels_last, torch.contiguous_format]:
3946            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
3947            out = model(input)
3948            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3949
3950    @onlyCUDA
3951    @skipCUDAIfRocm
3952    @skipCUDAIfCudnnVersionLessThan(7603)
3953    def test_convert_conv3d_weight_memory_format(self, device):
3954        input = torch.randint(
3955            1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
3956        )
3957        model = (
3958            nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4))
3959            .to(device)
3960            .float()
3961        )
3962        for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
3963            model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
3964            out = model(input)
3965            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3966
3967    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
3968        # Test that _convolution_double_backward() outputs the correct grad shapes
3969        # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
3970        # specific case that was uncovered during the convolution consolidation effort.
3971        # The test can be safely deleted if _convolution_double_backward() is removed.
3972
3973        input = torch.randn(2, 3, 6, device=device)
3974        weight = torch.randn(3, 3, 3, device=device)
3975        bias = torch.randn(3, device=device)
3976        stride = (2,)
3977        padding = (1,)
3978        dilation = (1,)
3979        transposed = False
3980        output_padding = (0,)
3981        groups = 1
3982        output = torch.ops.aten.convolution(
3983            input,
3984            weight,
3985            bias,
3986            stride,
3987            padding,
3988            dilation,
3989            transposed,
3990            output_padding,
3991            groups,
3992        )
3993
3994        ggI = torch.randn(input.shape, device=device)
3995        ggW = torch.randn(weight.shape, device=device)
3996        ggB = torch.randn(bias.shape, device=device)
3997        gO = torch.randn(output.shape, device=device)
3998        output_mask = [True, True, True]
3999        (
4000            grad_grad_output,
4001            grad_input,
4002            grad_weight,
4003        ) = torch.ops.aten._convolution_double_backward(
4004            ggI,
4005            ggW,
4006            ggB,
4007            gO,
4008            weight,
4009            input,
4010            stride,
4011            padding,
4012            dilation,
4013            transposed,
4014            output_padding,
4015            groups,
4016            output_mask,
4017        )
4018
4019        # Make sure the correct shapes are computed.
4020        self.assertEqual(grad_grad_output.shape, gO.shape)
4021        self.assertEqual(grad_input.shape, input.shape)
4022        self.assertEqual(grad_weight.shape, weight.shape)
4023
4024    @onlyCUDA
4025    @largeTensorTest("40GB")
4026    @largeTensorTest("24GB", "cpu")
4027    def test_conv3d_64bit_indexing(self, device):
4028        x = torch.rand(1, 32, 512, 512, 256)
4029        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
4030        yref = m(x)
4031        y = m.to(device=device)(x.to(device=device))
4032        self.assertEqual(yref, y)
4033
4034
4035instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
4036instantiate_parametrized_tests(TestConvolutionNN)
4037
4038if __name__ == "__main__":
4039    run_tests()
4040