1# mypy: allow-untyped-defs 2from typing import List, Optional 3 4import numpy as np 5 6import torch 7import torch.nn.functional as F 8 9from .expanded_weights_utils import ( 10 set_grad_sample_if_exists, 11 unpack_expanded_weight_or_tensor, 12) 13 14 15THRESHOLD = 32 16 17 18def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): 19 if func == F.conv1d: 20 return conv1dOpt 21 if func == F.conv2d: 22 return conv2dOpt 23 else: 24 assert func == F.conv3d 25 return conv3dOpt 26 27 28def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): 29 args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)] 30 kwargs = expanded_args_and_kwargs[ 31 len(expanded_args_and_kwargs) - len(kwarg_names) : 32 ] 33 kwargs = dict(zip(kwarg_names, kwargs)) 34 35 return conv_normalizer(*args, **kwargs) 36 37 38def conv_normalizer( 39 input, 40 weight, 41 bias=None, 42 stride=1, 43 padding=0, 44 dilation=1, 45 groups=1, 46): 47 return (input, weight), { 48 "bias": bias, 49 "stride": stride, 50 "padding": padding, 51 "dilation": dilation, 52 "groups": groups, 53 } 54 55 56def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size): 57 if padding_style == "valid": 58 return input 59 else: 60 padding = int_padding_for_string_padding( 61 func, padding_style, dilation, kernel_size 62 ) 63 return F.pad(input, padding) 64 65 66def int_padding_for_string_padding(func, padding_style, dilation, kernel_size): 67 def get_dilation(i): 68 return dilation[i] if isinstance(dilation, tuple) else dilation 69 70 if padding_style == "same": 71 padding: List[int] = [] 72 # F.pad needs the padding in reverse order from what conv expects 73 for i in range(conv_picker(func, 0, 1, 2), -1, -1): 74 padding += conv_padding_for_same(get_dilation(i), kernel_size[i]) 75 return padding 76 elif padding_style == "valid": 77 return conv_picker(func, 2, 4, 6) * (0,) 78 else: 79 raise RuntimeError( 80 f"got padding type of {padding_style}, only accept 'same' or 'valid'" 81 ) 82 83 84def conv_padding_for_same(dilation, kernel_size): 85 total_pad = dilation * (kernel_size - 1) 86 left_pad = total_pad // 2 87 right_pad = total_pad - left_pad 88 return left_pad, right_pad 89 90 91def conv_backward(func, ctx, grad_output): 92 def weight_grad_sample(weight): 93 if batch_size < THRESHOLD and groups == 1: 94 return conv_group_weight_grad_sample( 95 ctx.input, 96 grad_output, 97 weight_shape, 98 stride, 99 padding, 100 dilation, 101 batch_size, 102 func, 103 ) 104 else: 105 return conv_unfold_weight_grad_sample( 106 ctx.input, 107 grad_output, 108 weight_shape, 109 kernel_size, 110 stride, 111 padding, 112 dilation, 113 groups, 114 func, 115 ) 116 117 def expand(param): 118 if isinstance(param, int): 119 return conv_picker(func, (param,), (param, param), (param, param, param)) 120 else: 121 return param 122 123 def calc_total_padding(func, was_same, padding, dilation, kernel_size): 124 if was_same: 125 all_padding = int_padding_for_string_padding( 126 func, "same", dilation, kernel_size 127 ) 128 # F.pad needs the padding in reverse order from what conv expects 129 total_padding = tuple( 130 all_padding[i] + all_padding[i - 1] 131 for i in range(len(all_padding) - 1, -1, -2) 132 ) 133 return total_padding 134 else: 135 return tuple(2 * pad for pad in padding) 136 137 weight_shape = ctx.weight.shape 138 stride, padding, dilation, groups = ( 139 expand(ctx.stride), 140 expand(ctx.padding), 141 expand(ctx.dilation), 142 ctx.groups, 143 ) 144 145 kernel_size = [] 146 for i in range(2, conv_picker(func, 3, 4, 5)): 147 kernel_size.append(weight_shape[i]) 148 149 batch_size = ctx.batch_size 150 results: List[Optional[torch.Tensor]] = [] 151 results.append(None) # for kwarg names 152 results.append(None) # for op reference 153 154 # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding 155 total_padding = calc_total_padding( 156 func, ctx.was_same_padding, padding, dilation, kernel_size 157 ) 158 159 if ctx.input_required_grad: 160 output_padding = [] 161 input_dims = conv_picker(func, 1, 2, 3) 162 for i in range(input_dims): 163 input_dim = ctx.orig_input_shape[2 + i] 164 output_padding.append( 165 ( 166 total_padding[i] 167 + input_dim 168 - (kernel_size[i] * dilation[i] - dilation[i] + 1) 169 ) 170 % stride[i] 171 ) 172 weight_ = unpack_expanded_weight_or_tensor(ctx.weight) 173 transpose_func = conv_picker( 174 func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d 175 ) 176 out = transpose_func( 177 grad_output, 178 weight_, 179 None, 180 stride, 181 padding, 182 tuple(output_padding), 183 groups, 184 dilation, 185 ) 186 187 if ctx.was_same_padding: 188 for i in range(len(total_padding)): 189 out = torch.narrow( 190 out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i] 191 ) 192 193 results.append(out) 194 else: 195 results.append(None) 196 # weight and bias don't compute batched gradients; no other arguments are differentiable 197 results = results + [None] * 6 198 199 # set grad_sample field for weight and bias with per sample gradients 200 set_grad_sample_if_exists(ctx.weight, weight_grad_sample) 201 set_grad_sample_if_exists( 202 ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2) 203 ) 204 return tuple(results) 205 206 207def conv_unfold_weight_grad_sample( 208 input, 209 grad_output, 210 weight_shape, 211 kernel_size, 212 stride, 213 padding, 214 dilation, 215 groups, 216 func, 217): 218 n = input.shape[0] 219 in_channels = input.shape[1] 220 221 unfold_func = conv_picker( 222 func, 223 lambda: F.unfold( 224 input.unsqueeze(-2), 225 kernel_size=(1, kernel_size[0]), 226 dilation=(1, dilation[0]), 227 padding=(0, padding[0]), 228 stride=(1, stride[0]), 229 ), 230 lambda: F.unfold( 231 input, kernel_size, dilation=dilation, padding=padding, stride=stride 232 ), 233 lambda: unfold3d(input, kernel_size, padding, stride, dilation), 234 ) 235 236 input = unfold_func() 237 grad_output = grad_output.reshape(n, -1, input.shape[-1]) 238 239 # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz 240 weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) 241 # rearrange the above tensor and extract diagonals. 242 weight_grad_sample = weight_grad_sample.view( 243 n, 244 groups, 245 -1, 246 groups, 247 int(in_channels / groups), 248 np.prod(kernel_size), 249 ) 250 weight_grad_sample = torch.einsum( 251 "ngrg...->ngr...", weight_grad_sample 252 ).contiguous() 253 shape = [n] + list(weight_shape) 254 weight_grad_sample = weight_grad_sample.view(shape) 255 return weight_grad_sample 256 257 258def conv_group_weight_grad_sample( 259 input, 260 grad_output, 261 weight_shape, 262 stride, 263 padding, 264 dilation, 265 batch_size, 266 func, 267): 268 I = input.shape[1] 269 O = grad_output.shape[1] 270 271 input_ = input.transpose(0, 1) 272 grad_output_ = grad_output.view( 273 grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] 274 ) 275 276 weight_grad_sample = func( 277 input_, 278 grad_output_, 279 None, 280 stride=dilation, 281 padding=padding, 282 dilation=stride, 283 groups=batch_size, 284 ) 285 input_dims = conv_picker(func, 3, 4, 5) 286 for i in range(2, input_dims): 287 weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i]) 288 weight_grad_sample = weight_grad_sample.view( 289 I, batch_size, O, *weight_grad_sample.shape[2:] 290 ) 291 weight_grad_sample = weight_grad_sample.movedim(0, 2) 292 return weight_grad_sample 293 294 295def unfold3d( 296 tensor, 297 kernel_size, 298 padding, 299 stride, 300 dilation, 301): 302 r""" 303 Extract sliding local blocks from an batched input tensor. 304 305 :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors). 306 This method implements the same action for 5D inputs 307 Args: 308 tensor: An input tensor of shape ``(B, C, D, H, W)``. 309 kernel_size: the size of the sliding blocks 310 padding: implicit zero padding to be added on both sides of input 311 stride: the stride of the sliding blocks in the input spatial dimensions 312 dilation: the spacing between the kernel points. 313 Returns: 314 A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions. 315 See :class:`torch.nn.Unfold` for more details 316 Example: 317 >>> # xdoctest: +SKIP 318 >>> B, C, D, H, W = 3, 4, 5, 6, 7 319 >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W) 320 >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape 321 torch.Size([3, 32, 120]) 322 """ 323 if len(tensor.shape) != 5: 324 raise ValueError( 325 f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" 326 ) 327 328 if dilation != (1, 1, 1): 329 raise NotImplementedError(f"dilation={dilation} not supported.") 330 331 batch_size, channels, _, _, _ = tensor.shape 332 333 # Input shape: (B, C, D, H, W) 334 tensor = F.pad( 335 tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) 336 ) 337 # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0]) 338 339 tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0]) 340 tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1]) 341 tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2]) 342 # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2]) 343 # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold` 344 345 tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7) 346 # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2]) 347 348 tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose( 349 1, 2 350 ) 351 # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2] 352 353 return tensor 354