xref: /aosp_15_r20/external/pytorch/torch/nn/utils/_expanded_weights/conv_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Optional
3
4import numpy as np
5
6import torch
7import torch.nn.functional as F
8
9from .expanded_weights_utils import (
10    set_grad_sample_if_exists,
11    unpack_expanded_weight_or_tensor,
12)
13
14
15THRESHOLD = 32
16
17
18def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
19    if func == F.conv1d:
20        return conv1dOpt
21    if func == F.conv2d:
22        return conv2dOpt
23    else:
24        assert func == F.conv3d
25        return conv3dOpt
26
27
28def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
29    args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)]
30    kwargs = expanded_args_and_kwargs[
31        len(expanded_args_and_kwargs) - len(kwarg_names) :
32    ]
33    kwargs = dict(zip(kwarg_names, kwargs))
34
35    return conv_normalizer(*args, **kwargs)
36
37
38def conv_normalizer(
39    input,
40    weight,
41    bias=None,
42    stride=1,
43    padding=0,
44    dilation=1,
45    groups=1,
46):
47    return (input, weight), {
48        "bias": bias,
49        "stride": stride,
50        "padding": padding,
51        "dilation": dilation,
52        "groups": groups,
53    }
54
55
56def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size):
57    if padding_style == "valid":
58        return input
59    else:
60        padding = int_padding_for_string_padding(
61            func, padding_style, dilation, kernel_size
62        )
63        return F.pad(input, padding)
64
65
66def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
67    def get_dilation(i):
68        return dilation[i] if isinstance(dilation, tuple) else dilation
69
70    if padding_style == "same":
71        padding: List[int] = []
72        # F.pad needs the padding in reverse order from what conv expects
73        for i in range(conv_picker(func, 0, 1, 2), -1, -1):
74            padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
75        return padding
76    elif padding_style == "valid":
77        return conv_picker(func, 2, 4, 6) * (0,)
78    else:
79        raise RuntimeError(
80            f"got padding type of {padding_style}, only accept 'same' or 'valid'"
81        )
82
83
84def conv_padding_for_same(dilation, kernel_size):
85    total_pad = dilation * (kernel_size - 1)
86    left_pad = total_pad // 2
87    right_pad = total_pad - left_pad
88    return left_pad, right_pad
89
90
91def conv_backward(func, ctx, grad_output):
92    def weight_grad_sample(weight):
93        if batch_size < THRESHOLD and groups == 1:
94            return conv_group_weight_grad_sample(
95                ctx.input,
96                grad_output,
97                weight_shape,
98                stride,
99                padding,
100                dilation,
101                batch_size,
102                func,
103            )
104        else:
105            return conv_unfold_weight_grad_sample(
106                ctx.input,
107                grad_output,
108                weight_shape,
109                kernel_size,
110                stride,
111                padding,
112                dilation,
113                groups,
114                func,
115            )
116
117    def expand(param):
118        if isinstance(param, int):
119            return conv_picker(func, (param,), (param, param), (param, param, param))
120        else:
121            return param
122
123    def calc_total_padding(func, was_same, padding, dilation, kernel_size):
124        if was_same:
125            all_padding = int_padding_for_string_padding(
126                func, "same", dilation, kernel_size
127            )
128            # F.pad needs the padding in reverse order from what conv expects
129            total_padding = tuple(
130                all_padding[i] + all_padding[i - 1]
131                for i in range(len(all_padding) - 1, -1, -2)
132            )
133            return total_padding
134        else:
135            return tuple(2 * pad for pad in padding)
136
137    weight_shape = ctx.weight.shape
138    stride, padding, dilation, groups = (
139        expand(ctx.stride),
140        expand(ctx.padding),
141        expand(ctx.dilation),
142        ctx.groups,
143    )
144
145    kernel_size = []
146    for i in range(2, conv_picker(func, 3, 4, 5)):
147        kernel_size.append(weight_shape[i])
148
149    batch_size = ctx.batch_size
150    results: List[Optional[torch.Tensor]] = []
151    results.append(None)  # for kwarg names
152    results.append(None)  # for op reference
153
154    # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding
155    total_padding = calc_total_padding(
156        func, ctx.was_same_padding, padding, dilation, kernel_size
157    )
158
159    if ctx.input_required_grad:
160        output_padding = []
161        input_dims = conv_picker(func, 1, 2, 3)
162        for i in range(input_dims):
163            input_dim = ctx.orig_input_shape[2 + i]
164            output_padding.append(
165                (
166                    total_padding[i]
167                    + input_dim
168                    - (kernel_size[i] * dilation[i] - dilation[i] + 1)
169                )
170                % stride[i]
171            )
172        weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
173        transpose_func = conv_picker(
174            func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d
175        )
176        out = transpose_func(
177            grad_output,
178            weight_,
179            None,
180            stride,
181            padding,
182            tuple(output_padding),
183            groups,
184            dilation,
185        )
186
187        if ctx.was_same_padding:
188            for i in range(len(total_padding)):
189                out = torch.narrow(
190                    out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i]
191                )
192
193        results.append(out)
194    else:
195        results.append(None)
196    # weight and bias don't compute batched gradients; no other arguments are differentiable
197    results = results + [None] * 6
198
199    # set grad_sample field for weight and bias with per sample gradients
200    set_grad_sample_if_exists(ctx.weight, weight_grad_sample)
201    set_grad_sample_if_exists(
202        ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2)
203    )
204    return tuple(results)
205
206
207def conv_unfold_weight_grad_sample(
208    input,
209    grad_output,
210    weight_shape,
211    kernel_size,
212    stride,
213    padding,
214    dilation,
215    groups,
216    func,
217):
218    n = input.shape[0]
219    in_channels = input.shape[1]
220
221    unfold_func = conv_picker(
222        func,
223        lambda: F.unfold(
224            input.unsqueeze(-2),
225            kernel_size=(1, kernel_size[0]),
226            dilation=(1, dilation[0]),
227            padding=(0, padding[0]),
228            stride=(1, stride[0]),
229        ),
230        lambda: F.unfold(
231            input, kernel_size, dilation=dilation, padding=padding, stride=stride
232        ),
233        lambda: unfold3d(input, kernel_size, padding, stride, dilation),
234    )
235
236    input = unfold_func()
237    grad_output = grad_output.reshape(n, -1, input.shape[-1])
238
239    # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
240    weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
241    # rearrange the above tensor and extract diagonals.
242    weight_grad_sample = weight_grad_sample.view(
243        n,
244        groups,
245        -1,
246        groups,
247        int(in_channels / groups),
248        np.prod(kernel_size),
249    )
250    weight_grad_sample = torch.einsum(
251        "ngrg...->ngr...", weight_grad_sample
252    ).contiguous()
253    shape = [n] + list(weight_shape)
254    weight_grad_sample = weight_grad_sample.view(shape)
255    return weight_grad_sample
256
257
258def conv_group_weight_grad_sample(
259    input,
260    grad_output,
261    weight_shape,
262    stride,
263    padding,
264    dilation,
265    batch_size,
266    func,
267):
268    I = input.shape[1]
269    O = grad_output.shape[1]
270
271    input_ = input.transpose(0, 1)
272    grad_output_ = grad_output.view(
273        grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:]
274    )
275
276    weight_grad_sample = func(
277        input_,
278        grad_output_,
279        None,
280        stride=dilation,
281        padding=padding,
282        dilation=stride,
283        groups=batch_size,
284    )
285    input_dims = conv_picker(func, 3, 4, 5)
286    for i in range(2, input_dims):
287        weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i])
288    weight_grad_sample = weight_grad_sample.view(
289        I, batch_size, O, *weight_grad_sample.shape[2:]
290    )
291    weight_grad_sample = weight_grad_sample.movedim(0, 2)
292    return weight_grad_sample
293
294
295def unfold3d(
296    tensor,
297    kernel_size,
298    padding,
299    stride,
300    dilation,
301):
302    r"""
303    Extract sliding local blocks from an batched input tensor.
304
305    :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
306    This method implements the same action for 5D inputs
307    Args:
308        tensor: An input tensor of shape ``(B, C, D, H, W)``.
309        kernel_size: the size of the sliding blocks
310        padding: implicit zero padding to be added on both sides of input
311        stride: the stride of the sliding blocks in the input spatial dimensions
312        dilation: the spacing between the kernel points.
313    Returns:
314        A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions.
315        See :class:`torch.nn.Unfold` for more details
316    Example:
317        >>> # xdoctest: +SKIP
318        >>> B, C, D, H, W = 3, 4, 5, 6, 7
319        >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
320        >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
321        torch.Size([3, 32, 120])
322    """
323    if len(tensor.shape) != 5:
324        raise ValueError(
325            f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
326        )
327
328    if dilation != (1, 1, 1):
329        raise NotImplementedError(f"dilation={dilation} not supported.")
330
331    batch_size, channels, _, _, _ = tensor.shape
332
333    # Input shape: (B, C, D, H, W)
334    tensor = F.pad(
335        tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
336    )
337    # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
338
339    tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
340    tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
341    tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
342    # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
343    # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`
344
345    tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
346    # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
347
348    tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
349        1, 2
350    )
351    # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]
352
353    return tensor
354