1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""Gradient interface.""" 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.utils import _pair, _single, _triple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerdef conv1d_input( 9*da0073e9SAndroid Build Coastguard Worker input_size, 10*da0073e9SAndroid Build Coastguard Worker weight, 11*da0073e9SAndroid Build Coastguard Worker grad_output, 12*da0073e9SAndroid Build Coastguard Worker stride=1, 13*da0073e9SAndroid Build Coastguard Worker padding=0, 14*da0073e9SAndroid Build Coastguard Worker dilation=1, 15*da0073e9SAndroid Build Coastguard Worker groups=1, 16*da0073e9SAndroid Build Coastguard Worker): 17*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv1d with respect to the input of the convolution. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker This is same as the 1D transposed convolution operator under the hood but requires 20*da0073e9SAndroid Build Coastguard Worker the shape of the gradient w.r.t. input to be specified explicitly. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Args: 23*da0073e9SAndroid Build Coastguard Worker input_size : Shape of the input gradient tensor 24*da0073e9SAndroid Build Coastguard Worker weight: weight tensor (out_channels x in_channels/groups x kW) 25*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oW) 26*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 27*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 28*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 29*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker Examples:: 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 1, 3, requires_grad=True) 34*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(1, 1, 1, requires_grad=True) 35*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv1d(input, weight) 36*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 37*da0073e9SAndroid Build Coastguard Worker >>> grad_input = torch.autograd.grad(output, input, grad_output) 38*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv1d_input(input.shape, weight, grad_output) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker """ 41*da0073e9SAndroid Build Coastguard Worker input = grad_output.new_empty(1).expand(input_size) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 44*da0073e9SAndroid Build Coastguard Worker grad_output, 45*da0073e9SAndroid Build Coastguard Worker input, 46*da0073e9SAndroid Build Coastguard Worker weight, 47*da0073e9SAndroid Build Coastguard Worker None, 48*da0073e9SAndroid Build Coastguard Worker _single(stride), 49*da0073e9SAndroid Build Coastguard Worker _single(padding), 50*da0073e9SAndroid Build Coastguard Worker _single(dilation), 51*da0073e9SAndroid Build Coastguard Worker False, 52*da0073e9SAndroid Build Coastguard Worker [0], 53*da0073e9SAndroid Build Coastguard Worker groups, 54*da0073e9SAndroid Build Coastguard Worker (True, False, False), 55*da0073e9SAndroid Build Coastguard Worker )[0] 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef conv1d_weight( 59*da0073e9SAndroid Build Coastguard Worker input, 60*da0073e9SAndroid Build Coastguard Worker weight_size, 61*da0073e9SAndroid Build Coastguard Worker grad_output, 62*da0073e9SAndroid Build Coastguard Worker stride=1, 63*da0073e9SAndroid Build Coastguard Worker padding=0, 64*da0073e9SAndroid Build Coastguard Worker dilation=1, 65*da0073e9SAndroid Build Coastguard Worker groups=1, 66*da0073e9SAndroid Build Coastguard Worker): 67*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv1d with respect to the weight of the convolution. 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker Args: 70*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape (minibatch x in_channels x iW) 71*da0073e9SAndroid Build Coastguard Worker weight_size : Shape of the weight gradient tensor 72*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oW) 73*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 74*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 75*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 76*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker Examples:: 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 1, 3, requires_grad=True) 81*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(1, 1, 1, requires_grad=True) 82*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv1d(input, weight) 83*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 84*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 85*da0073e9SAndroid Build Coastguard Worker >>> grad_weight = torch.autograd.grad(output, filter, grad_output) 86*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv1d_weight(input, weight.shape, grad_output) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker """ 89*da0073e9SAndroid Build Coastguard Worker weight = grad_output.new_empty(1).expand(weight_size) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 92*da0073e9SAndroid Build Coastguard Worker grad_output, 93*da0073e9SAndroid Build Coastguard Worker input, 94*da0073e9SAndroid Build Coastguard Worker weight, 95*da0073e9SAndroid Build Coastguard Worker None, 96*da0073e9SAndroid Build Coastguard Worker _single(stride), 97*da0073e9SAndroid Build Coastguard Worker _single(padding), 98*da0073e9SAndroid Build Coastguard Worker _single(dilation), 99*da0073e9SAndroid Build Coastguard Worker False, 100*da0073e9SAndroid Build Coastguard Worker [0], 101*da0073e9SAndroid Build Coastguard Worker groups, 102*da0073e9SAndroid Build Coastguard Worker (False, True, False), 103*da0073e9SAndroid Build Coastguard Worker )[1] 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Workerdef conv2d_input( 107*da0073e9SAndroid Build Coastguard Worker input_size, 108*da0073e9SAndroid Build Coastguard Worker weight, 109*da0073e9SAndroid Build Coastguard Worker grad_output, 110*da0073e9SAndroid Build Coastguard Worker stride=1, 111*da0073e9SAndroid Build Coastguard Worker padding=0, 112*da0073e9SAndroid Build Coastguard Worker dilation=1, 113*da0073e9SAndroid Build Coastguard Worker groups=1, 114*da0073e9SAndroid Build Coastguard Worker): 115*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv2d with respect to the input of the convolution. 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker This is same as the 2D transposed convolution operator under the hood but requires 118*da0073e9SAndroid Build Coastguard Worker the shape of the gradient w.r.t. input to be specified explicitly. 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker Args: 121*da0073e9SAndroid Build Coastguard Worker input_size : Shape of the input gradient tensor 122*da0073e9SAndroid Build Coastguard Worker weight: weight tensor (out_channels x in_channels/groups x kH x kW) 123*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oH x oW) 124*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 125*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 126*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 127*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker Examples:: 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) 132*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) 133*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv2d(input, weight) 134*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 135*da0073e9SAndroid Build Coastguard Worker >>> grad_input = torch.autograd.grad(output, input, grad_output) 136*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv2d_input(input.shape, weight, grad_output) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker """ 139*da0073e9SAndroid Build Coastguard Worker input = grad_output.new_empty(1).expand(input_size) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 142*da0073e9SAndroid Build Coastguard Worker grad_output, 143*da0073e9SAndroid Build Coastguard Worker input, 144*da0073e9SAndroid Build Coastguard Worker weight, 145*da0073e9SAndroid Build Coastguard Worker None, 146*da0073e9SAndroid Build Coastguard Worker _pair(stride), 147*da0073e9SAndroid Build Coastguard Worker _pair(padding), 148*da0073e9SAndroid Build Coastguard Worker _pair(dilation), 149*da0073e9SAndroid Build Coastguard Worker False, 150*da0073e9SAndroid Build Coastguard Worker [0], 151*da0073e9SAndroid Build Coastguard Worker groups, 152*da0073e9SAndroid Build Coastguard Worker (True, False, False), 153*da0073e9SAndroid Build Coastguard Worker )[0] 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workerdef conv2d_weight( 157*da0073e9SAndroid Build Coastguard Worker input, 158*da0073e9SAndroid Build Coastguard Worker weight_size, 159*da0073e9SAndroid Build Coastguard Worker grad_output, 160*da0073e9SAndroid Build Coastguard Worker stride=1, 161*da0073e9SAndroid Build Coastguard Worker padding=0, 162*da0073e9SAndroid Build Coastguard Worker dilation=1, 163*da0073e9SAndroid Build Coastguard Worker groups=1, 164*da0073e9SAndroid Build Coastguard Worker): 165*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv2d with respect to the weight of the convolution. 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker Args: 168*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape (minibatch x in_channels x iH x iW) 169*da0073e9SAndroid Build Coastguard Worker weight_size : Shape of the weight gradient tensor 170*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oH x oW) 171*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 172*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 173*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 174*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker Examples:: 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) 179*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) 180*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv2d(input, weight) 181*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 182*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 183*da0073e9SAndroid Build Coastguard Worker >>> grad_weight = torch.autograd.grad(output, filter, grad_output) 184*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv2d_weight(input, weight.shape, grad_output) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker """ 187*da0073e9SAndroid Build Coastguard Worker weight = grad_output.new_empty(1).expand(weight_size) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 190*da0073e9SAndroid Build Coastguard Worker grad_output, 191*da0073e9SAndroid Build Coastguard Worker input, 192*da0073e9SAndroid Build Coastguard Worker weight, 193*da0073e9SAndroid Build Coastguard Worker None, 194*da0073e9SAndroid Build Coastguard Worker _pair(stride), 195*da0073e9SAndroid Build Coastguard Worker _pair(padding), 196*da0073e9SAndroid Build Coastguard Worker _pair(dilation), 197*da0073e9SAndroid Build Coastguard Worker False, 198*da0073e9SAndroid Build Coastguard Worker [0], 199*da0073e9SAndroid Build Coastguard Worker groups, 200*da0073e9SAndroid Build Coastguard Worker (False, True, False), 201*da0073e9SAndroid Build Coastguard Worker )[1] 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Workerdef conv3d_input( 205*da0073e9SAndroid Build Coastguard Worker input_size, 206*da0073e9SAndroid Build Coastguard Worker weight, 207*da0073e9SAndroid Build Coastguard Worker grad_output, 208*da0073e9SAndroid Build Coastguard Worker stride=1, 209*da0073e9SAndroid Build Coastguard Worker padding=0, 210*da0073e9SAndroid Build Coastguard Worker dilation=1, 211*da0073e9SAndroid Build Coastguard Worker groups=1, 212*da0073e9SAndroid Build Coastguard Worker): 213*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv3d with respect to the input of the convolution. 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker This is same as the 3D transposed convolution operator under the hood but requires 216*da0073e9SAndroid Build Coastguard Worker the shape of the gradient w.r.t. input to be specified explicitly. 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker Args: 219*da0073e9SAndroid Build Coastguard Worker input_size : Shape of the input gradient tensor 220*da0073e9SAndroid Build Coastguard Worker weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW) 221*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) 222*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 223*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 224*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 225*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker Examples:: 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) 230*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) 231*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv3d(input, weight) 232*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 233*da0073e9SAndroid Build Coastguard Worker >>> grad_input = torch.autograd.grad(output, input, grad_output) 234*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv3d_input(input.shape, weight, grad_output) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker """ 237*da0073e9SAndroid Build Coastguard Worker input = grad_output.new_empty(1).expand(input_size) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 240*da0073e9SAndroid Build Coastguard Worker grad_output, 241*da0073e9SAndroid Build Coastguard Worker input, 242*da0073e9SAndroid Build Coastguard Worker weight, 243*da0073e9SAndroid Build Coastguard Worker None, 244*da0073e9SAndroid Build Coastguard Worker _triple(stride), 245*da0073e9SAndroid Build Coastguard Worker _triple(padding), 246*da0073e9SAndroid Build Coastguard Worker _triple(dilation), 247*da0073e9SAndroid Build Coastguard Worker False, 248*da0073e9SAndroid Build Coastguard Worker [0], 249*da0073e9SAndroid Build Coastguard Worker groups, 250*da0073e9SAndroid Build Coastguard Worker (True, False, False), 251*da0073e9SAndroid Build Coastguard Worker )[0] 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Workerdef conv3d_weight( 255*da0073e9SAndroid Build Coastguard Worker input, 256*da0073e9SAndroid Build Coastguard Worker weight_size, 257*da0073e9SAndroid Build Coastguard Worker grad_output, 258*da0073e9SAndroid Build Coastguard Worker stride=1, 259*da0073e9SAndroid Build Coastguard Worker padding=0, 260*da0073e9SAndroid Build Coastguard Worker dilation=1, 261*da0073e9SAndroid Build Coastguard Worker groups=1, 262*da0073e9SAndroid Build Coastguard Worker): 263*da0073e9SAndroid Build Coastguard Worker r"""Compute the gradient of conv3d with respect to the weight of the convolution. 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker Args: 266*da0073e9SAndroid Build Coastguard Worker input: input tensor of shape (minibatch x in_channels x iT x iH x iW) 267*da0073e9SAndroid Build Coastguard Worker weight_size : Shape of the weight gradient tensor 268*da0073e9SAndroid Build Coastguard Worker grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) 269*da0073e9SAndroid Build Coastguard Worker stride (int or tuple, optional): Stride of the convolution. Default: 1 270*da0073e9SAndroid Build Coastguard Worker padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 271*da0073e9SAndroid Build Coastguard Worker dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 272*da0073e9SAndroid Build Coastguard Worker groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker Examples:: 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) 277*da0073e9SAndroid Build Coastguard Worker >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) 278*da0073e9SAndroid Build Coastguard Worker >>> output = F.conv3d(input, weight) 279*da0073e9SAndroid Build Coastguard Worker >>> grad_output = torch.randn(output.shape) 280*da0073e9SAndroid Build Coastguard Worker >>> grad_weight = torch.autograd.grad(output, weight, grad_output) 281*da0073e9SAndroid Build Coastguard Worker >>> F.grad.conv3d_weight(input, weight.shape, grad_output) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker """ 284*da0073e9SAndroid Build Coastguard Worker weight = grad_output.new_empty(1).expand(weight_size) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.convolution_backward( 287*da0073e9SAndroid Build Coastguard Worker grad_output, 288*da0073e9SAndroid Build Coastguard Worker input, 289*da0073e9SAndroid Build Coastguard Worker weight, 290*da0073e9SAndroid Build Coastguard Worker None, 291*da0073e9SAndroid Build Coastguard Worker _triple(stride), 292*da0073e9SAndroid Build Coastguard Worker _triple(padding), 293*da0073e9SAndroid Build Coastguard Worker _triple(dilation), 294*da0073e9SAndroid Build Coastguard Worker False, 295*da0073e9SAndroid Build Coastguard Worker [0], 296*da0073e9SAndroid Build Coastguard Worker groups, 297*da0073e9SAndroid Build Coastguard Worker (False, True, False), 298*da0073e9SAndroid Build Coastguard Worker )[1] 299