1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.backends.cudnn as cudnn 10*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 12*da0073e9SAndroid Build Coastguard Workerfrom torch._C._dynamo.guards import assert_size_stride 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_is_not_fp32 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 16*da0073e9SAndroid Build Coastguard Worker dtypes, 17*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 18*da0073e9SAndroid Build Coastguard Worker onlyXPU, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase 22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 23*da0073e9SAndroid Build Coastguard Worker dtype2prec_DONTUSE, 24*da0073e9SAndroid Build Coastguard Worker gradcheck, 25*da0073e9SAndroid Build Coastguard Worker gradgradcheck, 26*da0073e9SAndroid Build Coastguard Worker parametrize as parametrize_test, 27*da0073e9SAndroid Build Coastguard Worker run_tests, 28*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 29*da0073e9SAndroid Build Coastguard Worker TEST_SCIPY, 30*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() 35*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 36*da0073e9SAndroid Build Coastguard Worker import scipy.ndimage 37*da0073e9SAndroid Build Coastguard Worker import scipy.signal 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass TestConvolutionNNDeviceType(NNTestCase): 41*da0073e9SAndroid Build Coastguard Worker def run_conv_double_back_test( 42*da0073e9SAndroid Build Coastguard Worker self, 43*da0073e9SAndroid Build Coastguard Worker kern, 44*da0073e9SAndroid Build Coastguard Worker stride, 45*da0073e9SAndroid Build Coastguard Worker padding, 46*da0073e9SAndroid Build Coastguard Worker chan_in, 47*da0073e9SAndroid Build Coastguard Worker chan_out, 48*da0073e9SAndroid Build Coastguard Worker batch_size, 49*da0073e9SAndroid Build Coastguard Worker inp_size, 50*da0073e9SAndroid Build Coastguard Worker dilation, 51*da0073e9SAndroid Build Coastguard Worker no_weight, 52*da0073e9SAndroid Build Coastguard Worker groups=1, 53*da0073e9SAndroid Build Coastguard Worker use_xpu=False, 54*da0073e9SAndroid Build Coastguard Worker use_bias=True, 55*da0073e9SAndroid Build Coastguard Worker dtype=torch.double, 56*da0073e9SAndroid Build Coastguard Worker ): 57*da0073e9SAndroid Build Coastguard Worker device = torch.device("xpu" if use_xpu else "cpu") 58*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 59*da0073e9SAndroid Build Coastguard Worker batch_size, 60*da0073e9SAndroid Build Coastguard Worker chan_in, 61*da0073e9SAndroid Build Coastguard Worker inp_size, 62*da0073e9SAndroid Build Coastguard Worker inp_size, 63*da0073e9SAndroid Build Coastguard Worker device=device, 64*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 65*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 66*da0073e9SAndroid Build Coastguard Worker ) 67*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 68*da0073e9SAndroid Build Coastguard Worker chan_out, 69*da0073e9SAndroid Build Coastguard Worker chan_in // groups, 70*da0073e9SAndroid Build Coastguard Worker kern, 71*da0073e9SAndroid Build Coastguard Worker kern, 72*da0073e9SAndroid Build Coastguard Worker device=device, 73*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 74*da0073e9SAndroid Build Coastguard Worker requires_grad=not no_weight, 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker if use_bias: 77*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) 78*da0073e9SAndroid Build Coastguard Worker else: 79*da0073e9SAndroid Build Coastguard Worker bias = None 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def func(*inputs): 82*da0073e9SAndroid Build Coastguard Worker if use_bias: 83*da0073e9SAndroid Build Coastguard Worker lx, lweight, lbias = inputs 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker lx, lweight = inputs 86*da0073e9SAndroid Build Coastguard Worker lbias = None 87*da0073e9SAndroid Build Coastguard Worker out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) 88*da0073e9SAndroid Build Coastguard Worker return out 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker if use_bias: 91*da0073e9SAndroid Build Coastguard Worker inputs = x, weight, bias 92*da0073e9SAndroid Build Coastguard Worker else: 93*da0073e9SAndroid Build Coastguard Worker inputs = x, weight 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker dummy_out = func(*inputs) 96*da0073e9SAndroid Build Coastguard Worker grad_y = torch.randn_like( 97*da0073e9SAndroid Build Coastguard Worker dummy_out, device=device, dtype=dtype, requires_grad=True 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 101*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) 102*da0073e9SAndroid Build Coastguard Worker return g.requires_grad 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker return gradgradcheck(func, inputs, (grad_y,)) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 107*da0073e9SAndroid Build Coastguard Worker def test_Conv2d_large_workspace(self, device, dtype): 108*da0073e9SAndroid Build Coastguard Worker sizes = [ 109*da0073e9SAndroid Build Coastguard Worker (1, 256, 109, 175), 110*da0073e9SAndroid Build Coastguard Worker (1, 256, 80, 128), 111*da0073e9SAndroid Build Coastguard Worker (1, 256, 120, 192), 112*da0073e9SAndroid Build Coastguard Worker ] 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def run_test(benchmark): 115*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype) 116*da0073e9SAndroid Build Coastguard Worker for size in sizes: 117*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, device=device, dtype=dtype) 118*da0073e9SAndroid Build Coastguard Worker out = conv(x.detach().clone().requires_grad_()) 119*da0073e9SAndroid Build Coastguard Worker out.backward(torch.ones_like(out)) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker run_test(benchmark=False) 122*da0073e9SAndroid Build Coastguard Worker run_test(benchmark=True) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float) 125*da0073e9SAndroid Build Coastguard Worker def test_ConvTranspose2d_large_output_padding(self, device, dtype): 126*da0073e9SAndroid Build Coastguard Worker net1 = torch.nn.ConvTranspose2d( 127*da0073e9SAndroid Build Coastguard Worker 128, 64, kernel_size=3, stride=2, padding=1, output_padding=1 128*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 129*da0073e9SAndroid Build Coastguard Worker net2 = torch.nn.ConvTranspose2d( 130*da0073e9SAndroid Build Coastguard Worker 64, 32, kernel_size=3, stride=2, padding=1, output_padding=1 131*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 132*da0073e9SAndroid Build Coastguard Worker net3 = torch.nn.ConvTranspose2d( 133*da0073e9SAndroid Build Coastguard Worker 32, 3, kernel_size=3, stride=2, padding=1, output_padding=1 134*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 135*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) 136*da0073e9SAndroid Build Coastguard Worker x = net1(x) 137*da0073e9SAndroid Build Coastguard Worker x = net2(x) 138*da0073e9SAndroid Build Coastguard Worker x = net3(x) 139*da0073e9SAndroid Build Coastguard Worker x.backward(torch.randn_like(x)) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 142*da0073e9SAndroid Build Coastguard Worker def test_Conv2d_depthwise_naive_groups(self, device, dtype): 143*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and "xpu" in device: 144*da0073e9SAndroid Build Coastguard Worker self.skipTest( 145*da0073e9SAndroid Build Coastguard Worker "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4" 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker for depth_multiplier in [1, 2]: 148*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( 149*da0073e9SAndroid Build Coastguard Worker device, dtype 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker i = ( 152*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 2, 6, 6, device=device, dtype=dtype) 153*da0073e9SAndroid Build Coastguard Worker .div_(2) 154*da0073e9SAndroid Build Coastguard Worker .requires_grad_() 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker output = m(i) 157*da0073e9SAndroid Build Coastguard Worker grad_output = ( 158*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) 159*da0073e9SAndroid Build Coastguard Worker / 2 160*da0073e9SAndroid Build Coastguard Worker ) 161*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker offset = 1 * depth_multiplier 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 166*da0073e9SAndroid Build Coastguard Worker m1.weight.data = m.weight.data[:offset].clone() 167*da0073e9SAndroid Build Coastguard Worker m1.bias.data = m.bias.data[:offset].clone() 168*da0073e9SAndroid Build Coastguard Worker i1 = i.detach()[:, :1].clone().requires_grad_() 169*da0073e9SAndroid Build Coastguard Worker output1 = m1(i1) 170*da0073e9SAndroid Build Coastguard Worker output1.backward(grad_output[:, :offset].contiguous()) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 173*da0073e9SAndroid Build Coastguard Worker m2.weight.data.copy_(m.weight.data[offset:]) 174*da0073e9SAndroid Build Coastguard Worker m2.bias.data.copy_(m.bias.data[offset:]) 175*da0073e9SAndroid Build Coastguard Worker i2 = i.detach()[:, 1:].clone().requires_grad_() 176*da0073e9SAndroid Build Coastguard Worker output2 = m2(i2) 177*da0073e9SAndroid Build Coastguard Worker output2.backward(grad_output[:, offset:].contiguous()) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 180*da0073e9SAndroid Build Coastguard Worker output, 181*da0073e9SAndroid Build Coastguard Worker torch.cat([output1, output2], 1), 182*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 183*da0073e9SAndroid Build Coastguard Worker rtol=0, 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 186*da0073e9SAndroid Build Coastguard Worker i.grad.data, 187*da0073e9SAndroid Build Coastguard Worker torch.cat([i1.grad.data, i2.grad.data], 1), 188*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 189*da0073e9SAndroid Build Coastguard Worker rtol=0, 190*da0073e9SAndroid Build Coastguard Worker ) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 192*da0073e9SAndroid Build Coastguard Worker m.bias.grad.data, 193*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 194*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 195*da0073e9SAndroid Build Coastguard Worker rtol=0, 196*da0073e9SAndroid Build Coastguard Worker ) 197*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 198*da0073e9SAndroid Build Coastguard Worker m.weight.grad.data, 199*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 200*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 201*da0073e9SAndroid Build Coastguard Worker rtol=0, 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 205*da0073e9SAndroid Build Coastguard Worker def test_Conv3d_depthwise_naive_groups(self, device, dtype): 206*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and "xpu" in device: 207*da0073e9SAndroid Build Coastguard Worker self.skipTest( 208*da0073e9SAndroid Build Coastguard Worker "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4" 209*da0073e9SAndroid Build Coastguard Worker ) 210*da0073e9SAndroid Build Coastguard Worker for depth_multiplier in [1, 2]: 211*da0073e9SAndroid Build Coastguard Worker m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( 212*da0073e9SAndroid Build Coastguard Worker device, dtype 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker i = ( 215*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 2, 6, 6, 6, device=device, dtype=dtype) 216*da0073e9SAndroid Build Coastguard Worker .div_(2) 217*da0073e9SAndroid Build Coastguard Worker .requires_grad_() 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker output = m(i) 220*da0073e9SAndroid Build Coastguard Worker grad_output = ( 221*da0073e9SAndroid Build Coastguard Worker torch.randn( 222*da0073e9SAndroid Build Coastguard Worker 2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker / 2 225*da0073e9SAndroid Build Coastguard Worker ) 226*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker offset = 1 * depth_multiplier 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 231*da0073e9SAndroid Build Coastguard Worker m1.weight.data = m.weight.data[:offset].clone() 232*da0073e9SAndroid Build Coastguard Worker m1.bias.data = m.bias.data[:offset].clone() 233*da0073e9SAndroid Build Coastguard Worker i1 = i.detach()[:, :1].clone().requires_grad_() 234*da0073e9SAndroid Build Coastguard Worker output1 = m1(i1) 235*da0073e9SAndroid Build Coastguard Worker output1.backward(grad_output[:, :offset].contiguous()) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 238*da0073e9SAndroid Build Coastguard Worker m2.weight.data.copy_(m.weight.data[offset:]) 239*da0073e9SAndroid Build Coastguard Worker m2.bias.data.copy_(m.bias.data[offset:]) 240*da0073e9SAndroid Build Coastguard Worker i2 = i.detach()[:, 1:].clone().requires_grad_() 241*da0073e9SAndroid Build Coastguard Worker output2 = m2(i2) 242*da0073e9SAndroid Build Coastguard Worker output2.backward(grad_output[:, offset:].contiguous()) 243*da0073e9SAndroid Build Coastguard Worker atol, rtol = (3e-4, 3e-2) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 246*da0073e9SAndroid Build Coastguard Worker output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol 247*da0073e9SAndroid Build Coastguard Worker ) 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 249*da0073e9SAndroid Build Coastguard Worker i.grad.data, 250*da0073e9SAndroid Build Coastguard Worker torch.cat([i1.grad.data, i2.grad.data], 1), 251*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 252*da0073e9SAndroid Build Coastguard Worker rtol=0, 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 255*da0073e9SAndroid Build Coastguard Worker m.bias.grad.data, 256*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 257*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 258*da0073e9SAndroid Build Coastguard Worker rtol=0, 259*da0073e9SAndroid Build Coastguard Worker ) 260*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 261*da0073e9SAndroid Build Coastguard Worker m.weight.grad.data, 262*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 263*da0073e9SAndroid Build Coastguard Worker atol=atol, 264*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 265*da0073e9SAndroid Build Coastguard Worker ) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 268*da0073e9SAndroid Build Coastguard Worker def test_noncontig_conv_grad(self, device, dtype): 269*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) 270*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 271*da0073e9SAndroid Build Coastguard Worker 2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True 272*da0073e9SAndroid Build Coastguard Worker ) 273*da0073e9SAndroid Build Coastguard Worker output = module(input) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] 276*da0073e9SAndroid Build Coastguard Worker assert not grad.is_contiguous() 277*da0073e9SAndroid Build Coastguard Worker output.backward(grad, retain_graph=True) 278*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(input.grad) 279*da0073e9SAndroid Build Coastguard Worker result = input.grad.data.clone() 280*da0073e9SAndroid Build Coastguard Worker input.grad.data.zero_() 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker output.backward(grad.contiguous()) 283*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 284*da0073e9SAndroid Build Coastguard Worker result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 288*da0073e9SAndroid Build Coastguard Worker def test_conv_double_backward(self, device, dtype): 289*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 290*da0073e9SAndroid Build Coastguard Worker batch_size = 1 291*da0073e9SAndroid Build Coastguard Worker for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: 292*da0073e9SAndroid Build Coastguard Worker for stride, padding, chan_in, chan_out, dilation in product( 293*da0073e9SAndroid Build Coastguard Worker [1], [2], [2], [3], dilations 294*da0073e9SAndroid Build Coastguard Worker ): 295*da0073e9SAndroid Build Coastguard Worker no_weight = stride == 2 296*da0073e9SAndroid Build Coastguard Worker result = self.run_conv_double_back_test( 297*da0073e9SAndroid Build Coastguard Worker kern, 298*da0073e9SAndroid Build Coastguard Worker stride, 299*da0073e9SAndroid Build Coastguard Worker padding, 300*da0073e9SAndroid Build Coastguard Worker chan_in, 301*da0073e9SAndroid Build Coastguard Worker chan_out, 302*da0073e9SAndroid Build Coastguard Worker batch_size, 303*da0073e9SAndroid Build Coastguard Worker inp_size, 304*da0073e9SAndroid Build Coastguard Worker dilation, 305*da0073e9SAndroid Build Coastguard Worker no_weight, 306*da0073e9SAndroid Build Coastguard Worker use_xpu=True, 307*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result, "Conv double backward test failed") 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker def test_conv_double_backward_no_bias(self): 312*da0073e9SAndroid Build Coastguard Worker kern, stride = 3, 2 313*da0073e9SAndroid Build Coastguard Worker chan_in, chan_out = 2, 4 314*da0073e9SAndroid Build Coastguard Worker batch_size, inp_size = 2, 5 315*da0073e9SAndroid Build Coastguard Worker padding, dilation = 1, 1 316*da0073e9SAndroid Build Coastguard Worker no_weight, use_bias = False, True 317*da0073e9SAndroid Build Coastguard Worker result = self.run_conv_double_back_test( 318*da0073e9SAndroid Build Coastguard Worker kern, 319*da0073e9SAndroid Build Coastguard Worker stride, 320*da0073e9SAndroid Build Coastguard Worker padding, 321*da0073e9SAndroid Build Coastguard Worker chan_in, 322*da0073e9SAndroid Build Coastguard Worker chan_out, 323*da0073e9SAndroid Build Coastguard Worker batch_size, 324*da0073e9SAndroid Build Coastguard Worker inp_size, 325*da0073e9SAndroid Build Coastguard Worker dilation, 326*da0073e9SAndroid Build Coastguard Worker no_weight, 327*da0073e9SAndroid Build Coastguard Worker use_bias=use_bias, 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result, "Conv double backward test failed") 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker def test_conv_double_backward_groups(self): 332*da0073e9SAndroid Build Coastguard Worker kern, stride, padding = 3, 1, 2 333*da0073e9SAndroid Build Coastguard Worker chan_in, chan_out = 2, 4 334*da0073e9SAndroid Build Coastguard Worker batch_size, inp_size, dilation = 2, 6, 1 335*da0073e9SAndroid Build Coastguard Worker no_weight = False 336*da0073e9SAndroid Build Coastguard Worker groups = 2 337*da0073e9SAndroid Build Coastguard Worker result = self.run_conv_double_back_test( 338*da0073e9SAndroid Build Coastguard Worker kern, 339*da0073e9SAndroid Build Coastguard Worker stride, 340*da0073e9SAndroid Build Coastguard Worker padding, 341*da0073e9SAndroid Build Coastguard Worker chan_in * groups, 342*da0073e9SAndroid Build Coastguard Worker chan_out * groups, 343*da0073e9SAndroid Build Coastguard Worker batch_size, 344*da0073e9SAndroid Build Coastguard Worker inp_size, 345*da0073e9SAndroid Build Coastguard Worker dilation, 346*da0073e9SAndroid Build Coastguard Worker no_weight, 347*da0073e9SAndroid Build Coastguard Worker groups=groups, 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result, "Conv double backward test failed") 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker def test_conv_double_backward_stride(self): 352*da0073e9SAndroid Build Coastguard Worker batch_size = 2 353*da0073e9SAndroid Build Coastguard Worker for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: 354*da0073e9SAndroid Build Coastguard Worker for stride, padding, chan_in, chan_out, dilation in product( 355*da0073e9SAndroid Build Coastguard Worker [2], [0, 1], [1], [2], dilations 356*da0073e9SAndroid Build Coastguard Worker ): 357*da0073e9SAndroid Build Coastguard Worker no_weight = False 358*da0073e9SAndroid Build Coastguard Worker self.run_conv_double_back_test( 359*da0073e9SAndroid Build Coastguard Worker kern, 360*da0073e9SAndroid Build Coastguard Worker stride, 361*da0073e9SAndroid Build Coastguard Worker padding, 362*da0073e9SAndroid Build Coastguard Worker chan_in, 363*da0073e9SAndroid Build Coastguard Worker chan_out, 364*da0073e9SAndroid Build Coastguard Worker batch_size, 365*da0073e9SAndroid Build Coastguard Worker inp_size, 366*da0073e9SAndroid Build Coastguard Worker dilation, 367*da0073e9SAndroid Build Coastguard Worker no_weight, 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 371*da0073e9SAndroid Build Coastguard Worker def test_conv1d_same_padding(self, device, dtype): 372*da0073e9SAndroid Build Coastguard Worker test_args = [ 373*da0073e9SAndroid Build Coastguard Worker range(50, 55), 374*da0073e9SAndroid Build Coastguard Worker [1, 2, 3, 8], 375*da0073e9SAndroid Build Coastguard Worker range(1, 4), 376*da0073e9SAndroid Build Coastguard Worker [1], 377*da0073e9SAndroid Build Coastguard Worker ] 378*da0073e9SAndroid Build Coastguard Worker for in_size, k_size, dilation, stride in itertools.product(*test_args): 379*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, in_size, device=device, dtype=dtype) 380*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, k_size, device=device, dtype=dtype) 381*da0073e9SAndroid Build Coastguard Worker z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride) 382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 12, device=device, dtype=dtype) 385*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 3, device=device, dtype=dtype) 386*da0073e9SAndroid Build Coastguard Worker expect = F.conv1d(x, y, padding=1) 387*da0073e9SAndroid Build Coastguard Worker actual = F.conv1d(x, y, padding="same") 388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 12, device=device, dtype=dtype) 391*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, device=device, dtype=dtype) 392*da0073e9SAndroid Build Coastguard Worker expect = F.conv1d(x, y, padding=3, dilation=2) 393*da0073e9SAndroid Build Coastguard Worker actual = F.conv1d(x, y, padding="same", dilation=2) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] 397*da0073e9SAndroid Build Coastguard Worker actual = F.conv1d(x, y, padding="same", dilation=3) 398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 401*da0073e9SAndroid Build Coastguard Worker def test_conv3d_same_padding(self, device, dtype): 402*da0073e9SAndroid Build Coastguard Worker rtol, atol = None, None 403*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) 404*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) 405*da0073e9SAndroid Build Coastguard Worker expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] 406*da0073e9SAndroid Build Coastguard Worker actual = F.conv3d(x, y, padding="same") 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual, rtol=rtol, atol=atol) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) 410*da0073e9SAndroid Build Coastguard Worker actual = F.conv3d(x, y, padding="same", dilation=2) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual, rtol=rtol, atol=atol) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) 414*da0073e9SAndroid Build Coastguard Worker expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] 415*da0073e9SAndroid Build Coastguard Worker actual = F.conv3d(x, y, padding="same", dilation=3) 416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual, rtol=rtol, atol=atol) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 419*da0073e9SAndroid Build Coastguard Worker def test_conv1d_valid_padding(self, device, dtype): 420*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, device=device, dtype=dtype) 421*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, device=device, dtype=dtype) 422*da0073e9SAndroid Build Coastguard Worker expect = F.conv1d(x, y) 423*da0073e9SAndroid Build Coastguard Worker actual = F.conv1d(x, y, padding="valid") 424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 427*da0073e9SAndroid Build Coastguard Worker def test_conv2d_valid_padding(self, device, dtype): 428*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) 429*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) 430*da0073e9SAndroid Build Coastguard Worker expect = F.conv2d(x, y) 431*da0073e9SAndroid Build Coastguard Worker actual = F.conv2d(x, y, padding="valid") 432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 435*da0073e9SAndroid Build Coastguard Worker def test_conv3d_valid_padding(self, device, dtype): 436*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) 437*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) 438*da0073e9SAndroid Build Coastguard Worker expect = F.conv3d(x, y) 439*da0073e9SAndroid Build Coastguard Worker actual = F.conv3d(x, y, padding="valid") 440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 443*da0073e9SAndroid Build Coastguard Worker def test_conv1d_same_padding_backward(self, device, dtype): 444*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) 445*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker z = F.conv1d(x, y, padding=3, dilation=2) 448*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 449*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 450*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker z = F.conv1d(x, y, padding="same", dilation=2) 453*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 456*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker z = F.conv1d(x, y, padding=2)[..., 1:] 459*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 460*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 461*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker z = F.conv1d(x, y, padding="same") 464*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 469*da0073e9SAndroid Build Coastguard Worker def test_conv2d_same_padding_backward(self, device, dtype): 470*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) 471*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker z = F.conv2d(x, y, padding=(3, 4), dilation=2) 474*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 475*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 476*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker z = F.conv2d(x, y, padding="same", dilation=2) 479*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 482*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) 485*da0073e9SAndroid Build Coastguard Worker z = F.conv2d(x, y, padding=2)[..., 1:, 1:] 486*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 487*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 488*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker z = F.conv2d(x, y, padding="same") 491*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 496*da0073e9SAndroid Build Coastguard Worker def test_conv3d_same_padding_backward(self, device, dtype): 497*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) 498*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) 499*da0073e9SAndroid Build Coastguard Worker z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) 500*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 501*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 502*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker z = F.conv3d(x, y, padding="same", dilation=2) 505*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 508*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 509*da0073e9SAndroid Build Coastguard Worker gradcheck( 510*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="same", dilation=2), 511*da0073e9SAndroid Build Coastguard Worker (x, y), 512*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 513*da0073e9SAndroid Build Coastguard Worker nondet_tol=1e-5, 514*da0073e9SAndroid Build Coastguard Worker ) 515*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 516*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="same", dilation=2), 517*da0073e9SAndroid Build Coastguard Worker (x, y), 518*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev=True, 519*da0073e9SAndroid Build Coastguard Worker ) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) 522*da0073e9SAndroid Build Coastguard Worker z = F.conv3d(x, y, padding=2)[..., 1:, 1:] 523*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 524*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 525*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker z = F.conv3d(x, y, padding="same") 528*da0073e9SAndroid Build Coastguard Worker z.sum().abs().backward() 529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, x.grad) 530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, y.grad) 531*da0073e9SAndroid Build Coastguard Worker gradcheck( 532*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="same"), 533*da0073e9SAndroid Build Coastguard Worker (x, y), 534*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 535*da0073e9SAndroid Build Coastguard Worker nondet_tol=1e-5, 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 538*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="same"), 539*da0073e9SAndroid Build Coastguard Worker (x, y), 540*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev=True, 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 544*da0073e9SAndroid Build Coastguard Worker def test_conv1d_valid_padding_backward(self, device, dtype): 545*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) 546*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) 547*da0073e9SAndroid Build Coastguard Worker F.conv1d(x, y, padding=0).sum().abs().backward() 548*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 549*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 550*da0073e9SAndroid Build Coastguard Worker F.conv1d(x, y, padding="valid").sum().abs().backward() 551*da0073e9SAndroid Build Coastguard Worker gx_actual, gy_actual = x.grad, y.grad 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, gx_actual) 553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, gy_actual) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 556*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 557*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ("valid", "same")) 558*da0073e9SAndroid Build Coastguard Worker def test_conv1d_vs_scipy(self, device, dtype, mode): 559*da0073e9SAndroid Build Coastguard Worker t = make_tensor((1, 10), device=device, dtype=dtype) 560*da0073e9SAndroid Build Coastguard Worker feat_dim = t.shape[1] 561*da0073e9SAndroid Build Coastguard Worker weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) 562*da0073e9SAndroid Build Coastguard Worker weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker def _test(t, weight, mode): 565*da0073e9SAndroid Build Coastguard Worker t_a = t.view(-1).cpu().numpy() 566*da0073e9SAndroid Build Coastguard Worker w_a = weight.view(-1).cpu().numpy() 567*da0073e9SAndroid Build Coastguard Worker expected = scipy.signal.convolve(t_a, w_a, mode=mode) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker kwargs = {"padding": mode} 570*da0073e9SAndroid Build Coastguard Worker if mode == "same": 571*da0073e9SAndroid Build Coastguard Worker p = weight.shape[2] // 2 572*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.pad(t, (p, p)) 573*da0073e9SAndroid Build Coastguard Worker kwargs.pop("padding") 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker weight_flipped = torch.flip(weight, (2,)) 576*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) 577*da0073e9SAndroid Build Coastguard Worker if mode == "same": 578*da0073e9SAndroid Build Coastguard Worker actual = actual[:feat_dim] 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 583*da0073e9SAndroid Build Coastguard Worker _test(t, weight_even, mode) 584*da0073e9SAndroid Build Coastguard Worker _test(t, weight_odd, mode) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 587*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 588*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ("valid", "same")) 589*da0073e9SAndroid Build Coastguard Worker def test_conv2d_vs_scipy(self, device, dtype, mode): 590*da0073e9SAndroid Build Coastguard Worker t = make_tensor((1, 5, 10), device=device, dtype=dtype) 591*da0073e9SAndroid Build Coastguard Worker weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) 592*da0073e9SAndroid Build Coastguard Worker weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker def _test(t, weight, mode): 595*da0073e9SAndroid Build Coastguard Worker t_a = t.squeeze(0).cpu().numpy() 596*da0073e9SAndroid Build Coastguard Worker w_a = weight.squeeze(0).squeeze(0).cpu().numpy() 597*da0073e9SAndroid Build Coastguard Worker expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker kwargs = {"padding": mode} 600*da0073e9SAndroid Build Coastguard Worker if mode == "same": 601*da0073e9SAndroid Build Coastguard Worker left_right_pad = weight.shape[3] // 2 602*da0073e9SAndroid Build Coastguard Worker top_bottom_pad = weight.shape[2] // 2 603*da0073e9SAndroid Build Coastguard Worker p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) 604*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.pad(t, p) 605*da0073e9SAndroid Build Coastguard Worker kwargs.pop("padding") 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker weight_flipped = torch.flip(weight, (2, 3)) 608*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) 609*da0073e9SAndroid Build Coastguard Worker if mode == "same": 610*da0073e9SAndroid Build Coastguard Worker actual = actual[:5, :10] 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 615*da0073e9SAndroid Build Coastguard Worker _test(t, weight_even, mode) 616*da0073e9SAndroid Build Coastguard Worker _test(t, weight_odd, mode) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 619*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 620*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ("valid", "same")) 621*da0073e9SAndroid Build Coastguard Worker def test_conv3d_vs_scipy(self, device, dtype, mode): 622*da0073e9SAndroid Build Coastguard Worker t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) 623*da0073e9SAndroid Build Coastguard Worker weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) 624*da0073e9SAndroid Build Coastguard Worker weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker def _test(t, weight, mode): 627*da0073e9SAndroid Build Coastguard Worker t_a = t.squeeze(0).cpu().numpy() 628*da0073e9SAndroid Build Coastguard Worker w_a = weight.squeeze(0).squeeze(0).cpu().numpy() 629*da0073e9SAndroid Build Coastguard Worker expected = scipy.signal.convolve(t_a, w_a, mode=mode) 630*da0073e9SAndroid Build Coastguard Worker kwargs = {"padding": mode} 631*da0073e9SAndroid Build Coastguard Worker if mode == "same": 632*da0073e9SAndroid Build Coastguard Worker left_right_pad = weight.shape[4] // 2 633*da0073e9SAndroid Build Coastguard Worker top_bottom_pad = weight.shape[3] // 2 634*da0073e9SAndroid Build Coastguard Worker front_back_pad = weight.shape[2] // 2 635*da0073e9SAndroid Build Coastguard Worker p = ( 636*da0073e9SAndroid Build Coastguard Worker left_right_pad, 637*da0073e9SAndroid Build Coastguard Worker left_right_pad, 638*da0073e9SAndroid Build Coastguard Worker top_bottom_pad, 639*da0073e9SAndroid Build Coastguard Worker top_bottom_pad, 640*da0073e9SAndroid Build Coastguard Worker front_back_pad, 641*da0073e9SAndroid Build Coastguard Worker front_back_pad, 642*da0073e9SAndroid Build Coastguard Worker ) 643*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.pad(t, p) 644*da0073e9SAndroid Build Coastguard Worker kwargs.pop("padding") 645*da0073e9SAndroid Build Coastguard Worker weight_flipped = torch.flip(weight, (2, 3, 4)) 646*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) 647*da0073e9SAndroid Build Coastguard Worker if mode == "same": 648*da0073e9SAndroid Build Coastguard Worker actual = actual[:5, :5, :10] 649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 652*da0073e9SAndroid Build Coastguard Worker _test(t, weight_even, mode) 653*da0073e9SAndroid Build Coastguard Worker _test(t, weight_odd, mode) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 656*da0073e9SAndroid Build Coastguard Worker def test_conv2d_valid_padding_backward(self, device, dtype): 657*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) 658*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) 659*da0073e9SAndroid Build Coastguard Worker F.conv2d(x, y, padding=0).sum().abs().backward() 660*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 661*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 662*da0073e9SAndroid Build Coastguard Worker F.conv2d(x, y, padding="valid").sum().abs().backward() 663*da0073e9SAndroid Build Coastguard Worker gx_actual, gy_actual = x.grad, y.grad 664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, gx_actual) 665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, gy_actual) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 668*da0073e9SAndroid Build Coastguard Worker def test_conv3d_valid_padding_backward(self, device, dtype): 669*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) 670*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) 671*da0073e9SAndroid Build Coastguard Worker F.conv3d(x, y, padding=0).sum().abs().backward() 672*da0073e9SAndroid Build Coastguard Worker gx_expect, gy_expect = x.grad, y.grad 673*da0073e9SAndroid Build Coastguard Worker x.grad, y.grad = None, None 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker F.conv3d(x, y, padding="valid").sum().abs().backward() 676*da0073e9SAndroid Build Coastguard Worker gx_actual, gy_actual = x.grad, y.grad 677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expect, gx_actual) 678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gy_expect, gy_actual) 679*da0073e9SAndroid Build Coastguard Worker gradcheck( 680*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="valid"), 681*da0073e9SAndroid Build Coastguard Worker (x, y), 682*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 683*da0073e9SAndroid Build Coastguard Worker ) 684*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 685*da0073e9SAndroid Build Coastguard Worker lambda x, y: F.conv3d(x, y, padding="valid"), 686*da0073e9SAndroid Build Coastguard Worker (x, y), 687*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev=True, 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d") 691*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): 692*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) 693*da0073e9SAndroid Build Coastguard Worker output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) 694*da0073e9SAndroid Build Coastguard Worker ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d") 695*da0073e9SAndroid Build Coastguard Worker m = ConvTransposeNd( 696*da0073e9SAndroid Build Coastguard Worker 1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device 697*da0073e9SAndroid Build Coastguard Worker ) 698*da0073e9SAndroid Build Coastguard Worker output = m(inp, output_size=output_size) 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, output_size) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 702*da0073e9SAndroid Build Coastguard Worker def test_conv_empty_channel(self, device, dtype): 703*da0073e9SAndroid Build Coastguard Worker in_channels = 0 704*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) 705*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 0, 15, device=device, dtype=dtype) 706*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 709*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 1, 0, device=device, dtype=dtype) 710*da0073e9SAndroid Build Coastguard Worker mod(inp) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) 713*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) 714*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 717*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) 718*da0073e9SAndroid Build Coastguard Worker mod(inp) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) 721*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) 722*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 725*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) 726*da0073e9SAndroid Build Coastguard Worker mod(inp) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker def test_group_conv_empty(self, device): 729*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to( 730*da0073e9SAndroid Build Coastguard Worker device 731*da0073e9SAndroid Build Coastguard Worker ) 732*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 4, 4, 4, device=device) 733*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def test_group_convTranspose_empty(self, device): 736*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ConvTranspose2d( 737*da0073e9SAndroid Build Coastguard Worker 4, 4, stride=2, kernel_size=3, padding=1, groups=4 738*da0073e9SAndroid Build Coastguard Worker ).to(device) 739*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 4, 4, 4, device=device) 740*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker def test_convTranspose_empty(self, device): 743*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to( 744*da0073e9SAndroid Build Coastguard Worker device 745*da0073e9SAndroid Build Coastguard Worker ) 746*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 4, 4, 4, device=device) 747*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker def test_conv_large_nosplit(self, device): 750*da0073e9SAndroid Build Coastguard Worker dtype = torch.half 751*da0073e9SAndroid Build Coastguard Worker conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) 752*da0073e9SAndroid Build Coastguard Worker input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) 753*da0073e9SAndroid Build Coastguard Worker conv1(input_large) 754*da0073e9SAndroid Build Coastguard Worker conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) 755*da0073e9SAndroid Build Coastguard Worker input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device) 756*da0073e9SAndroid Build Coastguard Worker conv2(input_large) 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker def test_conv_noncontig_weights(self, device): 759*da0073e9SAndroid Build Coastguard Worker for dim in (1, 2, 3): 760*da0073e9SAndroid Build Coastguard Worker for grouped in (False, True): 761*da0073e9SAndroid Build Coastguard Worker nc = 3 762*da0073e9SAndroid Build Coastguard Worker groups = 3 if grouped else 1 763*da0073e9SAndroid Build Coastguard Worker w = torch.randn([3] * dim, device=device) 764*da0073e9SAndroid Build Coastguard Worker w = w.expand([nc, int(nc / groups)] + list(w.shape)) 765*da0073e9SAndroid Build Coastguard Worker w = w.detach().requires_grad_() 766*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 767*da0073e9SAndroid Build Coastguard Worker [1, nc] + ([5] * dim), device=device, requires_grad=True 768*da0073e9SAndroid Build Coastguard Worker ) 769*da0073e9SAndroid Build Coastguard Worker y = getattr(F, f"conv{dim}d")(x, w, groups=groups) 770*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 771*da0073e9SAndroid Build Coastguard Worker y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups) 772*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker def test_conv_noncontig_weights_and_bias(self, device): 775*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 776*da0073e9SAndroid Build Coastguard Worker conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to( 777*da0073e9SAndroid Build Coastguard Worker device, torch.float 778*da0073e9SAndroid Build Coastguard Worker ) 779*da0073e9SAndroid Build Coastguard Worker input_nc = torch.randn( 780*da0073e9SAndroid Build Coastguard Worker (1, 3, 224, 224, 2), device=device, dtype=torch.float 781*da0073e9SAndroid Build Coastguard Worker )[:, :, :, :, 1] 782*da0073e9SAndroid Build Coastguard Worker input_c = input_nc.contiguous() 783*da0073e9SAndroid Build Coastguard Worker weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[ 784*da0073e9SAndroid Build Coastguard Worker :, :, :, :, 1 785*da0073e9SAndroid Build Coastguard Worker ] 786*da0073e9SAndroid Build Coastguard Worker conv1.weight = nn.Parameter(weight_nc) 787*da0073e9SAndroid Build Coastguard Worker weight_c = conv1.weight.contiguous() 788*da0073e9SAndroid Build Coastguard Worker if bias: 789*da0073e9SAndroid Build Coastguard Worker bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] 790*da0073e9SAndroid Build Coastguard Worker conv1.bias = nn.Parameter(bias_nc) 791*da0073e9SAndroid Build Coastguard Worker bias_c = conv1.bias.contiguous() 792*da0073e9SAndroid Build Coastguard Worker out1 = conv1(input_nc) 793*da0073e9SAndroid Build Coastguard Worker conv1.weight = nn.Parameter(weight_c) 794*da0073e9SAndroid Build Coastguard Worker if bias: 795*da0073e9SAndroid Build Coastguard Worker conv1.bias = nn.Parameter(bias_c) 796*da0073e9SAndroid Build Coastguard Worker out2 = conv1(input_c) 797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker def test_conv_transposed_large(self, device): 800*da0073e9SAndroid Build Coastguard Worker dtype = torch.half if self.device_type == "cuda" else torch.float 801*da0073e9SAndroid Build Coastguard Worker conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) 802*da0073e9SAndroid Build Coastguard Worker input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) 803*da0073e9SAndroid Build Coastguard Worker ret = conv(input_large) 804*da0073e9SAndroid Build Coastguard Worker maxdiff0 = ( 805*da0073e9SAndroid Build Coastguard Worker (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))) 806*da0073e9SAndroid Build Coastguard Worker .abs_() 807*da0073e9SAndroid Build Coastguard Worker .max() 808*da0073e9SAndroid Build Coastguard Worker .item() 809*da0073e9SAndroid Build Coastguard Worker ) 810*da0073e9SAndroid Build Coastguard Worker maxdiff1 = ( 811*da0073e9SAndroid Build Coastguard Worker (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))) 812*da0073e9SAndroid Build Coastguard Worker .abs_() 813*da0073e9SAndroid Build Coastguard Worker .max() 814*da0073e9SAndroid Build Coastguard Worker .item() 815*da0073e9SAndroid Build Coastguard Worker ) 816*da0073e9SAndroid Build Coastguard Worker maxdiff2 = ( 817*da0073e9SAndroid Build Coastguard Worker (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))) 818*da0073e9SAndroid Build Coastguard Worker .abs_() 819*da0073e9SAndroid Build Coastguard Worker .max() 820*da0073e9SAndroid Build Coastguard Worker .item() 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker maxdiff3 = ( 823*da0073e9SAndroid Build Coastguard Worker (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))) 824*da0073e9SAndroid Build Coastguard Worker .abs_() 825*da0073e9SAndroid Build Coastguard Worker .max() 826*da0073e9SAndroid Build Coastguard Worker .item() 827*da0073e9SAndroid Build Coastguard Worker ) 828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(maxdiff0, 0) 829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(maxdiff1, 0) 830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(maxdiff2, 0) 831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(maxdiff3, 0) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def test_conv_large(self, device): 834*da0073e9SAndroid Build Coastguard Worker dtype = torch.half if self.device_type == "cuda" else torch.float 835*da0073e9SAndroid Build Coastguard Worker conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) 836*da0073e9SAndroid Build Coastguard Worker input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) 837*da0073e9SAndroid Build Coastguard Worker ret = conv(input_large) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret[:2048], conv(input_large[:2048])) 839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) 840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret[4096:], conv(input_large[4096:])) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker conv.zero_grad() 843*da0073e9SAndroid Build Coastguard Worker ret.view(4097, -1).max(dim=1).values.sum().backward() 844*da0073e9SAndroid Build Coastguard Worker del ret 845*da0073e9SAndroid Build Coastguard Worker grad1 = conv.weight.grad.detach().clone() 846*da0073e9SAndroid Build Coastguard Worker conv.zero_grad() 847*da0073e9SAndroid Build Coastguard Worker conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() 848*da0073e9SAndroid Build Coastguard Worker conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() 849*da0073e9SAndroid Build Coastguard Worker conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() 850*da0073e9SAndroid Build Coastguard Worker grad2 = conv.weight.grad.detach().clone() 851*da0073e9SAndroid Build Coastguard Worker scale = 1 / grad2.abs().mean() 852*da0073e9SAndroid Build Coastguard Worker grad1 = grad1 * scale 853*da0073e9SAndroid Build Coastguard Worker grad2 = grad2 * scale 854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def test_Conv2d_size_1_kernel(self, device): 857*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(2, 3, 5, 5) 858*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) 859*da0073e9SAndroid Build Coastguard Worker y_cpu = conv_cpu(x_cpu) 860*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(y_cpu) 861*da0073e9SAndroid Build Coastguard Worker y_cpu.backward(y) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker with cudnn.flags(enabled=False): 864*da0073e9SAndroid Build Coastguard Worker conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) 865*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.data.copy_(conv_cpu.bias.data) 866*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.data.copy_(conv_cpu.weight.data) 867*da0073e9SAndroid Build Coastguard Worker y_cuda = conv_cuda(x_cpu.to(device)) 868*da0073e9SAndroid Build Coastguard Worker y_cuda.backward(y.to(device)) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 871*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 872*da0073e9SAndroid Build Coastguard Worker conv_cpu.bias.grad.data, 873*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.grad.data, 874*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 875*da0073e9SAndroid Build Coastguard Worker rtol=0, 876*da0073e9SAndroid Build Coastguard Worker exact_device=False, 877*da0073e9SAndroid Build Coastguard Worker ) 878*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 879*da0073e9SAndroid Build Coastguard Worker conv_cpu.weight.grad.data, 880*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.grad.data, 881*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 882*da0073e9SAndroid Build Coastguard Worker rtol=0, 883*da0073e9SAndroid Build Coastguard Worker exact_device=False, 884*da0073e9SAndroid Build Coastguard Worker ) 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker def test_ConvTranspose2d_size_1_kernel(self, device): 887*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(2, 3, 5, 5) 888*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) 889*da0073e9SAndroid Build Coastguard Worker y_cpu = conv_cpu(x_cpu) 890*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(y_cpu) 891*da0073e9SAndroid Build Coastguard Worker y_cpu.backward(y) 892*da0073e9SAndroid Build Coastguard Worker conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) 893*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.data.copy_(conv_cpu.bias.data) 894*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.data.copy_(conv_cpu.weight.data) 895*da0073e9SAndroid Build Coastguard Worker y_cuda = conv_cuda(x_cpu.to(device)) 896*da0073e9SAndroid Build Coastguard Worker y_cuda.backward(y.to(device)) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 899*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 900*da0073e9SAndroid Build Coastguard Worker conv_cpu.bias.grad.data, 901*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.grad.data, 902*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 903*da0073e9SAndroid Build Coastguard Worker rtol=0, 904*da0073e9SAndroid Build Coastguard Worker exact_device=False, 905*da0073e9SAndroid Build Coastguard Worker ) 906*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 907*da0073e9SAndroid Build Coastguard Worker conv_cpu.weight.grad.data, 908*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.grad.data, 909*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 910*da0073e9SAndroid Build Coastguard Worker rtol=0, 911*da0073e9SAndroid Build Coastguard Worker exact_device=False, 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker def test_ConvTranspose3d_size_1_kernel(self, device): 915*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 916*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(2, 3, 3, 5, 5) 917*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) 918*da0073e9SAndroid Build Coastguard Worker y_cpu = conv_cpu(x_cpu) 919*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(y_cpu) 920*da0073e9SAndroid Build Coastguard Worker y_cpu.backward(y) 921*da0073e9SAndroid Build Coastguard Worker conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) 922*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.data.copy_(conv_cpu.bias.data) 923*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.data.copy_(conv_cpu.weight.data) 924*da0073e9SAndroid Build Coastguard Worker y_cuda = conv_cuda(x_cpu.to(device)) 925*da0073e9SAndroid Build Coastguard Worker y_cuda.backward(y.to(device)) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 928*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 929*da0073e9SAndroid Build Coastguard Worker conv_cpu.bias.grad.data, 930*da0073e9SAndroid Build Coastguard Worker conv_cuda.bias.grad.data, 931*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 932*da0073e9SAndroid Build Coastguard Worker rtol=0, 933*da0073e9SAndroid Build Coastguard Worker exact_device=False, 934*da0073e9SAndroid Build Coastguard Worker ) 935*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 936*da0073e9SAndroid Build Coastguard Worker conv_cpu.weight.grad.data, 937*da0073e9SAndroid Build Coastguard Worker conv_cuda.weight.grad.data, 938*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 939*da0073e9SAndroid Build Coastguard Worker rtol=0, 940*da0073e9SAndroid Build Coastguard Worker exact_device=False, 941*da0073e9SAndroid Build Coastguard Worker ) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 944*da0073e9SAndroid Build Coastguard Worker def test_Conv2d_naive_groups(self, device, dtype): 945*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) 946*da0073e9SAndroid Build Coastguard Worker i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) 947*da0073e9SAndroid Build Coastguard Worker output = m(i) 948*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) 949*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) 952*da0073e9SAndroid Build Coastguard Worker m1.weight.data.copy_(m.weight.data[:2]) 953*da0073e9SAndroid Build Coastguard Worker m1.bias.data.copy_(m.bias.data[:2]) 954*da0073e9SAndroid Build Coastguard Worker i1 = i.data[:, :2].contiguous().requires_grad_(True) 955*da0073e9SAndroid Build Coastguard Worker output1 = m1(i1) 956*da0073e9SAndroid Build Coastguard Worker output1.backward(grad_output[:, :2].contiguous()) 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) 959*da0073e9SAndroid Build Coastguard Worker m2.weight.data.copy_(m.weight.data[2:]) 960*da0073e9SAndroid Build Coastguard Worker m2.bias.data.copy_(m.bias.data[2:]) 961*da0073e9SAndroid Build Coastguard Worker i2 = i.data[:, 2:].contiguous().requires_grad_(True) 962*da0073e9SAndroid Build Coastguard Worker output2 = m2(i2) 963*da0073e9SAndroid Build Coastguard Worker output2.backward(grad_output[:, 2:].contiguous()) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, torch.cat([output1, output2], 1)) 966*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 967*da0073e9SAndroid Build Coastguard Worker i.grad.data, 968*da0073e9SAndroid Build Coastguard Worker torch.cat([i1.grad.data, i2.grad.data], 1), 969*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 970*da0073e9SAndroid Build Coastguard Worker rtol=0, 971*da0073e9SAndroid Build Coastguard Worker ) 972*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 973*da0073e9SAndroid Build Coastguard Worker m.bias.grad.data, 974*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 975*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 976*da0073e9SAndroid Build Coastguard Worker rtol=0, 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 979*da0073e9SAndroid Build Coastguard Worker m.weight.grad.data, 980*da0073e9SAndroid Build Coastguard Worker torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 981*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtype], 982*da0073e9SAndroid Build Coastguard Worker rtol=0, 983*da0073e9SAndroid Build Coastguard Worker ) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 986*da0073e9SAndroid Build Coastguard Worker def test_Conv2d_backward_depthwise(self, device, dtype): 987*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) 988*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker def conv2d_depthwise(x, weight): 991*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.conv2d( 992*da0073e9SAndroid Build Coastguard Worker x, weight, bias=None, stride=(1, 10), groups=2 993*da0073e9SAndroid Build Coastguard Worker ) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float) 998*da0073e9SAndroid Build Coastguard Worker def test_conv_cudnn_nhwc(self, device, dtype): 999*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, out_channels, kernel_size, groups): 1000*da0073e9SAndroid Build Coastguard Worker input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( 1001*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last 1002*da0073e9SAndroid Build Coastguard Worker ) 1003*da0073e9SAndroid Build Coastguard Worker input.requires_grad_() 1004*da0073e9SAndroid Build Coastguard Worker conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to( 1005*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, memory_format=torch.channels_last 1006*da0073e9SAndroid Build Coastguard Worker ) 1007*da0073e9SAndroid Build Coastguard Worker for p in conv.parameters(): 1008*da0073e9SAndroid Build Coastguard Worker p.data = torch.randint_like(p, -3, 3) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().double().requires_grad_() 1011*da0073e9SAndroid Build Coastguard Worker ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) 1012*da0073e9SAndroid Build Coastguard Worker ref_conv.load_state_dict(conv.state_dict()) 1013*da0073e9SAndroid Build Coastguard Worker ref_conv = ref_conv.to( 1014*da0073e9SAndroid Build Coastguard Worker device=device, dtype=torch.double, memory_format=torch.contiguous_format 1015*da0073e9SAndroid Build Coastguard Worker ) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker out = conv(input) 1018*da0073e9SAndroid Build Coastguard Worker ref_out = ref_conv(ref_input) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker grad = torch.randint_like(out, -3, 3) 1021*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().double().contiguous() 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1024*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 1027*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) 1028*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1029*da0073e9SAndroid Build Coastguard Worker conv.weight.grad.is_contiguous(memory_format=torch.channels_last) 1030*da0073e9SAndroid Build Coastguard Worker ) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1033*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_input.grad.is_contiguous()) 1034*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_conv.weight.grad.is_contiguous()) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out, exact_dtype=False) 1037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) 1038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) 1039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) 1042*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) 1043*da0073e9SAndroid Build Coastguard Worker helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) 1044*da0073e9SAndroid Build Coastguard Worker helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float) 1047*da0073e9SAndroid Build Coastguard Worker def test_conv_cudnn_ndhwc(self, device, dtype): 1048*da0073e9SAndroid Build Coastguard Worker def helper(n, c, d, h, w, out_channels, kernel_size, groups): 1049*da0073e9SAndroid Build Coastguard Worker input = torch.randint( 1050*da0073e9SAndroid Build Coastguard Worker -2, 2, (n, c, d, h, w), dtype=dtype, device=device 1051*da0073e9SAndroid Build Coastguard Worker ).to(memory_format=torch.channels_last_3d) 1052*da0073e9SAndroid Build Coastguard Worker input.requires_grad_() 1053*da0073e9SAndroid Build Coastguard Worker conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to( 1054*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, memory_format=torch.channels_last_3d 1055*da0073e9SAndroid Build Coastguard Worker ) 1056*da0073e9SAndroid Build Coastguard Worker for p in conv.parameters(): 1057*da0073e9SAndroid Build Coastguard Worker p.data = torch.randint_like(p, -2, 2) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().double().requires_grad_() 1060*da0073e9SAndroid Build Coastguard Worker ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) 1061*da0073e9SAndroid Build Coastguard Worker ref_conv.load_state_dict(conv.state_dict()) 1062*da0073e9SAndroid Build Coastguard Worker ref_conv = ref_conv.to( 1063*da0073e9SAndroid Build Coastguard Worker device=device, dtype=torch.double, memory_format=torch.contiguous_format 1064*da0073e9SAndroid Build Coastguard Worker ) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker out = conv(input) 1067*da0073e9SAndroid Build Coastguard Worker ref_out = ref_conv(ref_input) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker grad = torch.randint_like(out, -2, 2) 1070*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().double().contiguous() 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1073*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) 1076*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1077*da0073e9SAndroid Build Coastguard Worker input.grad.is_contiguous(memory_format=torch.channels_last_3d) 1078*da0073e9SAndroid Build Coastguard Worker ) 1079*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1080*da0073e9SAndroid Build Coastguard Worker conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d) 1081*da0073e9SAndroid Build Coastguard Worker ) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1084*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_input.grad.is_contiguous()) 1085*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_conv.weight.grad.is_contiguous()) 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out, exact_dtype=False) 1088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) 1089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) 1090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) 1093*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) 1094*da0073e9SAndroid Build Coastguard Worker helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) 1095*da0073e9SAndroid Build Coastguard Worker helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker def _run_conv( 1098*da0073e9SAndroid Build Coastguard Worker self, 1099*da0073e9SAndroid Build Coastguard Worker layer, 1100*da0073e9SAndroid Build Coastguard Worker device, 1101*da0073e9SAndroid Build Coastguard Worker inp, 1102*da0073e9SAndroid Build Coastguard Worker grad, 1103*da0073e9SAndroid Build Coastguard Worker ref_conv, 1104*da0073e9SAndroid Build Coastguard Worker ref_input, 1105*da0073e9SAndroid Build Coastguard Worker ref_out, 1106*da0073e9SAndroid Build Coastguard Worker input_format, 1107*da0073e9SAndroid Build Coastguard Worker weight_format, 1108*da0073e9SAndroid Build Coastguard Worker grad_format, 1109*da0073e9SAndroid Build Coastguard Worker output_format, 1110*da0073e9SAndroid Build Coastguard Worker ): 1111*da0073e9SAndroid Build Coastguard Worker conv = ( 1112*da0073e9SAndroid Build Coastguard Worker layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device) 1113*da0073e9SAndroid Build Coastguard Worker ) 1114*da0073e9SAndroid Build Coastguard Worker conv.load_state_dict(ref_conv.state_dict()) 1115*da0073e9SAndroid Build Coastguard Worker weight_data = ( 1116*da0073e9SAndroid Build Coastguard Worker conv.weight.detach().clone().contiguous(memory_format=weight_format) 1117*da0073e9SAndroid Build Coastguard Worker ) 1118*da0073e9SAndroid Build Coastguard Worker conv.weight.data = weight_data.resize_( 1119*da0073e9SAndroid Build Coastguard Worker weight_data.size(), memory_format=weight_format 1120*da0073e9SAndroid Build Coastguard Worker ) 1121*da0073e9SAndroid Build Coastguard Worker input = inp.clone().contiguous(memory_format=input_format) 1122*da0073e9SAndroid Build Coastguard Worker input.resize_(input.size(), memory_format=input_format) 1123*da0073e9SAndroid Build Coastguard Worker input = input.requires_grad_() 1124*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=grad_format) 1125*da0073e9SAndroid Build Coastguard Worker grad.resize_(grad.size(), memory_format=grad_format) 1126*da0073e9SAndroid Build Coastguard Worker out = conv(input) 1127*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=output_format)) 1129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.weight.grad, ref_conv.weight.grad) 1131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv.bias.grad, ref_conv.bias.grad) 1132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): 1135*da0073e9SAndroid Build Coastguard Worker data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) 1136*da0073e9SAndroid Build Coastguard Worker ref_input = data.clone().contiguous().requires_grad_(True) 1137*da0073e9SAndroid Build Coastguard Worker ref_conv = layer(c, k, filter_size).float().to(device) 1138*da0073e9SAndroid Build Coastguard Worker ref_out = ref_conv(ref_input) 1139*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device=device) 1140*da0073e9SAndroid Build Coastguard Worker ref_out.backward(grad) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker for w_f in [torch.contiguous_format, torch.channels_last]: 1143*da0073e9SAndroid Build Coastguard Worker for g_f in [torch.contiguous_format, torch.channels_last]: 1144*da0073e9SAndroid Build Coastguard Worker for input_format in [torch.contiguous_format, torch.channels_last]: 1145*da0073e9SAndroid Build Coastguard Worker output_format = torch.contiguous_format 1146*da0073e9SAndroid Build Coastguard Worker if input_format == torch.channels_last: 1147*da0073e9SAndroid Build Coastguard Worker output_format = torch.channels_last 1148*da0073e9SAndroid Build Coastguard Worker if w_f == torch.channels_last: 1149*da0073e9SAndroid Build Coastguard Worker output_format = torch.channels_last 1150*da0073e9SAndroid Build Coastguard Worker self._run_conv( 1151*da0073e9SAndroid Build Coastguard Worker layer, 1152*da0073e9SAndroid Build Coastguard Worker device, 1153*da0073e9SAndroid Build Coastguard Worker data, 1154*da0073e9SAndroid Build Coastguard Worker grad, 1155*da0073e9SAndroid Build Coastguard Worker ref_conv, 1156*da0073e9SAndroid Build Coastguard Worker ref_input, 1157*da0073e9SAndroid Build Coastguard Worker ref_out, 1158*da0073e9SAndroid Build Coastguard Worker input_format, 1159*da0073e9SAndroid Build Coastguard Worker w_f, 1160*da0073e9SAndroid Build Coastguard Worker g_f, 1161*da0073e9SAndroid Build Coastguard Worker output_format, 1162*da0073e9SAndroid Build Coastguard Worker ) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1165*da0073e9SAndroid Build Coastguard Worker def test_conv_cudnn_nhwc_support(self, device, dtype): 1166*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 1167*da0073e9SAndroid Build Coastguard Worker (1, 16, 1, 1), dtype=dtype, device=device, requires_grad=True 1168*da0073e9SAndroid Build Coastguard Worker ) 1169*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 1170*da0073e9SAndroid Build Coastguard Worker (8, 16, 3, 3), dtype=dtype, device=device, requires_grad=True 1171*da0073e9SAndroid Build Coastguard Worker ) 1172*da0073e9SAndroid Build Coastguard Worker weight = weight.to(memory_format=torch.channels_last) 1173*da0073e9SAndroid Build Coastguard Worker o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) 1174*da0073e9SAndroid Build Coastguard Worker self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) 1175*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1178*da0073e9SAndroid Build Coastguard Worker def test_conv2d_no_grad(self, device, dtype): 1179*da0073e9SAndroid Build Coastguard Worker for batch in [1, 2, 3]: 1180*da0073e9SAndroid Build Coastguard Worker for groups in [1, 2, 4]: 1181*da0073e9SAndroid Build Coastguard Worker input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) 1182*da0073e9SAndroid Build Coastguard Worker m = nn.Conv2d( 1183*da0073e9SAndroid Build Coastguard Worker groups, 1184*da0073e9SAndroid Build Coastguard Worker 8, 1185*da0073e9SAndroid Build Coastguard Worker kernel_size=(3, 3), 1186*da0073e9SAndroid Build Coastguard Worker groups=groups, 1187*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1188*da0073e9SAndroid Build Coastguard Worker device=device, 1189*da0073e9SAndroid Build Coastguard Worker ) 1190*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1191*da0073e9SAndroid Build Coastguard Worker output_ng = m(input) 1192*da0073e9SAndroid Build Coastguard Worker output = m(input) 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): 1196*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 6, device=device) 1197*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(3, 3, 3, device=device) 1198*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(3, device=device) 1199*da0073e9SAndroid Build Coastguard Worker stride = (2,) 1200*da0073e9SAndroid Build Coastguard Worker padding = (1,) 1201*da0073e9SAndroid Build Coastguard Worker dilation = (1,) 1202*da0073e9SAndroid Build Coastguard Worker transposed = False 1203*da0073e9SAndroid Build Coastguard Worker output_padding = (0,) 1204*da0073e9SAndroid Build Coastguard Worker groups = 1 1205*da0073e9SAndroid Build Coastguard Worker output = torch.ops.aten.convolution( 1206*da0073e9SAndroid Build Coastguard Worker input, 1207*da0073e9SAndroid Build Coastguard Worker weight, 1208*da0073e9SAndroid Build Coastguard Worker bias, 1209*da0073e9SAndroid Build Coastguard Worker stride, 1210*da0073e9SAndroid Build Coastguard Worker padding, 1211*da0073e9SAndroid Build Coastguard Worker dilation, 1212*da0073e9SAndroid Build Coastguard Worker transposed, 1213*da0073e9SAndroid Build Coastguard Worker output_padding, 1214*da0073e9SAndroid Build Coastguard Worker groups, 1215*da0073e9SAndroid Build Coastguard Worker ) 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker ggI = torch.randn(input.shape, device=device) 1218*da0073e9SAndroid Build Coastguard Worker ggW = torch.randn(weight.shape, device=device) 1219*da0073e9SAndroid Build Coastguard Worker ggB = torch.randn(bias.shape, device=device) 1220*da0073e9SAndroid Build Coastguard Worker gO = torch.randn(output.shape, device=device) 1221*da0073e9SAndroid Build Coastguard Worker output_mask = [True, True, True] 1222*da0073e9SAndroid Build Coastguard Worker ( 1223*da0073e9SAndroid Build Coastguard Worker grad_grad_output, 1224*da0073e9SAndroid Build Coastguard Worker grad_input, 1225*da0073e9SAndroid Build Coastguard Worker grad_weight, 1226*da0073e9SAndroid Build Coastguard Worker ) = torch.ops.aten._convolution_double_backward( 1227*da0073e9SAndroid Build Coastguard Worker ggI, 1228*da0073e9SAndroid Build Coastguard Worker ggW, 1229*da0073e9SAndroid Build Coastguard Worker ggB, 1230*da0073e9SAndroid Build Coastguard Worker gO, 1231*da0073e9SAndroid Build Coastguard Worker weight, 1232*da0073e9SAndroid Build Coastguard Worker input, 1233*da0073e9SAndroid Build Coastguard Worker stride, 1234*da0073e9SAndroid Build Coastguard Worker padding, 1235*da0073e9SAndroid Build Coastguard Worker dilation, 1236*da0073e9SAndroid Build Coastguard Worker transposed, 1237*da0073e9SAndroid Build Coastguard Worker output_padding, 1238*da0073e9SAndroid Build Coastguard Worker groups, 1239*da0073e9SAndroid Build Coastguard Worker output_mask, 1240*da0073e9SAndroid Build Coastguard Worker ) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_grad_output.shape, gO.shape) 1243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_input.shape, input.shape) 1244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_weight.shape, weight.shape) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker @onlyXPU 1247*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1248*da0073e9SAndroid Build Coastguard Worker def test_channels_last_ouput_stride(self, device, dtype): 1249*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 1250*da0073e9SAndroid Build Coastguard Worker (2, 3, 16, 16), device=device, dtype=dtype, requires_grad=True 1251*da0073e9SAndroid Build Coastguard Worker ) 1252*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 1253*da0073e9SAndroid Build Coastguard Worker (512, 3, 3, 3), device=device, dtype=dtype, requires_grad=True 1254*da0073e9SAndroid Build Coastguard Worker ) 1255*da0073e9SAndroid Build Coastguard Worker input = input.to(memory_format=torch.channels_last) 1256*da0073e9SAndroid Build Coastguard Worker weight = weight.to(memory_format=torch.channels_last) 1257*da0073e9SAndroid Build Coastguard Worker out = torch.conv2d(input, weight, None, (2, 2), (0, 0), (1, 1), 1) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker if dtype is torch.float64: 1260*da0073e9SAndroid Build Coastguard Worker # Like most conv backend, xpu does not support float64 for chanel last conv. 1261*da0073e9SAndroid Build Coastguard Worker # input NHWC, output NCHW 1262*da0073e9SAndroid Build Coastguard Worker assert_size_stride(out, (2, 512, 7, 7), (25088, 49, 7, 1)) 1263*da0073e9SAndroid Build Coastguard Worker else: 1264*da0073e9SAndroid Build Coastguard Worker # input NHWC, output NHWC 1265*da0073e9SAndroid Build Coastguard Worker assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512)) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests( 1269*da0073e9SAndroid Build Coastguard Worker TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True 1270*da0073e9SAndroid Build Coastguard Worker) 1271*da0073e9SAndroid Build Coastguard Worker 1272*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1273*da0073e9SAndroid Build Coastguard Worker run_tests() 1274