1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31import torch.nn.functional as F 32 33 34def view_one_hot(index, length): 35 vec = length * [1] 36 vec[index] = -1 37 return vec 38 39def create_smoothing_kernel(widths, gamma=1.5): 40 """ creates a truncated gaussian smoothing kernel for the given widths 41 42 Parameters: 43 ----------- 44 widths: list[Int] or torch.LongTensor 45 specifies the shape of the smoothing kernel, entries must be > 0. 46 47 gamma: float, optional 48 decay factor for gaussian relative to kernel size 49 50 Returns: 51 -------- 52 kernel: torch.FloatTensor 53 """ 54 55 widths = torch.LongTensor(widths) 56 num_dims = len(widths) 57 58 assert(widths.min() > 0) 59 60 centers = widths.float() / 2 - 0.5 61 sigmas = gamma * (centers + 1) 62 63 vals = [] 64 65 vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)] 66 vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)]) 67 68 kernel = torch.exp(- vals) 69 kernel = kernel / kernel.sum() 70 71 return kernel 72 73 74def create_partition_kernel(widths, strides): 75 """ creates a partition kernel for mapping a convolutional network output back to the input domain 76 77 Given a fully convolutional network with receptive field of shape widths and the given strides, this 78 function construncts an intorpolation kernel whose tranlations by multiples of the given strides form 79 a partition of one on the input domain. 80 81 Parameter: 82 ---------- 83 widths: list[Int] or torch.LongTensor 84 shape of receptive field 85 86 strides: list[Int] or torch.LongTensor 87 total strides of convolutional network 88 89 Returns: 90 kernel: torch.FloatTensor 91 """ 92 93 num_dims = len(widths) 94 assert num_dims == len(strides) and num_dims in {1, 2, 3} 95 96 convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d} 97 98 widths = torch.LongTensor(widths) 99 strides = torch.LongTensor(strides) 100 101 proto_kernel = torch.ones(torch.minimum(strides, widths).tolist()) 102 103 # create interpolation kernel eta 104 eta_widths = widths - strides + 1 105 if eta_widths.min() <= 0: 106 print("[create_partition_kernel] warning: receptive field does not cover input domain") 107 eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths)) 108 109 110 eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist()) 111 112 padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch 113 padded_proto_kernel = F.pad(proto_kernel, padding) 114 padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape) 115 kernel = convs[num_dims](padded_proto_kernel, eta) 116 117 return kernel 118 119 120def receptive_field(conv_model, input_shape, output_position): 121 """ estimates boundaries of receptive field connected to output_position via autograd 122 123 Parameters: 124 ----------- 125 conv_model: nn.Module or autograd function 126 function or model implementing fully convolutional model 127 128 input_shape: List[Int] 129 input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...] 130 131 output_position: List[Int] 132 output position for which the receptive field is determined; the function raises an exception 133 if output_position is out of bounds for the given input_shape. 134 135 Returns: 136 -------- 137 low: List[Int] 138 start indices of receptive field 139 140 high: List[Int] 141 stop indices of receptive field 142 143 """ 144 145 x = torch.randn((1,) + tuple(input_shape), requires_grad=True) 146 y = conv_model(x) 147 148 # collapse channels and remove batch dimension 149 y = torch.sum(y, 1)[0] 150 151 # create mask 152 mask = torch.zeros_like(y) 153 index = [torch.tensor(i) for i in output_position] 154 try: 155 mask.index_put_(index, torch.tensor(1, dtype=mask.dtype)) 156 except IndexError: 157 raise ValueError('output_position out of bounds') 158 159 (mask * y).sum().backward() 160 161 # sum over channels and remove batch dimension 162 grad = torch.sum(x.grad, dim=1)[0] 163 tmp = torch.nonzero(grad, as_tuple=True) 164 low = [t.min().item() for t in tmp] 165 high = [t.max().item() for t in tmp] 166 167 return low, high 168 169def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10): 170 """ attempts to estimate receptive field size, strides and left paddings for given model 171 172 173 Parameters: 174 ----------- 175 model: nn.Module or autograd function 176 fully convolutional model for which parameters are estimated 177 178 num_channels: Int 179 number of input channels for model 180 181 num_dims: Int 182 number of input dimensions for model (without channel dimension) 183 184 width: Int 185 width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd 186 187 max_stride: Int, optional 188 assumed maximal stride of the model for any dimension, when set too low the function may fail for 189 any value of width 190 191 Returns: 192 -------- 193 receptive_field_size: List[Int] 194 receptive field size in all dimension 195 196 strides: List[Int] 197 stride in all dimensions 198 199 left_paddings: List[Int] 200 left padding in all dimensions; this is relevant for aligning the receptive field on the input plane 201 202 Raises: 203 ------- 204 ValueError, KeyError 205 206 """ 207 208 input_shape = [num_channels] + num_dims * [width] 209 output_position1 = num_dims * [width // (2 * max_stride)] 210 output_position2 = num_dims * [width // (2 * max_stride) + 1] 211 212 low1, high1 = receptive_field(model, input_shape, output_position1) 213 low2, high2 = receptive_field(model, input_shape, output_position2) 214 215 widths1 = [h - l + 1 for l, h in zip(low1, high1)] 216 widths2 = [h - l + 1 for l, h in zip(low2, high2)] 217 218 if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]): 219 raise ValueError("[estimate_strides]: widths to small to determine strides") 220 221 receptive_field_size = widths1 222 strides = [l2 - l1 for l1, l2 in zip(low1, low2)] 223 left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)] 224 225 return receptive_field_size, strides, left_paddings 226 227def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False): 228 """ determines size of receptive field, strides and padding probabilistically 229 230 231 Parameters: 232 ----------- 233 model: nn.Module or autograd function 234 fully convolutional model for which parameters are estimated 235 236 num_channels: Int 237 number of input channels for model 238 239 num_dims: Int 240 number of input dimensions for model (without channel dimension) 241 242 max_width: Int 243 maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd 244 245 verbose: bool, optional 246 if true, the function prints parameters for individual trials 247 248 Returns: 249 -------- 250 receptive_field_size: List[Int] 251 receptive field size in all dimension 252 253 strides: List[Int] 254 stride in all dimensions 255 256 left_paddings: List[Int] 257 left padding in all dimensions; this is relevant for aligning the receptive field on the input plane 258 259 Raises: 260 ------- 261 ValueError 262 263 """ 264 265 max_stride = max_width // 2 266 stride = max_stride // 100 267 width = max_width // 100 268 269 if width_hint is not None: width = 2 * width_hint 270 if stride_hint is not None: stride = stride_hint 271 272 did_it = False 273 while width < max_width and stride < max_stride: 274 try: 275 if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}") 276 receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride) 277 did_it = True 278 except: 279 pass 280 281 if did_it: break 282 283 width *= 2 284 if width >= max_width and stride < max_stride: 285 stride *= 2 286 width = 2 * stride 287 288 if not did_it: 289 raise ValueError(f'could not determine conv parameter with given max_width={max_width}') 290 291 return receptive_field_size, strides, left_paddings 292 293 294class GradWeight(torch.autograd.Function): 295 def __init__(self): 296 super().__init__() 297 298 @staticmethod 299 def forward(ctx, x, weight): 300 ctx.save_for_backward(weight) 301 return x.clone() 302 303 @staticmethod 304 def backward(ctx, grad_output): 305 weight, = ctx.saved_tensors 306 307 grad_input = grad_output * weight 308 309 return grad_input, None 310 311 312# API 313 314def relegance_gradient_weighting(x, weight): 315 """ 316 317 Args: 318 x (torch.tensor): input tensor 319 weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass 320 321 Returns: 322 torch.tensor: the unmodified input tensor x 323 324 Raises: 325 RuntimeError: if estimation of parameters fails due to exceeded compute budget 326 """ 327 if weight is None: 328 return x 329 else: 330 return GradWeight.apply(x, weight) 331 332 333 334def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False): 335 """ creates parameters for mapping back output domain relevance to input tomain 336 337 Args: 338 model (nn.Module or autograd.Function): fully convolutional model 339 num_channels (int): number of input channels to model 340 num_dims (int): number of input dimensions of model (without channel and batch dimension) 341 width_hint(int or None): optional hint at maximal width of receptive field 342 stride_hint(int or None): optional hint at maximal stride 343 344 Returns: 345 dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution 346 """ 347 348 max_width = int(100000 / (10 ** num_dims)) 349 350 did_it = False 351 try: 352 receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose) 353 did_it = True 354 except: 355 # try once again with larger max_width 356 max_width *= 10 357 358 # crash if exception is raised 359 try: 360 if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose) 361 except: 362 raise RuntimeError("could not determine parameters within given compute budget") 363 364 partition_kernel = create_partition_kernel(receptive_field_size, strides) 365 partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1) 366 367 tconv_parameters = { 368 'kernel': partition_kernel, 369 'receptive_field_shape': receptive_field_size, 370 'stride': strides, 371 'left_padding': left_paddings, 372 'num_dims': num_dims 373 } 374 375 return tconv_parameters 376 377 378 379def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters): 380 """ maps output-domain relevance to input-domain relevance via transpose convolution 381 382 Args: 383 od_relevance (torch.tensor): output-domain relevance 384 tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel 385 386 Returns: 387 torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input. 388 Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a 389 convenient way to adjust the output to the correct size. 390 391 Raises: 392 ValueError: if number of dimensions is not supported 393 """ 394 395 kernel = tconv_parameters['kernel'].to(od_relevance.device) 396 rf_shape = tconv_parameters['receptive_field_shape'] 397 stride = tconv_parameters['stride'] 398 left_padding = tconv_parameters['left_padding'] 399 400 num_dims = len(kernel.shape) - 2 401 402 # repeat boundary values 403 od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)] 404 padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate') 405 od_padding = od_padding[::2] 406 407 # apply mapping and left trimming 408 if num_dims == 1: 409 id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride) 410 id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :] 411 elif num_dims == 2: 412 id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride) 413 id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:] 414 elif num_dims == 3: 415 id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride) 416 id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :] 417 else: 418 raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported') 419 420 return id_relevance 421 422 423def relegance_resize_relevance_to_input_size(reference_input, relevance): 424 """ adjusts size of relevance tensor to reference input size 425 426 Args: 427 reference_input (torch.tensor): discriminator input tensor for reference 428 relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input 429 430 Returns: 431 torch.tensor: resized relevance 432 433 Raises: 434 ValueError: if number of dimensions is not supported 435 """ 436 resized_relevance = torch.zeros_like(reference_input) 437 438 num_dims = len(reference_input.shape) - 2 439 with torch.no_grad(): 440 if num_dims == 1: 441 resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))] 442 elif num_dims == 2: 443 resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))] 444 elif num_dims == 3: 445 resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))] 446 else: 447 raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported') 448 449 return resized_relevance