xref: /aosp_15_r20/external/pytorch/torch/nn/grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Gradient interface."""
3
4import torch
5from torch.nn.modules.utils import _pair, _single, _triple
6
7
8def conv1d_input(
9    input_size,
10    weight,
11    grad_output,
12    stride=1,
13    padding=0,
14    dilation=1,
15    groups=1,
16):
17    r"""Compute the gradient of conv1d with respect to the input of the convolution.
18
19    This is same as the 1D transposed convolution operator under the hood but requires
20    the shape of the gradient w.r.t. input to be specified explicitly.
21
22    Args:
23        input_size : Shape of the input gradient tensor
24        weight: weight tensor (out_channels x in_channels/groups x kW)
25        grad_output : output gradient tensor (minibatch x out_channels x oW)
26        stride (int or tuple, optional): Stride of the convolution. Default: 1
27        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
28        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
29        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
30
31    Examples::
32
33        >>> input = torch.randn(1, 1, 3, requires_grad=True)
34        >>> weight = torch.randn(1, 1, 1, requires_grad=True)
35        >>> output = F.conv1d(input, weight)
36        >>> grad_output = torch.randn(output.shape)
37        >>> grad_input = torch.autograd.grad(output, input, grad_output)
38        >>> F.grad.conv1d_input(input.shape, weight, grad_output)
39
40    """
41    input = grad_output.new_empty(1).expand(input_size)
42
43    return torch.ops.aten.convolution_backward(
44        grad_output,
45        input,
46        weight,
47        None,
48        _single(stride),
49        _single(padding),
50        _single(dilation),
51        False,
52        [0],
53        groups,
54        (True, False, False),
55    )[0]
56
57
58def conv1d_weight(
59    input,
60    weight_size,
61    grad_output,
62    stride=1,
63    padding=0,
64    dilation=1,
65    groups=1,
66):
67    r"""Compute the gradient of conv1d with respect to the weight of the convolution.
68
69    Args:
70        input: input tensor of shape (minibatch x in_channels x iW)
71        weight_size : Shape of the weight gradient tensor
72        grad_output : output gradient tensor (minibatch x out_channels x oW)
73        stride (int or tuple, optional): Stride of the convolution. Default: 1
74        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
75        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
76        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
77
78    Examples::
79
80        >>> input = torch.randn(1, 1, 3, requires_grad=True)
81        >>> weight = torch.randn(1, 1, 1, requires_grad=True)
82        >>> output = F.conv1d(input, weight)
83        >>> grad_output = torch.randn(output.shape)
84        >>> # xdoctest: +SKIP
85        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
86        >>> F.grad.conv1d_weight(input, weight.shape, grad_output)
87
88    """
89    weight = grad_output.new_empty(1).expand(weight_size)
90
91    return torch.ops.aten.convolution_backward(
92        grad_output,
93        input,
94        weight,
95        None,
96        _single(stride),
97        _single(padding),
98        _single(dilation),
99        False,
100        [0],
101        groups,
102        (False, True, False),
103    )[1]
104
105
106def conv2d_input(
107    input_size,
108    weight,
109    grad_output,
110    stride=1,
111    padding=0,
112    dilation=1,
113    groups=1,
114):
115    r"""Compute the gradient of conv2d with respect to the input of the convolution.
116
117    This is same as the 2D transposed convolution operator under the hood but requires
118    the shape of the gradient w.r.t. input to be specified explicitly.
119
120    Args:
121        input_size : Shape of the input gradient tensor
122        weight: weight tensor (out_channels x in_channels/groups x kH x kW)
123        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
124        stride (int or tuple, optional): Stride of the convolution. Default: 1
125        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
126        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
127        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
128
129    Examples::
130
131        >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
132        >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
133        >>> output = F.conv2d(input, weight)
134        >>> grad_output = torch.randn(output.shape)
135        >>> grad_input = torch.autograd.grad(output, input, grad_output)
136        >>> F.grad.conv2d_input(input.shape, weight, grad_output)
137
138    """
139    input = grad_output.new_empty(1).expand(input_size)
140
141    return torch.ops.aten.convolution_backward(
142        grad_output,
143        input,
144        weight,
145        None,
146        _pair(stride),
147        _pair(padding),
148        _pair(dilation),
149        False,
150        [0],
151        groups,
152        (True, False, False),
153    )[0]
154
155
156def conv2d_weight(
157    input,
158    weight_size,
159    grad_output,
160    stride=1,
161    padding=0,
162    dilation=1,
163    groups=1,
164):
165    r"""Compute the gradient of conv2d with respect to the weight of the convolution.
166
167    Args:
168        input: input tensor of shape (minibatch x in_channels x iH x iW)
169        weight_size : Shape of the weight gradient tensor
170        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
171        stride (int or tuple, optional): Stride of the convolution. Default: 1
172        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
173        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
174        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
175
176    Examples::
177
178        >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
179        >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
180        >>> output = F.conv2d(input, weight)
181        >>> grad_output = torch.randn(output.shape)
182        >>> # xdoctest: +SKIP
183        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
184        >>> F.grad.conv2d_weight(input, weight.shape, grad_output)
185
186    """
187    weight = grad_output.new_empty(1).expand(weight_size)
188
189    return torch.ops.aten.convolution_backward(
190        grad_output,
191        input,
192        weight,
193        None,
194        _pair(stride),
195        _pair(padding),
196        _pair(dilation),
197        False,
198        [0],
199        groups,
200        (False, True, False),
201    )[1]
202
203
204def conv3d_input(
205    input_size,
206    weight,
207    grad_output,
208    stride=1,
209    padding=0,
210    dilation=1,
211    groups=1,
212):
213    r"""Compute the gradient of conv3d with respect to the input of the convolution.
214
215    This is same as the 3D transposed convolution operator under the hood but requires
216    the shape of the gradient w.r.t. input to be specified explicitly.
217
218    Args:
219        input_size : Shape of the input gradient tensor
220        weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
221        grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
222        stride (int or tuple, optional): Stride of the convolution. Default: 1
223        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
224        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
225        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
226
227    Examples::
228
229        >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
230        >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
231        >>> output = F.conv3d(input, weight)
232        >>> grad_output = torch.randn(output.shape)
233        >>> grad_input = torch.autograd.grad(output, input, grad_output)
234        >>> F.grad.conv3d_input(input.shape, weight, grad_output)
235
236    """
237    input = grad_output.new_empty(1).expand(input_size)
238
239    return torch.ops.aten.convolution_backward(
240        grad_output,
241        input,
242        weight,
243        None,
244        _triple(stride),
245        _triple(padding),
246        _triple(dilation),
247        False,
248        [0],
249        groups,
250        (True, False, False),
251    )[0]
252
253
254def conv3d_weight(
255    input,
256    weight_size,
257    grad_output,
258    stride=1,
259    padding=0,
260    dilation=1,
261    groups=1,
262):
263    r"""Compute the gradient of conv3d with respect to the weight of the convolution.
264
265    Args:
266        input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
267        weight_size : Shape of the weight gradient tensor
268        grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
269        stride (int or tuple, optional): Stride of the convolution. Default: 1
270        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
271        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
272        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
273
274    Examples::
275
276        >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
277        >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
278        >>> output = F.conv3d(input, weight)
279        >>> grad_output = torch.randn(output.shape)
280        >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
281        >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
282
283    """
284    weight = grad_output.new_empty(1).expand(weight_size)
285
286    return torch.ops.aten.convolution_backward(
287        grad_output,
288        input,
289        weight,
290        None,
291        _triple(stride),
292        _triple(padding),
293        _triple(dilation),
294        False,
295        [0],
296        groups,
297        (False, True, False),
298    )[1]
299