1# mypy: allow-untyped-defs 2"""Gradient interface.""" 3 4import torch 5from torch.nn.modules.utils import _pair, _single, _triple 6 7 8def conv1d_input( 9 input_size, 10 weight, 11 grad_output, 12 stride=1, 13 padding=0, 14 dilation=1, 15 groups=1, 16): 17 r"""Compute the gradient of conv1d with respect to the input of the convolution. 18 19 This is same as the 1D transposed convolution operator under the hood but requires 20 the shape of the gradient w.r.t. input to be specified explicitly. 21 22 Args: 23 input_size : Shape of the input gradient tensor 24 weight: weight tensor (out_channels x in_channels/groups x kW) 25 grad_output : output gradient tensor (minibatch x out_channels x oW) 26 stride (int or tuple, optional): Stride of the convolution. Default: 1 27 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 28 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 29 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 30 31 Examples:: 32 33 >>> input = torch.randn(1, 1, 3, requires_grad=True) 34 >>> weight = torch.randn(1, 1, 1, requires_grad=True) 35 >>> output = F.conv1d(input, weight) 36 >>> grad_output = torch.randn(output.shape) 37 >>> grad_input = torch.autograd.grad(output, input, grad_output) 38 >>> F.grad.conv1d_input(input.shape, weight, grad_output) 39 40 """ 41 input = grad_output.new_empty(1).expand(input_size) 42 43 return torch.ops.aten.convolution_backward( 44 grad_output, 45 input, 46 weight, 47 None, 48 _single(stride), 49 _single(padding), 50 _single(dilation), 51 False, 52 [0], 53 groups, 54 (True, False, False), 55 )[0] 56 57 58def conv1d_weight( 59 input, 60 weight_size, 61 grad_output, 62 stride=1, 63 padding=0, 64 dilation=1, 65 groups=1, 66): 67 r"""Compute the gradient of conv1d with respect to the weight of the convolution. 68 69 Args: 70 input: input tensor of shape (minibatch x in_channels x iW) 71 weight_size : Shape of the weight gradient tensor 72 grad_output : output gradient tensor (minibatch x out_channels x oW) 73 stride (int or tuple, optional): Stride of the convolution. Default: 1 74 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 75 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 76 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 77 78 Examples:: 79 80 >>> input = torch.randn(1, 1, 3, requires_grad=True) 81 >>> weight = torch.randn(1, 1, 1, requires_grad=True) 82 >>> output = F.conv1d(input, weight) 83 >>> grad_output = torch.randn(output.shape) 84 >>> # xdoctest: +SKIP 85 >>> grad_weight = torch.autograd.grad(output, filter, grad_output) 86 >>> F.grad.conv1d_weight(input, weight.shape, grad_output) 87 88 """ 89 weight = grad_output.new_empty(1).expand(weight_size) 90 91 return torch.ops.aten.convolution_backward( 92 grad_output, 93 input, 94 weight, 95 None, 96 _single(stride), 97 _single(padding), 98 _single(dilation), 99 False, 100 [0], 101 groups, 102 (False, True, False), 103 )[1] 104 105 106def conv2d_input( 107 input_size, 108 weight, 109 grad_output, 110 stride=1, 111 padding=0, 112 dilation=1, 113 groups=1, 114): 115 r"""Compute the gradient of conv2d with respect to the input of the convolution. 116 117 This is same as the 2D transposed convolution operator under the hood but requires 118 the shape of the gradient w.r.t. input to be specified explicitly. 119 120 Args: 121 input_size : Shape of the input gradient tensor 122 weight: weight tensor (out_channels x in_channels/groups x kH x kW) 123 grad_output : output gradient tensor (minibatch x out_channels x oH x oW) 124 stride (int or tuple, optional): Stride of the convolution. Default: 1 125 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 126 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 127 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 128 129 Examples:: 130 131 >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) 132 >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) 133 >>> output = F.conv2d(input, weight) 134 >>> grad_output = torch.randn(output.shape) 135 >>> grad_input = torch.autograd.grad(output, input, grad_output) 136 >>> F.grad.conv2d_input(input.shape, weight, grad_output) 137 138 """ 139 input = grad_output.new_empty(1).expand(input_size) 140 141 return torch.ops.aten.convolution_backward( 142 grad_output, 143 input, 144 weight, 145 None, 146 _pair(stride), 147 _pair(padding), 148 _pair(dilation), 149 False, 150 [0], 151 groups, 152 (True, False, False), 153 )[0] 154 155 156def conv2d_weight( 157 input, 158 weight_size, 159 grad_output, 160 stride=1, 161 padding=0, 162 dilation=1, 163 groups=1, 164): 165 r"""Compute the gradient of conv2d with respect to the weight of the convolution. 166 167 Args: 168 input: input tensor of shape (minibatch x in_channels x iH x iW) 169 weight_size : Shape of the weight gradient tensor 170 grad_output : output gradient tensor (minibatch x out_channels x oH x oW) 171 stride (int or tuple, optional): Stride of the convolution. Default: 1 172 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 173 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 174 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 175 176 Examples:: 177 178 >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) 179 >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) 180 >>> output = F.conv2d(input, weight) 181 >>> grad_output = torch.randn(output.shape) 182 >>> # xdoctest: +SKIP 183 >>> grad_weight = torch.autograd.grad(output, filter, grad_output) 184 >>> F.grad.conv2d_weight(input, weight.shape, grad_output) 185 186 """ 187 weight = grad_output.new_empty(1).expand(weight_size) 188 189 return torch.ops.aten.convolution_backward( 190 grad_output, 191 input, 192 weight, 193 None, 194 _pair(stride), 195 _pair(padding), 196 _pair(dilation), 197 False, 198 [0], 199 groups, 200 (False, True, False), 201 )[1] 202 203 204def conv3d_input( 205 input_size, 206 weight, 207 grad_output, 208 stride=1, 209 padding=0, 210 dilation=1, 211 groups=1, 212): 213 r"""Compute the gradient of conv3d with respect to the input of the convolution. 214 215 This is same as the 3D transposed convolution operator under the hood but requires 216 the shape of the gradient w.r.t. input to be specified explicitly. 217 218 Args: 219 input_size : Shape of the input gradient tensor 220 weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW) 221 grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) 222 stride (int or tuple, optional): Stride of the convolution. Default: 1 223 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 224 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 225 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 226 227 Examples:: 228 229 >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) 230 >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) 231 >>> output = F.conv3d(input, weight) 232 >>> grad_output = torch.randn(output.shape) 233 >>> grad_input = torch.autograd.grad(output, input, grad_output) 234 >>> F.grad.conv3d_input(input.shape, weight, grad_output) 235 236 """ 237 input = grad_output.new_empty(1).expand(input_size) 238 239 return torch.ops.aten.convolution_backward( 240 grad_output, 241 input, 242 weight, 243 None, 244 _triple(stride), 245 _triple(padding), 246 _triple(dilation), 247 False, 248 [0], 249 groups, 250 (True, False, False), 251 )[0] 252 253 254def conv3d_weight( 255 input, 256 weight_size, 257 grad_output, 258 stride=1, 259 padding=0, 260 dilation=1, 261 groups=1, 262): 263 r"""Compute the gradient of conv3d with respect to the weight of the convolution. 264 265 Args: 266 input: input tensor of shape (minibatch x in_channels x iT x iH x iW) 267 weight_size : Shape of the weight gradient tensor 268 grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) 269 stride (int or tuple, optional): Stride of the convolution. Default: 1 270 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 271 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 272 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 273 274 Examples:: 275 276 >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) 277 >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) 278 >>> output = F.conv3d(input, weight) 279 >>> grad_output = torch.randn(output.shape) 280 >>> grad_weight = torch.autograd.grad(output, weight, grad_output) 281 >>> F.grad.conv3d_weight(input, weight.shape, grad_output) 282 283 """ 284 weight = grad_output.new_empty(1).expand(weight_size) 285 286 return torch.ops.aten.convolution_backward( 287 grad_output, 288 input, 289 weight, 290 None, 291 _triple(stride), 292 _triple(padding), 293 _triple(dilation), 294 False, 295 [0], 296 groups, 297 (False, True, False), 298 )[1] 299