xref: /aosp_15_r20/external/libopus/dnn/torch/dnntools/dnntools/relegance/relegance.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31import torch.nn.functional as F
32
33
34def view_one_hot(index, length):
35    vec = length * [1]
36    vec[index] = -1
37    return vec
38
39def create_smoothing_kernel(widths, gamma=1.5):
40    """ creates a truncated gaussian smoothing kernel for the given widths
41
42        Parameters:
43        -----------
44        widths: list[Int] or torch.LongTensor
45            specifies the shape of the smoothing kernel, entries must be > 0.
46
47        gamma: float, optional
48            decay factor for gaussian relative to kernel size
49
50        Returns:
51        --------
52        kernel: torch.FloatTensor
53    """
54
55    widths = torch.LongTensor(widths)
56    num_dims = len(widths)
57
58    assert(widths.min() > 0)
59
60    centers = widths.float() / 2 - 0.5
61    sigmas  = gamma * (centers + 1)
62
63    vals = []
64
65    vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
66    vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
67
68    kernel = torch.exp(- vals)
69    kernel = kernel / kernel.sum()
70
71    return kernel
72
73
74def create_partition_kernel(widths, strides):
75    """ creates a partition kernel for mapping a convolutional network output back to the input domain
76
77        Given a fully convolutional network with receptive field of shape widths and the given strides, this
78        function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
79        a partition of one on the input domain.
80
81        Parameter:
82        ----------
83        widths: list[Int] or torch.LongTensor
84            shape of receptive field
85
86        strides: list[Int] or torch.LongTensor
87            total strides of convolutional network
88
89        Returns:
90        kernel: torch.FloatTensor
91    """
92
93    num_dims = len(widths)
94    assert num_dims == len(strides) and num_dims in {1, 2, 3}
95
96    convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
97
98    widths = torch.LongTensor(widths)
99    strides = torch.LongTensor(strides)
100
101    proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
102
103    # create interpolation kernel eta
104    eta_widths = widths - strides + 1
105    if eta_widths.min() <= 0:
106        print("[create_partition_kernel] warning: receptive field does not cover input domain")
107        eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
108
109
110    eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
111
112    padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
113    padded_proto_kernel = F.pad(proto_kernel, padding)
114    padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
115    kernel = convs[num_dims](padded_proto_kernel, eta)
116
117    return kernel
118
119
120def receptive_field(conv_model, input_shape, output_position):
121    """ estimates boundaries of receptive field connected to output_position via autograd
122
123        Parameters:
124        -----------
125        conv_model: nn.Module or autograd function
126            function or model implementing fully convolutional model
127
128        input_shape: List[Int]
129            input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
130
131        output_position: List[Int]
132            output position for which the receptive field is determined; the function raises an exception
133            if output_position is out of bounds for the given input_shape.
134
135        Returns:
136        --------
137        low: List[Int]
138            start indices of receptive field
139
140        high: List[Int]
141            stop indices of receptive field
142
143    """
144
145    x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
146    y = conv_model(x)
147
148    # collapse channels and remove batch dimension
149    y = torch.sum(y, 1)[0]
150
151    # create mask
152    mask = torch.zeros_like(y)
153    index = [torch.tensor(i) for i in output_position]
154    try:
155        mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
156    except IndexError:
157        raise ValueError('output_position out of bounds')
158
159    (mask * y).sum().backward()
160
161    # sum over channels and remove batch dimension
162    grad = torch.sum(x.grad, dim=1)[0]
163    tmp = torch.nonzero(grad, as_tuple=True)
164    low  = [t.min().item() for t in tmp]
165    high = [t.max().item() for t in tmp]
166
167    return low, high
168
169def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
170    """ attempts to estimate receptive field size, strides and left paddings for given model
171
172
173        Parameters:
174        -----------
175        model: nn.Module or autograd function
176            fully convolutional model for which parameters are estimated
177
178        num_channels: Int
179            number of input channels for model
180
181        num_dims: Int
182            number of input dimensions for model (without channel dimension)
183
184        width: Int
185            width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
186
187        max_stride: Int, optional
188            assumed maximal stride of the model for any dimension, when set too low the function may fail for
189            any value of width
190
191        Returns:
192        --------
193        receptive_field_size: List[Int]
194            receptive field size in all dimension
195
196        strides: List[Int]
197            stride in all dimensions
198
199        left_paddings: List[Int]
200            left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
201
202        Raises:
203        -------
204        ValueError, KeyError
205
206    """
207
208    input_shape = [num_channels] + num_dims * [width]
209    output_position1 = num_dims * [width // (2 * max_stride)]
210    output_position2 = num_dims * [width // (2 * max_stride) + 1]
211
212    low1, high1 = receptive_field(model, input_shape, output_position1)
213    low2, high2 = receptive_field(model, input_shape, output_position2)
214
215    widths1 = [h - l + 1 for l, h in zip(low1, high1)]
216    widths2 = [h - l + 1 for l, h in zip(low2, high2)]
217
218    if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
219        raise ValueError("[estimate_strides]: widths to small to determine strides")
220
221    receptive_field_size = widths1
222    strides              = [l2 - l1 for l1, l2 in zip(low1, low2)]
223    left_paddings        = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
224
225    return receptive_field_size, strides, left_paddings
226
227def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
228    """ determines size of receptive field, strides and padding probabilistically
229
230
231        Parameters:
232        -----------
233        model: nn.Module or autograd function
234            fully convolutional model for which parameters are estimated
235
236        num_channels: Int
237            number of input channels for model
238
239        num_dims: Int
240            number of input dimensions for model (without channel dimension)
241
242        max_width: Int
243            maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
244
245        verbose: bool, optional
246            if true, the function prints parameters for individual trials
247
248        Returns:
249        --------
250        receptive_field_size: List[Int]
251            receptive field size in all dimension
252
253        strides: List[Int]
254            stride in all dimensions
255
256        left_paddings: List[Int]
257            left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
258
259        Raises:
260        -------
261        ValueError
262
263    """
264
265    max_stride = max_width // 2
266    stride = max_stride // 100
267    width = max_width // 100
268
269    if width_hint is not None: width = 2 * width_hint
270    if stride_hint is not None: stride = stride_hint
271
272    did_it = False
273    while width < max_width and stride < max_stride:
274        try:
275            if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
276            receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
277            did_it = True
278        except:
279            pass
280
281        if did_it: break
282
283        width *= 2
284        if width >= max_width and stride < max_stride:
285            stride *= 2
286            width = 2 * stride
287
288    if not did_it:
289        raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
290
291    return receptive_field_size, strides, left_paddings
292
293
294class GradWeight(torch.autograd.Function):
295    def __init__(self):
296        super().__init__()
297
298    @staticmethod
299    def forward(ctx, x, weight):
300        ctx.save_for_backward(weight)
301        return x.clone()
302
303    @staticmethod
304    def backward(ctx, grad_output):
305        weight, = ctx.saved_tensors
306
307        grad_input = grad_output * weight
308
309        return grad_input, None
310
311
312# API
313
314def relegance_gradient_weighting(x, weight):
315    """
316
317    Args:
318        x (torch.tensor): input tensor
319        weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
320
321    Returns:
322        torch.tensor: the unmodified input tensor x
323
324    Raises:
325        RuntimeError: if estimation of parameters fails due to exceeded compute budget
326    """
327    if weight is None:
328        return x
329    else:
330        return GradWeight.apply(x, weight)
331
332
333
334def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
335    """ creates parameters for mapping back output domain relevance to input tomain
336
337    Args:
338        model (nn.Module or autograd.Function): fully convolutional model
339        num_channels (int): number of input channels to model
340        num_dims (int): number of input dimensions of model (without channel and batch dimension)
341        width_hint(int or None): optional hint at maximal width of receptive field
342        stride_hint(int or None): optional hint at maximal stride
343
344    Returns:
345        dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
346    """
347
348    max_width = int(100000 / (10 ** num_dims))
349
350    did_it = False
351    try:
352        receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
353        did_it = True
354    except:
355        # try once again with larger max_width
356        max_width *= 10
357
358    # crash if exception is raised
359    try:
360        if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
361    except:
362        raise RuntimeError("could not determine parameters within given compute budget")
363
364    partition_kernel = create_partition_kernel(receptive_field_size, strides)
365    partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
366
367    tconv_parameters = {
368        'kernel': partition_kernel,
369        'receptive_field_shape': receptive_field_size,
370        'stride': strides,
371        'left_padding': left_paddings,
372        'num_dims': num_dims
373    }
374
375    return tconv_parameters
376
377
378
379def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
380    """ maps output-domain relevance to input-domain relevance via transpose convolution
381
382    Args:
383        od_relevance (torch.tensor): output-domain relevance
384        tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
385
386    Returns:
387        torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
388                      Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
389                      convenient way to adjust the output to the correct size.
390
391    Raises:
392        ValueError: if number of dimensions is not supported
393    """
394
395    kernel       = tconv_parameters['kernel'].to(od_relevance.device)
396    rf_shape     = tconv_parameters['receptive_field_shape']
397    stride       = tconv_parameters['stride']
398    left_padding = tconv_parameters['left_padding']
399
400    num_dims = len(kernel.shape) - 2
401
402    # repeat boundary values
403    od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
404    padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
405    od_padding = od_padding[::2]
406
407    # apply mapping and left trimming
408    if num_dims == 1:
409        id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
410        id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
411    elif num_dims == 2:
412        id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
413        id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
414    elif num_dims == 3:
415        id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
416        id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
417    else:
418        raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
419
420    return id_relevance
421
422
423def relegance_resize_relevance_to_input_size(reference_input, relevance):
424    """ adjusts size of relevance tensor to reference input size
425
426    Args:
427        reference_input (torch.tensor): discriminator input tensor for reference
428        relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
429
430    Returns:
431        torch.tensor: resized relevance
432
433    Raises:
434        ValueError: if number of dimensions is not supported
435    """
436    resized_relevance = torch.zeros_like(reference_input)
437
438    num_dims = len(reference_input.shape) - 2
439    with torch.no_grad():
440        if num_dims == 1:
441            resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
442        elif num_dims == 2:
443            resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
444        elif num_dims == 3:
445            resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
446        else:
447            raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
448
449    return resized_relevance