xref: /aosp_15_r20/external/pytorch/test/xpu/test_conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.backends.cudnn as cudnn
10*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
12*da0073e9SAndroid Build Coastguard Workerfrom torch._C._dynamo.guards import assert_size_stride
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_is_not_fp32
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
16*da0073e9SAndroid Build Coastguard Worker    dtypes,
17*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
18*da0073e9SAndroid Build Coastguard Worker    onlyXPU,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
23*da0073e9SAndroid Build Coastguard Worker    dtype2prec_DONTUSE,
24*da0073e9SAndroid Build Coastguard Worker    gradcheck,
25*da0073e9SAndroid Build Coastguard Worker    gradgradcheck,
26*da0073e9SAndroid Build Coastguard Worker    parametrize as parametrize_test,
27*da0073e9SAndroid Build Coastguard Worker    run_tests,
28*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
29*da0073e9SAndroid Build Coastguard Worker    TEST_SCIPY,
30*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
35*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
36*da0073e9SAndroid Build Coastguard Worker    import scipy.ndimage
37*da0073e9SAndroid Build Coastguard Worker    import scipy.signal
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass TestConvolutionNNDeviceType(NNTestCase):
41*da0073e9SAndroid Build Coastguard Worker    def run_conv_double_back_test(
42*da0073e9SAndroid Build Coastguard Worker        self,
43*da0073e9SAndroid Build Coastguard Worker        kern,
44*da0073e9SAndroid Build Coastguard Worker        stride,
45*da0073e9SAndroid Build Coastguard Worker        padding,
46*da0073e9SAndroid Build Coastguard Worker        chan_in,
47*da0073e9SAndroid Build Coastguard Worker        chan_out,
48*da0073e9SAndroid Build Coastguard Worker        batch_size,
49*da0073e9SAndroid Build Coastguard Worker        inp_size,
50*da0073e9SAndroid Build Coastguard Worker        dilation,
51*da0073e9SAndroid Build Coastguard Worker        no_weight,
52*da0073e9SAndroid Build Coastguard Worker        groups=1,
53*da0073e9SAndroid Build Coastguard Worker        use_xpu=False,
54*da0073e9SAndroid Build Coastguard Worker        use_bias=True,
55*da0073e9SAndroid Build Coastguard Worker        dtype=torch.double,
56*da0073e9SAndroid Build Coastguard Worker    ):
57*da0073e9SAndroid Build Coastguard Worker        device = torch.device("xpu" if use_xpu else "cpu")
58*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(
59*da0073e9SAndroid Build Coastguard Worker            batch_size,
60*da0073e9SAndroid Build Coastguard Worker            chan_in,
61*da0073e9SAndroid Build Coastguard Worker            inp_size,
62*da0073e9SAndroid Build Coastguard Worker            inp_size,
63*da0073e9SAndroid Build Coastguard Worker            device=device,
64*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
65*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
66*da0073e9SAndroid Build Coastguard Worker        )
67*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
68*da0073e9SAndroid Build Coastguard Worker            chan_out,
69*da0073e9SAndroid Build Coastguard Worker            chan_in // groups,
70*da0073e9SAndroid Build Coastguard Worker            kern,
71*da0073e9SAndroid Build Coastguard Worker            kern,
72*da0073e9SAndroid Build Coastguard Worker            device=device,
73*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
74*da0073e9SAndroid Build Coastguard Worker            requires_grad=not no_weight,
75*da0073e9SAndroid Build Coastguard Worker        )
76*da0073e9SAndroid Build Coastguard Worker        if use_bias:
77*da0073e9SAndroid Build Coastguard Worker            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
78*da0073e9SAndroid Build Coastguard Worker        else:
79*da0073e9SAndroid Build Coastguard Worker            bias = None
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        def func(*inputs):
82*da0073e9SAndroid Build Coastguard Worker            if use_bias:
83*da0073e9SAndroid Build Coastguard Worker                lx, lweight, lbias = inputs
84*da0073e9SAndroid Build Coastguard Worker            else:
85*da0073e9SAndroid Build Coastguard Worker                lx, lweight = inputs
86*da0073e9SAndroid Build Coastguard Worker                lbias = None
87*da0073e9SAndroid Build Coastguard Worker            out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
88*da0073e9SAndroid Build Coastguard Worker            return out
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        if use_bias:
91*da0073e9SAndroid Build Coastguard Worker            inputs = x, weight, bias
92*da0073e9SAndroid Build Coastguard Worker        else:
93*da0073e9SAndroid Build Coastguard Worker            inputs = x, weight
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        dummy_out = func(*inputs)
96*da0073e9SAndroid Build Coastguard Worker        grad_y = torch.randn_like(
97*da0073e9SAndroid Build Coastguard Worker            dummy_out, device=device, dtype=dtype, requires_grad=True
98*da0073e9SAndroid Build Coastguard Worker        )
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float:
101*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
102*da0073e9SAndroid Build Coastguard Worker            return g.requires_grad
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        return gradgradcheck(func, inputs, (grad_y,))
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
107*da0073e9SAndroid Build Coastguard Worker    def test_Conv2d_large_workspace(self, device, dtype):
108*da0073e9SAndroid Build Coastguard Worker        sizes = [
109*da0073e9SAndroid Build Coastguard Worker            (1, 256, 109, 175),
110*da0073e9SAndroid Build Coastguard Worker            (1, 256, 80, 128),
111*da0073e9SAndroid Build Coastguard Worker            (1, 256, 120, 192),
112*da0073e9SAndroid Build Coastguard Worker        ]
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        def run_test(benchmark):
115*da0073e9SAndroid Build Coastguard Worker            conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype)
116*da0073e9SAndroid Build Coastguard Worker            for size in sizes:
117*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(size, device=device, dtype=dtype)
118*da0073e9SAndroid Build Coastguard Worker                out = conv(x.detach().clone().requires_grad_())
119*da0073e9SAndroid Build Coastguard Worker                out.backward(torch.ones_like(out))
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        run_test(benchmark=False)
122*da0073e9SAndroid Build Coastguard Worker        run_test(benchmark=True)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float)
125*da0073e9SAndroid Build Coastguard Worker    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
126*da0073e9SAndroid Build Coastguard Worker        net1 = torch.nn.ConvTranspose2d(
127*da0073e9SAndroid Build Coastguard Worker            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
128*da0073e9SAndroid Build Coastguard Worker        ).to(device=device, dtype=dtype)
129*da0073e9SAndroid Build Coastguard Worker        net2 = torch.nn.ConvTranspose2d(
130*da0073e9SAndroid Build Coastguard Worker            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
131*da0073e9SAndroid Build Coastguard Worker        ).to(device=device, dtype=dtype)
132*da0073e9SAndroid Build Coastguard Worker        net3 = torch.nn.ConvTranspose2d(
133*da0073e9SAndroid Build Coastguard Worker            32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
134*da0073e9SAndroid Build Coastguard Worker        ).to(device=device, dtype=dtype)
135*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
136*da0073e9SAndroid Build Coastguard Worker        x = net1(x)
137*da0073e9SAndroid Build Coastguard Worker        x = net2(x)
138*da0073e9SAndroid Build Coastguard Worker        x = net3(x)
139*da0073e9SAndroid Build Coastguard Worker        x.backward(torch.randn_like(x))
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
142*da0073e9SAndroid Build Coastguard Worker    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
143*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and "xpu" in device:
144*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
145*da0073e9SAndroid Build Coastguard Worker                "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4"
146*da0073e9SAndroid Build Coastguard Worker            )
147*da0073e9SAndroid Build Coastguard Worker        for depth_multiplier in [1, 2]:
148*da0073e9SAndroid Build Coastguard Worker            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
149*da0073e9SAndroid Build Coastguard Worker                device, dtype
150*da0073e9SAndroid Build Coastguard Worker            )
151*da0073e9SAndroid Build Coastguard Worker            i = (
152*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 2, 6, 6, device=device, dtype=dtype)
153*da0073e9SAndroid Build Coastguard Worker                .div_(2)
154*da0073e9SAndroid Build Coastguard Worker                .requires_grad_()
155*da0073e9SAndroid Build Coastguard Worker            )
156*da0073e9SAndroid Build Coastguard Worker            output = m(i)
157*da0073e9SAndroid Build Coastguard Worker            grad_output = (
158*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
159*da0073e9SAndroid Build Coastguard Worker                / 2
160*da0073e9SAndroid Build Coastguard Worker            )
161*da0073e9SAndroid Build Coastguard Worker            output.backward(grad_output)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker            offset = 1 * depth_multiplier
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
166*da0073e9SAndroid Build Coastguard Worker            m1.weight.data = m.weight.data[:offset].clone()
167*da0073e9SAndroid Build Coastguard Worker            m1.bias.data = m.bias.data[:offset].clone()
168*da0073e9SAndroid Build Coastguard Worker            i1 = i.detach()[:, :1].clone().requires_grad_()
169*da0073e9SAndroid Build Coastguard Worker            output1 = m1(i1)
170*da0073e9SAndroid Build Coastguard Worker            output1.backward(grad_output[:, :offset].contiguous())
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
173*da0073e9SAndroid Build Coastguard Worker            m2.weight.data.copy_(m.weight.data[offset:])
174*da0073e9SAndroid Build Coastguard Worker            m2.bias.data.copy_(m.bias.data[offset:])
175*da0073e9SAndroid Build Coastguard Worker            i2 = i.detach()[:, 1:].clone().requires_grad_()
176*da0073e9SAndroid Build Coastguard Worker            output2 = m2(i2)
177*da0073e9SAndroid Build Coastguard Worker            output2.backward(grad_output[:, offset:].contiguous())
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
180*da0073e9SAndroid Build Coastguard Worker                output,
181*da0073e9SAndroid Build Coastguard Worker                torch.cat([output1, output2], 1),
182*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
183*da0073e9SAndroid Build Coastguard Worker                rtol=0,
184*da0073e9SAndroid Build Coastguard Worker            )
185*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
186*da0073e9SAndroid Build Coastguard Worker                i.grad.data,
187*da0073e9SAndroid Build Coastguard Worker                torch.cat([i1.grad.data, i2.grad.data], 1),
188*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
189*da0073e9SAndroid Build Coastguard Worker                rtol=0,
190*da0073e9SAndroid Build Coastguard Worker            )
191*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
192*da0073e9SAndroid Build Coastguard Worker                m.bias.grad.data,
193*da0073e9SAndroid Build Coastguard Worker                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
194*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
195*da0073e9SAndroid Build Coastguard Worker                rtol=0,
196*da0073e9SAndroid Build Coastguard Worker            )
197*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
198*da0073e9SAndroid Build Coastguard Worker                m.weight.grad.data,
199*da0073e9SAndroid Build Coastguard Worker                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
200*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
201*da0073e9SAndroid Build Coastguard Worker                rtol=0,
202*da0073e9SAndroid Build Coastguard Worker            )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
205*da0073e9SAndroid Build Coastguard Worker    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
206*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and "xpu" in device:
207*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
208*da0073e9SAndroid Build Coastguard Worker                "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4"
209*da0073e9SAndroid Build Coastguard Worker            )
210*da0073e9SAndroid Build Coastguard Worker        for depth_multiplier in [1, 2]:
211*da0073e9SAndroid Build Coastguard Worker            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
212*da0073e9SAndroid Build Coastguard Worker                device, dtype
213*da0073e9SAndroid Build Coastguard Worker            )
214*da0073e9SAndroid Build Coastguard Worker            i = (
215*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 2, 6, 6, 6, device=device, dtype=dtype)
216*da0073e9SAndroid Build Coastguard Worker                .div_(2)
217*da0073e9SAndroid Build Coastguard Worker                .requires_grad_()
218*da0073e9SAndroid Build Coastguard Worker            )
219*da0073e9SAndroid Build Coastguard Worker            output = m(i)
220*da0073e9SAndroid Build Coastguard Worker            grad_output = (
221*da0073e9SAndroid Build Coastguard Worker                torch.randn(
222*da0073e9SAndroid Build Coastguard Worker                    2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
223*da0073e9SAndroid Build Coastguard Worker                )
224*da0073e9SAndroid Build Coastguard Worker                / 2
225*da0073e9SAndroid Build Coastguard Worker            )
226*da0073e9SAndroid Build Coastguard Worker            output.backward(grad_output)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker            offset = 1 * depth_multiplier
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
231*da0073e9SAndroid Build Coastguard Worker            m1.weight.data = m.weight.data[:offset].clone()
232*da0073e9SAndroid Build Coastguard Worker            m1.bias.data = m.bias.data[:offset].clone()
233*da0073e9SAndroid Build Coastguard Worker            i1 = i.detach()[:, :1].clone().requires_grad_()
234*da0073e9SAndroid Build Coastguard Worker            output1 = m1(i1)
235*da0073e9SAndroid Build Coastguard Worker            output1.backward(grad_output[:, :offset].contiguous())
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
238*da0073e9SAndroid Build Coastguard Worker            m2.weight.data.copy_(m.weight.data[offset:])
239*da0073e9SAndroid Build Coastguard Worker            m2.bias.data.copy_(m.bias.data[offset:])
240*da0073e9SAndroid Build Coastguard Worker            i2 = i.detach()[:, 1:].clone().requires_grad_()
241*da0073e9SAndroid Build Coastguard Worker            output2 = m2(i2)
242*da0073e9SAndroid Build Coastguard Worker            output2.backward(grad_output[:, offset:].contiguous())
243*da0073e9SAndroid Build Coastguard Worker            atol, rtol = (3e-4, 3e-2)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
246*da0073e9SAndroid Build Coastguard Worker                output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
247*da0073e9SAndroid Build Coastguard Worker            )
248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
249*da0073e9SAndroid Build Coastguard Worker                i.grad.data,
250*da0073e9SAndroid Build Coastguard Worker                torch.cat([i1.grad.data, i2.grad.data], 1),
251*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
252*da0073e9SAndroid Build Coastguard Worker                rtol=0,
253*da0073e9SAndroid Build Coastguard Worker            )
254*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
255*da0073e9SAndroid Build Coastguard Worker                m.bias.grad.data,
256*da0073e9SAndroid Build Coastguard Worker                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
257*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtype],
258*da0073e9SAndroid Build Coastguard Worker                rtol=0,
259*da0073e9SAndroid Build Coastguard Worker            )
260*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
261*da0073e9SAndroid Build Coastguard Worker                m.weight.grad.data,
262*da0073e9SAndroid Build Coastguard Worker                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
263*da0073e9SAndroid Build Coastguard Worker                atol=atol,
264*da0073e9SAndroid Build Coastguard Worker                rtol=rtol,
265*da0073e9SAndroid Build Coastguard Worker            )
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
268*da0073e9SAndroid Build Coastguard Worker    def test_noncontig_conv_grad(self, device, dtype):
269*da0073e9SAndroid Build Coastguard Worker        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
270*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(
271*da0073e9SAndroid Build Coastguard Worker            2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
272*da0073e9SAndroid Build Coastguard Worker        )
273*da0073e9SAndroid Build Coastguard Worker        output = module(input)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
276*da0073e9SAndroid Build Coastguard Worker        assert not grad.is_contiguous()
277*da0073e9SAndroid Build Coastguard Worker        output.backward(grad, retain_graph=True)
278*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(input.grad)
279*da0073e9SAndroid Build Coastguard Worker        result = input.grad.data.clone()
280*da0073e9SAndroid Build Coastguard Worker        input.grad.data.zero_()
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        output.backward(grad.contiguous())
283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
284*da0073e9SAndroid Build Coastguard Worker            result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
285*da0073e9SAndroid Build Coastguard Worker        )
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
288*da0073e9SAndroid Build Coastguard Worker    def test_conv_double_backward(self, device, dtype):
289*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
290*da0073e9SAndroid Build Coastguard Worker            batch_size = 1
291*da0073e9SAndroid Build Coastguard Worker            for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
292*da0073e9SAndroid Build Coastguard Worker                for stride, padding, chan_in, chan_out, dilation in product(
293*da0073e9SAndroid Build Coastguard Worker                    [1], [2], [2], [3], dilations
294*da0073e9SAndroid Build Coastguard Worker                ):
295*da0073e9SAndroid Build Coastguard Worker                    no_weight = stride == 2
296*da0073e9SAndroid Build Coastguard Worker                    result = self.run_conv_double_back_test(
297*da0073e9SAndroid Build Coastguard Worker                        kern,
298*da0073e9SAndroid Build Coastguard Worker                        stride,
299*da0073e9SAndroid Build Coastguard Worker                        padding,
300*da0073e9SAndroid Build Coastguard Worker                        chan_in,
301*da0073e9SAndroid Build Coastguard Worker                        chan_out,
302*da0073e9SAndroid Build Coastguard Worker                        batch_size,
303*da0073e9SAndroid Build Coastguard Worker                        inp_size,
304*da0073e9SAndroid Build Coastguard Worker                        dilation,
305*da0073e9SAndroid Build Coastguard Worker                        no_weight,
306*da0073e9SAndroid Build Coastguard Worker                        use_xpu=True,
307*da0073e9SAndroid Build Coastguard Worker                        dtype=dtype,
308*da0073e9SAndroid Build Coastguard Worker                    )
309*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(result, "Conv double backward test failed")
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    def test_conv_double_backward_no_bias(self):
312*da0073e9SAndroid Build Coastguard Worker        kern, stride = 3, 2
313*da0073e9SAndroid Build Coastguard Worker        chan_in, chan_out = 2, 4
314*da0073e9SAndroid Build Coastguard Worker        batch_size, inp_size = 2, 5
315*da0073e9SAndroid Build Coastguard Worker        padding, dilation = 1, 1
316*da0073e9SAndroid Build Coastguard Worker        no_weight, use_bias = False, True
317*da0073e9SAndroid Build Coastguard Worker        result = self.run_conv_double_back_test(
318*da0073e9SAndroid Build Coastguard Worker            kern,
319*da0073e9SAndroid Build Coastguard Worker            stride,
320*da0073e9SAndroid Build Coastguard Worker            padding,
321*da0073e9SAndroid Build Coastguard Worker            chan_in,
322*da0073e9SAndroid Build Coastguard Worker            chan_out,
323*da0073e9SAndroid Build Coastguard Worker            batch_size,
324*da0073e9SAndroid Build Coastguard Worker            inp_size,
325*da0073e9SAndroid Build Coastguard Worker            dilation,
326*da0073e9SAndroid Build Coastguard Worker            no_weight,
327*da0073e9SAndroid Build Coastguard Worker            use_bias=use_bias,
328*da0073e9SAndroid Build Coastguard Worker        )
329*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result, "Conv double backward test failed")
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    def test_conv_double_backward_groups(self):
332*da0073e9SAndroid Build Coastguard Worker        kern, stride, padding = 3, 1, 2
333*da0073e9SAndroid Build Coastguard Worker        chan_in, chan_out = 2, 4
334*da0073e9SAndroid Build Coastguard Worker        batch_size, inp_size, dilation = 2, 6, 1
335*da0073e9SAndroid Build Coastguard Worker        no_weight = False
336*da0073e9SAndroid Build Coastguard Worker        groups = 2
337*da0073e9SAndroid Build Coastguard Worker        result = self.run_conv_double_back_test(
338*da0073e9SAndroid Build Coastguard Worker            kern,
339*da0073e9SAndroid Build Coastguard Worker            stride,
340*da0073e9SAndroid Build Coastguard Worker            padding,
341*da0073e9SAndroid Build Coastguard Worker            chan_in * groups,
342*da0073e9SAndroid Build Coastguard Worker            chan_out * groups,
343*da0073e9SAndroid Build Coastguard Worker            batch_size,
344*da0073e9SAndroid Build Coastguard Worker            inp_size,
345*da0073e9SAndroid Build Coastguard Worker            dilation,
346*da0073e9SAndroid Build Coastguard Worker            no_weight,
347*da0073e9SAndroid Build Coastguard Worker            groups=groups,
348*da0073e9SAndroid Build Coastguard Worker        )
349*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result, "Conv double backward test failed")
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    def test_conv_double_backward_stride(self):
352*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
353*da0073e9SAndroid Build Coastguard Worker        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
354*da0073e9SAndroid Build Coastguard Worker            for stride, padding, chan_in, chan_out, dilation in product(
355*da0073e9SAndroid Build Coastguard Worker                [2], [0, 1], [1], [2], dilations
356*da0073e9SAndroid Build Coastguard Worker            ):
357*da0073e9SAndroid Build Coastguard Worker                no_weight = False
358*da0073e9SAndroid Build Coastguard Worker                self.run_conv_double_back_test(
359*da0073e9SAndroid Build Coastguard Worker                    kern,
360*da0073e9SAndroid Build Coastguard Worker                    stride,
361*da0073e9SAndroid Build Coastguard Worker                    padding,
362*da0073e9SAndroid Build Coastguard Worker                    chan_in,
363*da0073e9SAndroid Build Coastguard Worker                    chan_out,
364*da0073e9SAndroid Build Coastguard Worker                    batch_size,
365*da0073e9SAndroid Build Coastguard Worker                    inp_size,
366*da0073e9SAndroid Build Coastguard Worker                    dilation,
367*da0073e9SAndroid Build Coastguard Worker                    no_weight,
368*da0073e9SAndroid Build Coastguard Worker                )
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
371*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_same_padding(self, device, dtype):
372*da0073e9SAndroid Build Coastguard Worker        test_args = [
373*da0073e9SAndroid Build Coastguard Worker            range(50, 55),
374*da0073e9SAndroid Build Coastguard Worker            [1, 2, 3, 8],
375*da0073e9SAndroid Build Coastguard Worker            range(1, 4),
376*da0073e9SAndroid Build Coastguard Worker            [1],
377*da0073e9SAndroid Build Coastguard Worker        ]
378*da0073e9SAndroid Build Coastguard Worker        for in_size, k_size, dilation, stride in itertools.product(*test_args):
379*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
380*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
381*da0073e9SAndroid Build Coastguard Worker            z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
382*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
385*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
386*da0073e9SAndroid Build Coastguard Worker        expect = F.conv1d(x, y, padding=1)
387*da0073e9SAndroid Build Coastguard Worker        actual = F.conv1d(x, y, padding="same")
388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
391*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
392*da0073e9SAndroid Build Coastguard Worker        expect = F.conv1d(x, y, padding=3, dilation=2)
393*da0073e9SAndroid Build Coastguard Worker        actual = F.conv1d(x, y, padding="same", dilation=2)
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
397*da0073e9SAndroid Build Coastguard Worker        actual = F.conv1d(x, y, padding="same", dilation=3)
398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
401*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_same_padding(self, device, dtype):
402*da0073e9SAndroid Build Coastguard Worker        rtol, atol = None, None
403*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
404*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
405*da0073e9SAndroid Build Coastguard Worker        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
406*da0073e9SAndroid Build Coastguard Worker        actual = F.conv3d(x, y, padding="same")
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
410*da0073e9SAndroid Build Coastguard Worker        actual = F.conv3d(x, y, padding="same", dilation=2)
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
414*da0073e9SAndroid Build Coastguard Worker        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
415*da0073e9SAndroid Build Coastguard Worker        actual = F.conv3d(x, y, padding="same", dilation=3)
416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
419*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_valid_padding(self, device, dtype):
420*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
421*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
422*da0073e9SAndroid Build Coastguard Worker        expect = F.conv1d(x, y)
423*da0073e9SAndroid Build Coastguard Worker        actual = F.conv1d(x, y, padding="valid")
424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
427*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_valid_padding(self, device, dtype):
428*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
429*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
430*da0073e9SAndroid Build Coastguard Worker        expect = F.conv2d(x, y)
431*da0073e9SAndroid Build Coastguard Worker        actual = F.conv2d(x, y, padding="valid")
432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
435*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_valid_padding(self, device, dtype):
436*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
437*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
438*da0073e9SAndroid Build Coastguard Worker        expect = F.conv3d(x, y)
439*da0073e9SAndroid Build Coastguard Worker        actual = F.conv3d(x, y, padding="valid")
440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
443*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_same_padding_backward(self, device, dtype):
444*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
445*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        z = F.conv1d(x, y, padding=3, dilation=2)
448*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
449*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
450*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        z = F.conv1d(x, y, padding="same", dilation=2)
453*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
456*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        z = F.conv1d(x, y, padding=2)[..., 1:]
459*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
460*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
461*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker        z = F.conv1d(x, y, padding="same")
464*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
469*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_same_padding_backward(self, device, dtype):
470*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
471*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
474*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
475*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
476*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        z = F.conv2d(x, y, padding="same", dilation=2)
479*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
481*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
482*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
485*da0073e9SAndroid Build Coastguard Worker        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
486*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
487*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
488*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker        z = F.conv2d(x, y, padding="same")
491*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
492*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
496*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_same_padding_backward(self, device, dtype):
497*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
498*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
499*da0073e9SAndroid Build Coastguard Worker        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
500*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
501*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
502*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        z = F.conv3d(x, y, padding="same", dilation=2)
505*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
508*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
509*da0073e9SAndroid Build Coastguard Worker        gradcheck(
510*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
511*da0073e9SAndroid Build Coastguard Worker            (x, y),
512*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
513*da0073e9SAndroid Build Coastguard Worker            nondet_tol=1e-5,
514*da0073e9SAndroid Build Coastguard Worker        )
515*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
516*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
517*da0073e9SAndroid Build Coastguard Worker            (x, y),
518*da0073e9SAndroid Build Coastguard Worker            check_fwd_over_rev=True,
519*da0073e9SAndroid Build Coastguard Worker        )
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
522*da0073e9SAndroid Build Coastguard Worker        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
523*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
524*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
525*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        z = F.conv3d(x, y, padding="same")
528*da0073e9SAndroid Build Coastguard Worker        z.sum().abs().backward()
529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, x.grad)
530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, y.grad)
531*da0073e9SAndroid Build Coastguard Worker        gradcheck(
532*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="same"),
533*da0073e9SAndroid Build Coastguard Worker            (x, y),
534*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
535*da0073e9SAndroid Build Coastguard Worker            nondet_tol=1e-5,
536*da0073e9SAndroid Build Coastguard Worker        )
537*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
538*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="same"),
539*da0073e9SAndroid Build Coastguard Worker            (x, y),
540*da0073e9SAndroid Build Coastguard Worker            check_fwd_over_rev=True,
541*da0073e9SAndroid Build Coastguard Worker        )
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
544*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_valid_padding_backward(self, device, dtype):
545*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
546*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
547*da0073e9SAndroid Build Coastguard Worker        F.conv1d(x, y, padding=0).sum().abs().backward()
548*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
549*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
550*da0073e9SAndroid Build Coastguard Worker        F.conv1d(x, y, padding="valid").sum().abs().backward()
551*da0073e9SAndroid Build Coastguard Worker        gx_actual, gy_actual = x.grad, y.grad
552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, gx_actual)
553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, gy_actual)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
556*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
557*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ("valid", "same"))
558*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_vs_scipy(self, device, dtype, mode):
559*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((1, 10), device=device, dtype=dtype)
560*da0073e9SAndroid Build Coastguard Worker        feat_dim = t.shape[1]
561*da0073e9SAndroid Build Coastguard Worker        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
562*da0073e9SAndroid Build Coastguard Worker        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        def _test(t, weight, mode):
565*da0073e9SAndroid Build Coastguard Worker            t_a = t.view(-1).cpu().numpy()
566*da0073e9SAndroid Build Coastguard Worker            w_a = weight.view(-1).cpu().numpy()
567*da0073e9SAndroid Build Coastguard Worker            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            kwargs = {"padding": mode}
570*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
571*da0073e9SAndroid Build Coastguard Worker                p = weight.shape[2] // 2
572*da0073e9SAndroid Build Coastguard Worker                t = torch.nn.functional.pad(t, (p, p))
573*da0073e9SAndroid Build Coastguard Worker                kwargs.pop("padding")
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker            weight_flipped = torch.flip(weight, (2,))
576*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
577*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
578*da0073e9SAndroid Build Coastguard Worker                actual = actual[:feat_dim]
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
583*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_even, mode)
584*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_odd, mode)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
587*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
588*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ("valid", "same"))
589*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_vs_scipy(self, device, dtype, mode):
590*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
591*da0073e9SAndroid Build Coastguard Worker        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
592*da0073e9SAndroid Build Coastguard Worker        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker        def _test(t, weight, mode):
595*da0073e9SAndroid Build Coastguard Worker            t_a = t.squeeze(0).cpu().numpy()
596*da0073e9SAndroid Build Coastguard Worker            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
597*da0073e9SAndroid Build Coastguard Worker            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker            kwargs = {"padding": mode}
600*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
601*da0073e9SAndroid Build Coastguard Worker                left_right_pad = weight.shape[3] // 2
602*da0073e9SAndroid Build Coastguard Worker                top_bottom_pad = weight.shape[2] // 2
603*da0073e9SAndroid Build Coastguard Worker                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
604*da0073e9SAndroid Build Coastguard Worker                t = torch.nn.functional.pad(t, p)
605*da0073e9SAndroid Build Coastguard Worker                kwargs.pop("padding")
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker            weight_flipped = torch.flip(weight, (2, 3))
608*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
609*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
610*da0073e9SAndroid Build Coastguard Worker                actual = actual[:5, :10]
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
615*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_even, mode)
616*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_odd, mode)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
619*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
620*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ("valid", "same"))
621*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_vs_scipy(self, device, dtype, mode):
622*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
623*da0073e9SAndroid Build Coastguard Worker        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
624*da0073e9SAndroid Build Coastguard Worker        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        def _test(t, weight, mode):
627*da0073e9SAndroid Build Coastguard Worker            t_a = t.squeeze(0).cpu().numpy()
628*da0073e9SAndroid Build Coastguard Worker            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
629*da0073e9SAndroid Build Coastguard Worker            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
630*da0073e9SAndroid Build Coastguard Worker            kwargs = {"padding": mode}
631*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
632*da0073e9SAndroid Build Coastguard Worker                left_right_pad = weight.shape[4] // 2
633*da0073e9SAndroid Build Coastguard Worker                top_bottom_pad = weight.shape[3] // 2
634*da0073e9SAndroid Build Coastguard Worker                front_back_pad = weight.shape[2] // 2
635*da0073e9SAndroid Build Coastguard Worker                p = (
636*da0073e9SAndroid Build Coastguard Worker                    left_right_pad,
637*da0073e9SAndroid Build Coastguard Worker                    left_right_pad,
638*da0073e9SAndroid Build Coastguard Worker                    top_bottom_pad,
639*da0073e9SAndroid Build Coastguard Worker                    top_bottom_pad,
640*da0073e9SAndroid Build Coastguard Worker                    front_back_pad,
641*da0073e9SAndroid Build Coastguard Worker                    front_back_pad,
642*da0073e9SAndroid Build Coastguard Worker                )
643*da0073e9SAndroid Build Coastguard Worker                t = torch.nn.functional.pad(t, p)
644*da0073e9SAndroid Build Coastguard Worker                kwargs.pop("padding")
645*da0073e9SAndroid Build Coastguard Worker            weight_flipped = torch.flip(weight, (2, 3, 4))
646*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
647*da0073e9SAndroid Build Coastguard Worker            if mode == "same":
648*da0073e9SAndroid Build Coastguard Worker                actual = actual[:5, :5, :10]
649*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
652*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_even, mode)
653*da0073e9SAndroid Build Coastguard Worker            _test(t, weight_odd, mode)
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
656*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_valid_padding_backward(self, device, dtype):
657*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
658*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
659*da0073e9SAndroid Build Coastguard Worker        F.conv2d(x, y, padding=0).sum().abs().backward()
660*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
661*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
662*da0073e9SAndroid Build Coastguard Worker        F.conv2d(x, y, padding="valid").sum().abs().backward()
663*da0073e9SAndroid Build Coastguard Worker        gx_actual, gy_actual = x.grad, y.grad
664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, gx_actual)
665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, gy_actual)
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
668*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_valid_padding_backward(self, device, dtype):
669*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
670*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
671*da0073e9SAndroid Build Coastguard Worker        F.conv3d(x, y, padding=0).sum().abs().backward()
672*da0073e9SAndroid Build Coastguard Worker        gx_expect, gy_expect = x.grad, y.grad
673*da0073e9SAndroid Build Coastguard Worker        x.grad, y.grad = None, None
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        F.conv3d(x, y, padding="valid").sum().abs().backward()
676*da0073e9SAndroid Build Coastguard Worker        gx_actual, gy_actual = x.grad, y.grad
677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx_expect, gx_actual)
678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gy_expect, gy_actual)
679*da0073e9SAndroid Build Coastguard Worker        gradcheck(
680*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="valid"),
681*da0073e9SAndroid Build Coastguard Worker            (x, y),
682*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
683*da0073e9SAndroid Build Coastguard Worker        )
684*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
685*da0073e9SAndroid Build Coastguard Worker            lambda x, y: F.conv3d(x, y, padding="valid"),
686*da0073e9SAndroid Build Coastguard Worker            (x, y),
687*da0073e9SAndroid Build Coastguard Worker            check_fwd_over_rev=True,
688*da0073e9SAndroid Build Coastguard Worker        )
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d")
691*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
692*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
693*da0073e9SAndroid Build Coastguard Worker        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
694*da0073e9SAndroid Build Coastguard Worker        ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
695*da0073e9SAndroid Build Coastguard Worker        m = ConvTransposeNd(
696*da0073e9SAndroid Build Coastguard Worker            1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
697*da0073e9SAndroid Build Coastguard Worker        )
698*da0073e9SAndroid Build Coastguard Worker        output = m(inp, output_size=output_size)
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, output_size)
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
702*da0073e9SAndroid Build Coastguard Worker    def test_conv_empty_channel(self, device, dtype):
703*da0073e9SAndroid Build Coastguard Worker        in_channels = 0
704*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
705*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
706*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
709*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
710*da0073e9SAndroid Build Coastguard Worker            mod(inp)
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
713*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
714*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
717*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
718*da0073e9SAndroid Build Coastguard Worker            mod(inp)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
721*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
722*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
725*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
726*da0073e9SAndroid Build Coastguard Worker            mod(inp)
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker    def test_group_conv_empty(self, device):
729*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
730*da0073e9SAndroid Build Coastguard Worker            device
731*da0073e9SAndroid Build Coastguard Worker        )
732*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 4, 4, 4, device=device)
733*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def test_group_convTranspose_empty(self, device):
736*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.ConvTranspose2d(
737*da0073e9SAndroid Build Coastguard Worker            4, 4, stride=2, kernel_size=3, padding=1, groups=4
738*da0073e9SAndroid Build Coastguard Worker        ).to(device)
739*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 4, 4, 4, device=device)
740*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    def test_convTranspose_empty(self, device):
743*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
744*da0073e9SAndroid Build Coastguard Worker            device
745*da0073e9SAndroid Build Coastguard Worker        )
746*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 4, 4, 4, device=device)
747*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker    def test_conv_large_nosplit(self, device):
750*da0073e9SAndroid Build Coastguard Worker        dtype = torch.half
751*da0073e9SAndroid Build Coastguard Worker        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
752*da0073e9SAndroid Build Coastguard Worker        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
753*da0073e9SAndroid Build Coastguard Worker        conv1(input_large)
754*da0073e9SAndroid Build Coastguard Worker        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
755*da0073e9SAndroid Build Coastguard Worker        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
756*da0073e9SAndroid Build Coastguard Worker        conv2(input_large)
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker    def test_conv_noncontig_weights(self, device):
759*da0073e9SAndroid Build Coastguard Worker        for dim in (1, 2, 3):
760*da0073e9SAndroid Build Coastguard Worker            for grouped in (False, True):
761*da0073e9SAndroid Build Coastguard Worker                nc = 3
762*da0073e9SAndroid Build Coastguard Worker                groups = 3 if grouped else 1
763*da0073e9SAndroid Build Coastguard Worker                w = torch.randn([3] * dim, device=device)
764*da0073e9SAndroid Build Coastguard Worker                w = w.expand([nc, int(nc / groups)] + list(w.shape))
765*da0073e9SAndroid Build Coastguard Worker                w = w.detach().requires_grad_()
766*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(
767*da0073e9SAndroid Build Coastguard Worker                    [1, nc] + ([5] * dim), device=device, requires_grad=True
768*da0073e9SAndroid Build Coastguard Worker                )
769*da0073e9SAndroid Build Coastguard Worker                y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
770*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
771*da0073e9SAndroid Build Coastguard Worker                y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
772*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    def test_conv_noncontig_weights_and_bias(self, device):
775*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
776*da0073e9SAndroid Build Coastguard Worker            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
777*da0073e9SAndroid Build Coastguard Worker                device, torch.float
778*da0073e9SAndroid Build Coastguard Worker            )
779*da0073e9SAndroid Build Coastguard Worker            input_nc = torch.randn(
780*da0073e9SAndroid Build Coastguard Worker                (1, 3, 224, 224, 2), device=device, dtype=torch.float
781*da0073e9SAndroid Build Coastguard Worker            )[:, :, :, :, 1]
782*da0073e9SAndroid Build Coastguard Worker            input_c = input_nc.contiguous()
783*da0073e9SAndroid Build Coastguard Worker            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
784*da0073e9SAndroid Build Coastguard Worker                :, :, :, :, 1
785*da0073e9SAndroid Build Coastguard Worker            ]
786*da0073e9SAndroid Build Coastguard Worker            conv1.weight = nn.Parameter(weight_nc)
787*da0073e9SAndroid Build Coastguard Worker            weight_c = conv1.weight.contiguous()
788*da0073e9SAndroid Build Coastguard Worker            if bias:
789*da0073e9SAndroid Build Coastguard Worker                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
790*da0073e9SAndroid Build Coastguard Worker                conv1.bias = nn.Parameter(bias_nc)
791*da0073e9SAndroid Build Coastguard Worker                bias_c = conv1.bias.contiguous()
792*da0073e9SAndroid Build Coastguard Worker            out1 = conv1(input_nc)
793*da0073e9SAndroid Build Coastguard Worker            conv1.weight = nn.Parameter(weight_c)
794*da0073e9SAndroid Build Coastguard Worker            if bias:
795*da0073e9SAndroid Build Coastguard Worker                conv1.bias = nn.Parameter(bias_c)
796*da0073e9SAndroid Build Coastguard Worker            out2 = conv1(input_c)
797*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out1, out2)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker    def test_conv_transposed_large(self, device):
800*da0073e9SAndroid Build Coastguard Worker        dtype = torch.half if self.device_type == "cuda" else torch.float
801*da0073e9SAndroid Build Coastguard Worker        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
802*da0073e9SAndroid Build Coastguard Worker        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
803*da0073e9SAndroid Build Coastguard Worker        ret = conv(input_large)
804*da0073e9SAndroid Build Coastguard Worker        maxdiff0 = (
805*da0073e9SAndroid Build Coastguard Worker            (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
806*da0073e9SAndroid Build Coastguard Worker            .abs_()
807*da0073e9SAndroid Build Coastguard Worker            .max()
808*da0073e9SAndroid Build Coastguard Worker            .item()
809*da0073e9SAndroid Build Coastguard Worker        )
810*da0073e9SAndroid Build Coastguard Worker        maxdiff1 = (
811*da0073e9SAndroid Build Coastguard Worker            (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
812*da0073e9SAndroid Build Coastguard Worker            .abs_()
813*da0073e9SAndroid Build Coastguard Worker            .max()
814*da0073e9SAndroid Build Coastguard Worker            .item()
815*da0073e9SAndroid Build Coastguard Worker        )
816*da0073e9SAndroid Build Coastguard Worker        maxdiff2 = (
817*da0073e9SAndroid Build Coastguard Worker            (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
818*da0073e9SAndroid Build Coastguard Worker            .abs_()
819*da0073e9SAndroid Build Coastguard Worker            .max()
820*da0073e9SAndroid Build Coastguard Worker            .item()
821*da0073e9SAndroid Build Coastguard Worker        )
822*da0073e9SAndroid Build Coastguard Worker        maxdiff3 = (
823*da0073e9SAndroid Build Coastguard Worker            (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
824*da0073e9SAndroid Build Coastguard Worker            .abs_()
825*da0073e9SAndroid Build Coastguard Worker            .max()
826*da0073e9SAndroid Build Coastguard Worker            .item()
827*da0073e9SAndroid Build Coastguard Worker        )
828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(maxdiff0, 0)
829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(maxdiff1, 0)
830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(maxdiff2, 0)
831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(maxdiff3, 0)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    def test_conv_large(self, device):
834*da0073e9SAndroid Build Coastguard Worker        dtype = torch.half if self.device_type == "cuda" else torch.float
835*da0073e9SAndroid Build Coastguard Worker        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
836*da0073e9SAndroid Build Coastguard Worker        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
837*da0073e9SAndroid Build Coastguard Worker        ret = conv(input_large)
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret[:2048], conv(input_large[:2048]))
839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret[4096:], conv(input_large[4096:]))
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        conv.zero_grad()
843*da0073e9SAndroid Build Coastguard Worker        ret.view(4097, -1).max(dim=1).values.sum().backward()
844*da0073e9SAndroid Build Coastguard Worker        del ret
845*da0073e9SAndroid Build Coastguard Worker        grad1 = conv.weight.grad.detach().clone()
846*da0073e9SAndroid Build Coastguard Worker        conv.zero_grad()
847*da0073e9SAndroid Build Coastguard Worker        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
848*da0073e9SAndroid Build Coastguard Worker        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
849*da0073e9SAndroid Build Coastguard Worker        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
850*da0073e9SAndroid Build Coastguard Worker        grad2 = conv.weight.grad.detach().clone()
851*da0073e9SAndroid Build Coastguard Worker        scale = 1 / grad2.abs().mean()
852*da0073e9SAndroid Build Coastguard Worker        grad1 = grad1 * scale
853*da0073e9SAndroid Build Coastguard Worker        grad2 = grad2 * scale
854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def test_Conv2d_size_1_kernel(self, device):
857*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.randn(2, 3, 5, 5)
858*da0073e9SAndroid Build Coastguard Worker        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
859*da0073e9SAndroid Build Coastguard Worker        y_cpu = conv_cpu(x_cpu)
860*da0073e9SAndroid Build Coastguard Worker        y = torch.rand_like(y_cpu)
861*da0073e9SAndroid Build Coastguard Worker        y_cpu.backward(y)
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        with cudnn.flags(enabled=False):
864*da0073e9SAndroid Build Coastguard Worker            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
865*da0073e9SAndroid Build Coastguard Worker            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
866*da0073e9SAndroid Build Coastguard Worker            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
867*da0073e9SAndroid Build Coastguard Worker            y_cuda = conv_cuda(x_cpu.to(device))
868*da0073e9SAndroid Build Coastguard Worker            y_cuda.backward(y.to(device))
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
872*da0073e9SAndroid Build Coastguard Worker            conv_cpu.bias.grad.data,
873*da0073e9SAndroid Build Coastguard Worker            conv_cuda.bias.grad.data,
874*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
875*da0073e9SAndroid Build Coastguard Worker            rtol=0,
876*da0073e9SAndroid Build Coastguard Worker            exact_device=False,
877*da0073e9SAndroid Build Coastguard Worker        )
878*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
879*da0073e9SAndroid Build Coastguard Worker            conv_cpu.weight.grad.data,
880*da0073e9SAndroid Build Coastguard Worker            conv_cuda.weight.grad.data,
881*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
882*da0073e9SAndroid Build Coastguard Worker            rtol=0,
883*da0073e9SAndroid Build Coastguard Worker            exact_device=False,
884*da0073e9SAndroid Build Coastguard Worker        )
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker    def test_ConvTranspose2d_size_1_kernel(self, device):
887*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.randn(2, 3, 5, 5)
888*da0073e9SAndroid Build Coastguard Worker        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
889*da0073e9SAndroid Build Coastguard Worker        y_cpu = conv_cpu(x_cpu)
890*da0073e9SAndroid Build Coastguard Worker        y = torch.rand_like(y_cpu)
891*da0073e9SAndroid Build Coastguard Worker        y_cpu.backward(y)
892*da0073e9SAndroid Build Coastguard Worker        conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
893*da0073e9SAndroid Build Coastguard Worker        conv_cuda.bias.data.copy_(conv_cpu.bias.data)
894*da0073e9SAndroid Build Coastguard Worker        conv_cuda.weight.data.copy_(conv_cpu.weight.data)
895*da0073e9SAndroid Build Coastguard Worker        y_cuda = conv_cuda(x_cpu.to(device))
896*da0073e9SAndroid Build Coastguard Worker        y_cuda.backward(y.to(device))
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
899*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
900*da0073e9SAndroid Build Coastguard Worker            conv_cpu.bias.grad.data,
901*da0073e9SAndroid Build Coastguard Worker            conv_cuda.bias.grad.data,
902*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
903*da0073e9SAndroid Build Coastguard Worker            rtol=0,
904*da0073e9SAndroid Build Coastguard Worker            exact_device=False,
905*da0073e9SAndroid Build Coastguard Worker        )
906*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
907*da0073e9SAndroid Build Coastguard Worker            conv_cpu.weight.grad.data,
908*da0073e9SAndroid Build Coastguard Worker            conv_cuda.weight.grad.data,
909*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
910*da0073e9SAndroid Build Coastguard Worker            rtol=0,
911*da0073e9SAndroid Build Coastguard Worker            exact_device=False,
912*da0073e9SAndroid Build Coastguard Worker        )
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    def test_ConvTranspose3d_size_1_kernel(self, device):
915*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
916*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.randn(2, 3, 3, 5, 5)
917*da0073e9SAndroid Build Coastguard Worker            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
918*da0073e9SAndroid Build Coastguard Worker            y_cpu = conv_cpu(x_cpu)
919*da0073e9SAndroid Build Coastguard Worker            y = torch.rand_like(y_cpu)
920*da0073e9SAndroid Build Coastguard Worker            y_cpu.backward(y)
921*da0073e9SAndroid Build Coastguard Worker            conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
922*da0073e9SAndroid Build Coastguard Worker            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
923*da0073e9SAndroid Build Coastguard Worker            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
924*da0073e9SAndroid Build Coastguard Worker            y_cuda = conv_cuda(x_cpu.to(device))
925*da0073e9SAndroid Build Coastguard Worker            y_cuda.backward(y.to(device))
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
928*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
929*da0073e9SAndroid Build Coastguard Worker                conv_cpu.bias.grad.data,
930*da0073e9SAndroid Build Coastguard Worker                conv_cuda.bias.grad.data,
931*da0073e9SAndroid Build Coastguard Worker                atol=1e-5,
932*da0073e9SAndroid Build Coastguard Worker                rtol=0,
933*da0073e9SAndroid Build Coastguard Worker                exact_device=False,
934*da0073e9SAndroid Build Coastguard Worker            )
935*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
936*da0073e9SAndroid Build Coastguard Worker                conv_cpu.weight.grad.data,
937*da0073e9SAndroid Build Coastguard Worker                conv_cuda.weight.grad.data,
938*da0073e9SAndroid Build Coastguard Worker                atol=1e-5,
939*da0073e9SAndroid Build Coastguard Worker                rtol=0,
940*da0073e9SAndroid Build Coastguard Worker                exact_device=False,
941*da0073e9SAndroid Build Coastguard Worker            )
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
944*da0073e9SAndroid Build Coastguard Worker    def test_Conv2d_naive_groups(self, device, dtype):
945*da0073e9SAndroid Build Coastguard Worker        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
946*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
947*da0073e9SAndroid Build Coastguard Worker        output = m(i)
948*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
949*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output)
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
952*da0073e9SAndroid Build Coastguard Worker        m1.weight.data.copy_(m.weight.data[:2])
953*da0073e9SAndroid Build Coastguard Worker        m1.bias.data.copy_(m.bias.data[:2])
954*da0073e9SAndroid Build Coastguard Worker        i1 = i.data[:, :2].contiguous().requires_grad_(True)
955*da0073e9SAndroid Build Coastguard Worker        output1 = m1(i1)
956*da0073e9SAndroid Build Coastguard Worker        output1.backward(grad_output[:, :2].contiguous())
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
959*da0073e9SAndroid Build Coastguard Worker        m2.weight.data.copy_(m.weight.data[2:])
960*da0073e9SAndroid Build Coastguard Worker        m2.bias.data.copy_(m.bias.data[2:])
961*da0073e9SAndroid Build Coastguard Worker        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
962*da0073e9SAndroid Build Coastguard Worker        output2 = m2(i2)
963*da0073e9SAndroid Build Coastguard Worker        output2.backward(grad_output[:, 2:].contiguous())
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, torch.cat([output1, output2], 1))
966*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
967*da0073e9SAndroid Build Coastguard Worker            i.grad.data,
968*da0073e9SAndroid Build Coastguard Worker            torch.cat([i1.grad.data, i2.grad.data], 1),
969*da0073e9SAndroid Build Coastguard Worker            atol=dtype2prec_DONTUSE[dtype],
970*da0073e9SAndroid Build Coastguard Worker            rtol=0,
971*da0073e9SAndroid Build Coastguard Worker        )
972*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
973*da0073e9SAndroid Build Coastguard Worker            m.bias.grad.data,
974*da0073e9SAndroid Build Coastguard Worker            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
975*da0073e9SAndroid Build Coastguard Worker            atol=dtype2prec_DONTUSE[dtype],
976*da0073e9SAndroid Build Coastguard Worker            rtol=0,
977*da0073e9SAndroid Build Coastguard Worker        )
978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
979*da0073e9SAndroid Build Coastguard Worker            m.weight.grad.data,
980*da0073e9SAndroid Build Coastguard Worker            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
981*da0073e9SAndroid Build Coastguard Worker            atol=dtype2prec_DONTUSE[dtype],
982*da0073e9SAndroid Build Coastguard Worker            rtol=0,
983*da0073e9SAndroid Build Coastguard Worker        )
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
986*da0073e9SAndroid Build Coastguard Worker    def test_Conv2d_backward_depthwise(self, device, dtype):
987*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
988*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        def conv2d_depthwise(x, weight):
991*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.conv2d(
992*da0073e9SAndroid Build Coastguard Worker                x, weight, bias=None, stride=(1, 10), groups=2
993*da0073e9SAndroid Build Coastguard Worker            )
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float)
998*da0073e9SAndroid Build Coastguard Worker    def test_conv_cudnn_nhwc(self, device, dtype):
999*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, out_channels, kernel_size, groups):
1000*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
1001*da0073e9SAndroid Build Coastguard Worker                memory_format=torch.channels_last
1002*da0073e9SAndroid Build Coastguard Worker            )
1003*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_()
1004*da0073e9SAndroid Build Coastguard Worker            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
1005*da0073e9SAndroid Build Coastguard Worker                device=device, dtype=dtype, memory_format=torch.channels_last
1006*da0073e9SAndroid Build Coastguard Worker            )
1007*da0073e9SAndroid Build Coastguard Worker            for p in conv.parameters():
1008*da0073e9SAndroid Build Coastguard Worker                p.data = torch.randint_like(p, -3, 3)
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().double().requires_grad_()
1011*da0073e9SAndroid Build Coastguard Worker            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
1012*da0073e9SAndroid Build Coastguard Worker            ref_conv.load_state_dict(conv.state_dict())
1013*da0073e9SAndroid Build Coastguard Worker            ref_conv = ref_conv.to(
1014*da0073e9SAndroid Build Coastguard Worker                device=device, dtype=torch.double, memory_format=torch.contiguous_format
1015*da0073e9SAndroid Build Coastguard Worker            )
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker            out = conv(input)
1018*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_conv(ref_input)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint_like(out, -3, 3)
1021*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().double().contiguous()
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1024*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1027*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
1028*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1029*da0073e9SAndroid Build Coastguard Worker                conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
1030*da0073e9SAndroid Build Coastguard Worker            )
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1033*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_input.grad.is_contiguous())
1034*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_conv.weight.grad.is_contiguous())
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out, exact_dtype=False)
1037*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
1038*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
1039*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
1042*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
1043*da0073e9SAndroid Build Coastguard Worker        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
1044*da0073e9SAndroid Build Coastguard Worker        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float)
1047*da0073e9SAndroid Build Coastguard Worker    def test_conv_cudnn_ndhwc(self, device, dtype):
1048*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
1049*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(
1050*da0073e9SAndroid Build Coastguard Worker                -2, 2, (n, c, d, h, w), dtype=dtype, device=device
1051*da0073e9SAndroid Build Coastguard Worker            ).to(memory_format=torch.channels_last_3d)
1052*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_()
1053*da0073e9SAndroid Build Coastguard Worker            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
1054*da0073e9SAndroid Build Coastguard Worker                device=device, dtype=dtype, memory_format=torch.channels_last_3d
1055*da0073e9SAndroid Build Coastguard Worker            )
1056*da0073e9SAndroid Build Coastguard Worker            for p in conv.parameters():
1057*da0073e9SAndroid Build Coastguard Worker                p.data = torch.randint_like(p, -2, 2)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().double().requires_grad_()
1060*da0073e9SAndroid Build Coastguard Worker            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
1061*da0073e9SAndroid Build Coastguard Worker            ref_conv.load_state_dict(conv.state_dict())
1062*da0073e9SAndroid Build Coastguard Worker            ref_conv = ref_conv.to(
1063*da0073e9SAndroid Build Coastguard Worker                device=device, dtype=torch.double, memory_format=torch.contiguous_format
1064*da0073e9SAndroid Build Coastguard Worker            )
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker            out = conv(input)
1067*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_conv(ref_input)
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint_like(out, -2, 2)
1070*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().double().contiguous()
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1073*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
1076*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1077*da0073e9SAndroid Build Coastguard Worker                input.grad.is_contiguous(memory_format=torch.channels_last_3d)
1078*da0073e9SAndroid Build Coastguard Worker            )
1079*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1080*da0073e9SAndroid Build Coastguard Worker                conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
1081*da0073e9SAndroid Build Coastguard Worker            )
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1084*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_input.grad.is_contiguous())
1085*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_conv.weight.grad.is_contiguous())
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out, exact_dtype=False)
1088*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
1089*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
1090*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
1093*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
1094*da0073e9SAndroid Build Coastguard Worker        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
1095*da0073e9SAndroid Build Coastguard Worker        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker    def _run_conv(
1098*da0073e9SAndroid Build Coastguard Worker        self,
1099*da0073e9SAndroid Build Coastguard Worker        layer,
1100*da0073e9SAndroid Build Coastguard Worker        device,
1101*da0073e9SAndroid Build Coastguard Worker        inp,
1102*da0073e9SAndroid Build Coastguard Worker        grad,
1103*da0073e9SAndroid Build Coastguard Worker        ref_conv,
1104*da0073e9SAndroid Build Coastguard Worker        ref_input,
1105*da0073e9SAndroid Build Coastguard Worker        ref_out,
1106*da0073e9SAndroid Build Coastguard Worker        input_format,
1107*da0073e9SAndroid Build Coastguard Worker        weight_format,
1108*da0073e9SAndroid Build Coastguard Worker        grad_format,
1109*da0073e9SAndroid Build Coastguard Worker        output_format,
1110*da0073e9SAndroid Build Coastguard Worker    ):
1111*da0073e9SAndroid Build Coastguard Worker        conv = (
1112*da0073e9SAndroid Build Coastguard Worker            layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
1113*da0073e9SAndroid Build Coastguard Worker        )
1114*da0073e9SAndroid Build Coastguard Worker        conv.load_state_dict(ref_conv.state_dict())
1115*da0073e9SAndroid Build Coastguard Worker        weight_data = (
1116*da0073e9SAndroid Build Coastguard Worker            conv.weight.detach().clone().contiguous(memory_format=weight_format)
1117*da0073e9SAndroid Build Coastguard Worker        )
1118*da0073e9SAndroid Build Coastguard Worker        conv.weight.data = weight_data.resize_(
1119*da0073e9SAndroid Build Coastguard Worker            weight_data.size(), memory_format=weight_format
1120*da0073e9SAndroid Build Coastguard Worker        )
1121*da0073e9SAndroid Build Coastguard Worker        input = inp.clone().contiguous(memory_format=input_format)
1122*da0073e9SAndroid Build Coastguard Worker        input.resize_(input.size(), memory_format=input_format)
1123*da0073e9SAndroid Build Coastguard Worker        input = input.requires_grad_()
1124*da0073e9SAndroid Build Coastguard Worker        grad = grad.contiguous(memory_format=grad_format)
1125*da0073e9SAndroid Build Coastguard Worker        grad.resize_(grad.size(), memory_format=grad_format)
1126*da0073e9SAndroid Build Coastguard Worker        out = conv(input)
1127*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
1128*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=output_format))
1129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
1130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
1131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
1132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, ref_input.grad)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
1135*da0073e9SAndroid Build Coastguard Worker        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
1136*da0073e9SAndroid Build Coastguard Worker        ref_input = data.clone().contiguous().requires_grad_(True)
1137*da0073e9SAndroid Build Coastguard Worker        ref_conv = layer(c, k, filter_size).float().to(device)
1138*da0073e9SAndroid Build Coastguard Worker        ref_out = ref_conv(ref_input)
1139*da0073e9SAndroid Build Coastguard Worker        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device=device)
1140*da0073e9SAndroid Build Coastguard Worker        ref_out.backward(grad)
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker        for w_f in [torch.contiguous_format, torch.channels_last]:
1143*da0073e9SAndroid Build Coastguard Worker            for g_f in [torch.contiguous_format, torch.channels_last]:
1144*da0073e9SAndroid Build Coastguard Worker                for input_format in [torch.contiguous_format, torch.channels_last]:
1145*da0073e9SAndroid Build Coastguard Worker                    output_format = torch.contiguous_format
1146*da0073e9SAndroid Build Coastguard Worker                    if input_format == torch.channels_last:
1147*da0073e9SAndroid Build Coastguard Worker                        output_format = torch.channels_last
1148*da0073e9SAndroid Build Coastguard Worker                    if w_f == torch.channels_last:
1149*da0073e9SAndroid Build Coastguard Worker                        output_format = torch.channels_last
1150*da0073e9SAndroid Build Coastguard Worker                    self._run_conv(
1151*da0073e9SAndroid Build Coastguard Worker                        layer,
1152*da0073e9SAndroid Build Coastguard Worker                        device,
1153*da0073e9SAndroid Build Coastguard Worker                        data,
1154*da0073e9SAndroid Build Coastguard Worker                        grad,
1155*da0073e9SAndroid Build Coastguard Worker                        ref_conv,
1156*da0073e9SAndroid Build Coastguard Worker                        ref_input,
1157*da0073e9SAndroid Build Coastguard Worker                        ref_out,
1158*da0073e9SAndroid Build Coastguard Worker                        input_format,
1159*da0073e9SAndroid Build Coastguard Worker                        w_f,
1160*da0073e9SAndroid Build Coastguard Worker                        g_f,
1161*da0073e9SAndroid Build Coastguard Worker                        output_format,
1162*da0073e9SAndroid Build Coastguard Worker                    )
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1165*da0073e9SAndroid Build Coastguard Worker    def test_conv_cudnn_nhwc_support(self, device, dtype):
1166*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(
1167*da0073e9SAndroid Build Coastguard Worker            (1, 16, 1, 1), dtype=dtype, device=device, requires_grad=True
1168*da0073e9SAndroid Build Coastguard Worker        )
1169*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
1170*da0073e9SAndroid Build Coastguard Worker            (8, 16, 3, 3), dtype=dtype, device=device, requires_grad=True
1171*da0073e9SAndroid Build Coastguard Worker        )
1172*da0073e9SAndroid Build Coastguard Worker        weight = weight.to(memory_format=torch.channels_last)
1173*da0073e9SAndroid Build Coastguard Worker        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
1174*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
1175*da0073e9SAndroid Build Coastguard Worker        o.sum().backward()
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1178*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_no_grad(self, device, dtype):
1179*da0073e9SAndroid Build Coastguard Worker        for batch in [1, 2, 3]:
1180*da0073e9SAndroid Build Coastguard Worker            for groups in [1, 2, 4]:
1181*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
1182*da0073e9SAndroid Build Coastguard Worker                m = nn.Conv2d(
1183*da0073e9SAndroid Build Coastguard Worker                    groups,
1184*da0073e9SAndroid Build Coastguard Worker                    8,
1185*da0073e9SAndroid Build Coastguard Worker                    kernel_size=(3, 3),
1186*da0073e9SAndroid Build Coastguard Worker                    groups=groups,
1187*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
1188*da0073e9SAndroid Build Coastguard Worker                    device=device,
1189*da0073e9SAndroid Build Coastguard Worker                )
1190*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
1191*da0073e9SAndroid Build Coastguard Worker                    output_ng = m(input)
1192*da0073e9SAndroid Build Coastguard Worker                output = m(input)
1193*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
1196*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 6, device=device)
1197*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(3, 3, 3, device=device)
1198*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(3, device=device)
1199*da0073e9SAndroid Build Coastguard Worker        stride = (2,)
1200*da0073e9SAndroid Build Coastguard Worker        padding = (1,)
1201*da0073e9SAndroid Build Coastguard Worker        dilation = (1,)
1202*da0073e9SAndroid Build Coastguard Worker        transposed = False
1203*da0073e9SAndroid Build Coastguard Worker        output_padding = (0,)
1204*da0073e9SAndroid Build Coastguard Worker        groups = 1
1205*da0073e9SAndroid Build Coastguard Worker        output = torch.ops.aten.convolution(
1206*da0073e9SAndroid Build Coastguard Worker            input,
1207*da0073e9SAndroid Build Coastguard Worker            weight,
1208*da0073e9SAndroid Build Coastguard Worker            bias,
1209*da0073e9SAndroid Build Coastguard Worker            stride,
1210*da0073e9SAndroid Build Coastguard Worker            padding,
1211*da0073e9SAndroid Build Coastguard Worker            dilation,
1212*da0073e9SAndroid Build Coastguard Worker            transposed,
1213*da0073e9SAndroid Build Coastguard Worker            output_padding,
1214*da0073e9SAndroid Build Coastguard Worker            groups,
1215*da0073e9SAndroid Build Coastguard Worker        )
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        ggI = torch.randn(input.shape, device=device)
1218*da0073e9SAndroid Build Coastguard Worker        ggW = torch.randn(weight.shape, device=device)
1219*da0073e9SAndroid Build Coastguard Worker        ggB = torch.randn(bias.shape, device=device)
1220*da0073e9SAndroid Build Coastguard Worker        gO = torch.randn(output.shape, device=device)
1221*da0073e9SAndroid Build Coastguard Worker        output_mask = [True, True, True]
1222*da0073e9SAndroid Build Coastguard Worker        (
1223*da0073e9SAndroid Build Coastguard Worker            grad_grad_output,
1224*da0073e9SAndroid Build Coastguard Worker            grad_input,
1225*da0073e9SAndroid Build Coastguard Worker            grad_weight,
1226*da0073e9SAndroid Build Coastguard Worker        ) = torch.ops.aten._convolution_double_backward(
1227*da0073e9SAndroid Build Coastguard Worker            ggI,
1228*da0073e9SAndroid Build Coastguard Worker            ggW,
1229*da0073e9SAndroid Build Coastguard Worker            ggB,
1230*da0073e9SAndroid Build Coastguard Worker            gO,
1231*da0073e9SAndroid Build Coastguard Worker            weight,
1232*da0073e9SAndroid Build Coastguard Worker            input,
1233*da0073e9SAndroid Build Coastguard Worker            stride,
1234*da0073e9SAndroid Build Coastguard Worker            padding,
1235*da0073e9SAndroid Build Coastguard Worker            dilation,
1236*da0073e9SAndroid Build Coastguard Worker            transposed,
1237*da0073e9SAndroid Build Coastguard Worker            output_padding,
1238*da0073e9SAndroid Build Coastguard Worker            groups,
1239*da0073e9SAndroid Build Coastguard Worker            output_mask,
1240*da0073e9SAndroid Build Coastguard Worker        )
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_grad_output.shape, gO.shape)
1243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_input.shape, input.shape)
1244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_weight.shape, weight.shape)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker    @onlyXPU
1247*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1248*da0073e9SAndroid Build Coastguard Worker    def test_channels_last_ouput_stride(self, device, dtype):
1249*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(
1250*da0073e9SAndroid Build Coastguard Worker            (2, 3, 16, 16), device=device, dtype=dtype, requires_grad=True
1251*da0073e9SAndroid Build Coastguard Worker        )
1252*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
1253*da0073e9SAndroid Build Coastguard Worker            (512, 3, 3, 3), device=device, dtype=dtype, requires_grad=True
1254*da0073e9SAndroid Build Coastguard Worker        )
1255*da0073e9SAndroid Build Coastguard Worker        input = input.to(memory_format=torch.channels_last)
1256*da0073e9SAndroid Build Coastguard Worker        weight = weight.to(memory_format=torch.channels_last)
1257*da0073e9SAndroid Build Coastguard Worker        out = torch.conv2d(input, weight, None, (2, 2), (0, 0), (1, 1), 1)
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.float64:
1260*da0073e9SAndroid Build Coastguard Worker            # Like most conv backend, xpu does not support float64 for chanel last conv.
1261*da0073e9SAndroid Build Coastguard Worker            # input NHWC, output NCHW
1262*da0073e9SAndroid Build Coastguard Worker            assert_size_stride(out, (2, 512, 7, 7), (25088, 49, 7, 1))
1263*da0073e9SAndroid Build Coastguard Worker        else:
1264*da0073e9SAndroid Build Coastguard Worker            # input NHWC, output NHWC
1265*da0073e9SAndroid Build Coastguard Worker            assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512))
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(
1269*da0073e9SAndroid Build Coastguard Worker    TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True
1270*da0073e9SAndroid Build Coastguard Worker)
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1273*da0073e9SAndroid Build Coastguard Worker    run_tests()
1274