xref: /aosp_15_r20/external/pytorch/torch/nn/grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""Gradient interface."""
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.utils import _pair, _single, _triple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerdef conv1d_input(
9*da0073e9SAndroid Build Coastguard Worker    input_size,
10*da0073e9SAndroid Build Coastguard Worker    weight,
11*da0073e9SAndroid Build Coastguard Worker    grad_output,
12*da0073e9SAndroid Build Coastguard Worker    stride=1,
13*da0073e9SAndroid Build Coastguard Worker    padding=0,
14*da0073e9SAndroid Build Coastguard Worker    dilation=1,
15*da0073e9SAndroid Build Coastguard Worker    groups=1,
16*da0073e9SAndroid Build Coastguard Worker):
17*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv1d with respect to the input of the convolution.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    This is same as the 1D transposed convolution operator under the hood but requires
20*da0073e9SAndroid Build Coastguard Worker    the shape of the gradient w.r.t. input to be specified explicitly.
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    Args:
23*da0073e9SAndroid Build Coastguard Worker        input_size : Shape of the input gradient tensor
24*da0073e9SAndroid Build Coastguard Worker        weight: weight tensor (out_channels x in_channels/groups x kW)
25*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oW)
26*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
27*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
28*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
29*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    Examples::
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(1, 1, 3, requires_grad=True)
34*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(1, 1, 1, requires_grad=True)
35*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv1d(input, weight)
36*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
37*da0073e9SAndroid Build Coastguard Worker        >>> grad_input = torch.autograd.grad(output, input, grad_output)
38*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv1d_input(input.shape, weight, grad_output)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    """
41*da0073e9SAndroid Build Coastguard Worker    input = grad_output.new_empty(1).expand(input_size)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
44*da0073e9SAndroid Build Coastguard Worker        grad_output,
45*da0073e9SAndroid Build Coastguard Worker        input,
46*da0073e9SAndroid Build Coastguard Worker        weight,
47*da0073e9SAndroid Build Coastguard Worker        None,
48*da0073e9SAndroid Build Coastguard Worker        _single(stride),
49*da0073e9SAndroid Build Coastguard Worker        _single(padding),
50*da0073e9SAndroid Build Coastguard Worker        _single(dilation),
51*da0073e9SAndroid Build Coastguard Worker        False,
52*da0073e9SAndroid Build Coastguard Worker        [0],
53*da0073e9SAndroid Build Coastguard Worker        groups,
54*da0073e9SAndroid Build Coastguard Worker        (True, False, False),
55*da0073e9SAndroid Build Coastguard Worker    )[0]
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef conv1d_weight(
59*da0073e9SAndroid Build Coastguard Worker    input,
60*da0073e9SAndroid Build Coastguard Worker    weight_size,
61*da0073e9SAndroid Build Coastguard Worker    grad_output,
62*da0073e9SAndroid Build Coastguard Worker    stride=1,
63*da0073e9SAndroid Build Coastguard Worker    padding=0,
64*da0073e9SAndroid Build Coastguard Worker    dilation=1,
65*da0073e9SAndroid Build Coastguard Worker    groups=1,
66*da0073e9SAndroid Build Coastguard Worker):
67*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv1d with respect to the weight of the convolution.
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    Args:
70*da0073e9SAndroid Build Coastguard Worker        input: input tensor of shape (minibatch x in_channels x iW)
71*da0073e9SAndroid Build Coastguard Worker        weight_size : Shape of the weight gradient tensor
72*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oW)
73*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
74*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
75*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
76*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    Examples::
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(1, 1, 3, requires_grad=True)
81*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(1, 1, 1, requires_grad=True)
82*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv1d(input, weight)
83*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
84*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
85*da0073e9SAndroid Build Coastguard Worker        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
86*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv1d_weight(input, weight.shape, grad_output)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    """
89*da0073e9SAndroid Build Coastguard Worker    weight = grad_output.new_empty(1).expand(weight_size)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
92*da0073e9SAndroid Build Coastguard Worker        grad_output,
93*da0073e9SAndroid Build Coastguard Worker        input,
94*da0073e9SAndroid Build Coastguard Worker        weight,
95*da0073e9SAndroid Build Coastguard Worker        None,
96*da0073e9SAndroid Build Coastguard Worker        _single(stride),
97*da0073e9SAndroid Build Coastguard Worker        _single(padding),
98*da0073e9SAndroid Build Coastguard Worker        _single(dilation),
99*da0073e9SAndroid Build Coastguard Worker        False,
100*da0073e9SAndroid Build Coastguard Worker        [0],
101*da0073e9SAndroid Build Coastguard Worker        groups,
102*da0073e9SAndroid Build Coastguard Worker        (False, True, False),
103*da0073e9SAndroid Build Coastguard Worker    )[1]
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Workerdef conv2d_input(
107*da0073e9SAndroid Build Coastguard Worker    input_size,
108*da0073e9SAndroid Build Coastguard Worker    weight,
109*da0073e9SAndroid Build Coastguard Worker    grad_output,
110*da0073e9SAndroid Build Coastguard Worker    stride=1,
111*da0073e9SAndroid Build Coastguard Worker    padding=0,
112*da0073e9SAndroid Build Coastguard Worker    dilation=1,
113*da0073e9SAndroid Build Coastguard Worker    groups=1,
114*da0073e9SAndroid Build Coastguard Worker):
115*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv2d with respect to the input of the convolution.
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    This is same as the 2D transposed convolution operator under the hood but requires
118*da0073e9SAndroid Build Coastguard Worker    the shape of the gradient w.r.t. input to be specified explicitly.
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    Args:
121*da0073e9SAndroid Build Coastguard Worker        input_size : Shape of the input gradient tensor
122*da0073e9SAndroid Build Coastguard Worker        weight: weight tensor (out_channels x in_channels/groups x kH x kW)
123*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
124*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
125*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
126*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
127*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    Examples::
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
132*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
133*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv2d(input, weight)
134*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
135*da0073e9SAndroid Build Coastguard Worker        >>> grad_input = torch.autograd.grad(output, input, grad_output)
136*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv2d_input(input.shape, weight, grad_output)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    """
139*da0073e9SAndroid Build Coastguard Worker    input = grad_output.new_empty(1).expand(input_size)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
142*da0073e9SAndroid Build Coastguard Worker        grad_output,
143*da0073e9SAndroid Build Coastguard Worker        input,
144*da0073e9SAndroid Build Coastguard Worker        weight,
145*da0073e9SAndroid Build Coastguard Worker        None,
146*da0073e9SAndroid Build Coastguard Worker        _pair(stride),
147*da0073e9SAndroid Build Coastguard Worker        _pair(padding),
148*da0073e9SAndroid Build Coastguard Worker        _pair(dilation),
149*da0073e9SAndroid Build Coastguard Worker        False,
150*da0073e9SAndroid Build Coastguard Worker        [0],
151*da0073e9SAndroid Build Coastguard Worker        groups,
152*da0073e9SAndroid Build Coastguard Worker        (True, False, False),
153*da0073e9SAndroid Build Coastguard Worker    )[0]
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerdef conv2d_weight(
157*da0073e9SAndroid Build Coastguard Worker    input,
158*da0073e9SAndroid Build Coastguard Worker    weight_size,
159*da0073e9SAndroid Build Coastguard Worker    grad_output,
160*da0073e9SAndroid Build Coastguard Worker    stride=1,
161*da0073e9SAndroid Build Coastguard Worker    padding=0,
162*da0073e9SAndroid Build Coastguard Worker    dilation=1,
163*da0073e9SAndroid Build Coastguard Worker    groups=1,
164*da0073e9SAndroid Build Coastguard Worker):
165*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv2d with respect to the weight of the convolution.
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    Args:
168*da0073e9SAndroid Build Coastguard Worker        input: input tensor of shape (minibatch x in_channels x iH x iW)
169*da0073e9SAndroid Build Coastguard Worker        weight_size : Shape of the weight gradient tensor
170*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
171*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
172*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
173*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
174*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    Examples::
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
179*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
180*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv2d(input, weight)
181*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
182*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
183*da0073e9SAndroid Build Coastguard Worker        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
184*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv2d_weight(input, weight.shape, grad_output)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    """
187*da0073e9SAndroid Build Coastguard Worker    weight = grad_output.new_empty(1).expand(weight_size)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
190*da0073e9SAndroid Build Coastguard Worker        grad_output,
191*da0073e9SAndroid Build Coastguard Worker        input,
192*da0073e9SAndroid Build Coastguard Worker        weight,
193*da0073e9SAndroid Build Coastguard Worker        None,
194*da0073e9SAndroid Build Coastguard Worker        _pair(stride),
195*da0073e9SAndroid Build Coastguard Worker        _pair(padding),
196*da0073e9SAndroid Build Coastguard Worker        _pair(dilation),
197*da0073e9SAndroid Build Coastguard Worker        False,
198*da0073e9SAndroid Build Coastguard Worker        [0],
199*da0073e9SAndroid Build Coastguard Worker        groups,
200*da0073e9SAndroid Build Coastguard Worker        (False, True, False),
201*da0073e9SAndroid Build Coastguard Worker    )[1]
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Workerdef conv3d_input(
205*da0073e9SAndroid Build Coastguard Worker    input_size,
206*da0073e9SAndroid Build Coastguard Worker    weight,
207*da0073e9SAndroid Build Coastguard Worker    grad_output,
208*da0073e9SAndroid Build Coastguard Worker    stride=1,
209*da0073e9SAndroid Build Coastguard Worker    padding=0,
210*da0073e9SAndroid Build Coastguard Worker    dilation=1,
211*da0073e9SAndroid Build Coastguard Worker    groups=1,
212*da0073e9SAndroid Build Coastguard Worker):
213*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv3d with respect to the input of the convolution.
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    This is same as the 3D transposed convolution operator under the hood but requires
216*da0073e9SAndroid Build Coastguard Worker    the shape of the gradient w.r.t. input to be specified explicitly.
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    Args:
219*da0073e9SAndroid Build Coastguard Worker        input_size : Shape of the input gradient tensor
220*da0073e9SAndroid Build Coastguard Worker        weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
221*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
222*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
223*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
224*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
225*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    Examples::
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
230*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
231*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv3d(input, weight)
232*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
233*da0073e9SAndroid Build Coastguard Worker        >>> grad_input = torch.autograd.grad(output, input, grad_output)
234*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv3d_input(input.shape, weight, grad_output)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    """
237*da0073e9SAndroid Build Coastguard Worker    input = grad_output.new_empty(1).expand(input_size)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
240*da0073e9SAndroid Build Coastguard Worker        grad_output,
241*da0073e9SAndroid Build Coastguard Worker        input,
242*da0073e9SAndroid Build Coastguard Worker        weight,
243*da0073e9SAndroid Build Coastguard Worker        None,
244*da0073e9SAndroid Build Coastguard Worker        _triple(stride),
245*da0073e9SAndroid Build Coastguard Worker        _triple(padding),
246*da0073e9SAndroid Build Coastguard Worker        _triple(dilation),
247*da0073e9SAndroid Build Coastguard Worker        False,
248*da0073e9SAndroid Build Coastguard Worker        [0],
249*da0073e9SAndroid Build Coastguard Worker        groups,
250*da0073e9SAndroid Build Coastguard Worker        (True, False, False),
251*da0073e9SAndroid Build Coastguard Worker    )[0]
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Workerdef conv3d_weight(
255*da0073e9SAndroid Build Coastguard Worker    input,
256*da0073e9SAndroid Build Coastguard Worker    weight_size,
257*da0073e9SAndroid Build Coastguard Worker    grad_output,
258*da0073e9SAndroid Build Coastguard Worker    stride=1,
259*da0073e9SAndroid Build Coastguard Worker    padding=0,
260*da0073e9SAndroid Build Coastguard Worker    dilation=1,
261*da0073e9SAndroid Build Coastguard Worker    groups=1,
262*da0073e9SAndroid Build Coastguard Worker):
263*da0073e9SAndroid Build Coastguard Worker    r"""Compute the gradient of conv3d with respect to the weight of the convolution.
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    Args:
266*da0073e9SAndroid Build Coastguard Worker        input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
267*da0073e9SAndroid Build Coastguard Worker        weight_size : Shape of the weight gradient tensor
268*da0073e9SAndroid Build Coastguard Worker        grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
269*da0073e9SAndroid Build Coastguard Worker        stride (int or tuple, optional): Stride of the convolution. Default: 1
270*da0073e9SAndroid Build Coastguard Worker        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
271*da0073e9SAndroid Build Coastguard Worker        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
272*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    Examples::
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
277*da0073e9SAndroid Build Coastguard Worker        >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
278*da0073e9SAndroid Build Coastguard Worker        >>> output = F.conv3d(input, weight)
279*da0073e9SAndroid Build Coastguard Worker        >>> grad_output = torch.randn(output.shape)
280*da0073e9SAndroid Build Coastguard Worker        >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
281*da0073e9SAndroid Build Coastguard Worker        >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    """
284*da0073e9SAndroid Build Coastguard Worker    weight = grad_output.new_empty(1).expand(weight_size)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    return torch.ops.aten.convolution_backward(
287*da0073e9SAndroid Build Coastguard Worker        grad_output,
288*da0073e9SAndroid Build Coastguard Worker        input,
289*da0073e9SAndroid Build Coastguard Worker        weight,
290*da0073e9SAndroid Build Coastguard Worker        None,
291*da0073e9SAndroid Build Coastguard Worker        _triple(stride),
292*da0073e9SAndroid Build Coastguard Worker        _triple(padding),
293*da0073e9SAndroid Build Coastguard Worker        _triple(dilation),
294*da0073e9SAndroid Build Coastguard Worker        False,
295*da0073e9SAndroid Build Coastguard Worker        [0],
296*da0073e9SAndroid Build Coastguard Worker        groups,
297*da0073e9SAndroid Build Coastguard Worker        (False, True, False),
298*da0073e9SAndroid Build Coastguard Worker    )[1]
299