xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_conv2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import warnings
8from typing import cast, Dict, List
9
10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
11
12import numpy as np
13import torch
14from executorch.backends.qualcomm.utils.constants import QCOM_DATA
15
16from .node_visitor import NodeVisitor, register_node_visitor
17from .qnn_constants import (
18    OpConv2d,
19    OpDepthWiseConv2d,
20    OpExpandDims,
21    OpReshape,
22    OpTransposeConv2d,
23    QNN_OP_PACKAGE_NAME_QTI_AISW,
24)
25from .utils import get_parameter
26
27
28@register_node_visitor
29class Conv2d(NodeVisitor):
30    target = ["aten.convolution.default"]
31
32    def __init__(self, *args) -> None:
33        super().__init__(*args)
34
35    def _add_conv_op_parameter(
36        self,
37        OP,
38        conv_op,
39        conv_input_tensors,
40        conv_output_tensors,
41        stride,
42        stride_shape,
43        padding,
44        padding_shape,
45        dilation,
46        dilation_shape,
47        output_padding=None,
48        output_padding_shape=None,
49        transpose_conv=False,
50        groups=None,
51    ) -> PyQnnWrapper.PyQnnOpWrapper:
52        """
53        This function is shared among Conv1D, Conv2D, and DepthWise Conv2D as most of the required parameters overlaps.
54        """
55        conv_op.AddInputTensors(conv_input_tensors)
56        conv_op.AddOutputTensors(conv_output_tensors)
57        conv_op.AddTensorParam(
58            OP.param_stride,
59            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
60            len(stride_shape),
61            stride_shape,
62            np.array(stride, dtype=np.uint32),
63            True,
64        )
65        conv_op.AddTensorParam(
66            OP.param_pad_amount,
67            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
68            len(padding_shape),
69            padding_shape,
70            np.array(
71                [[padding[0], padding[0]], [padding[1], padding[1]]],
72                dtype=np.uint32,
73            ),
74            True,
75        )
76
77        if transpose_conv:
78            conv_op.AddTensorParam(
79                OP.param_output_padding,
80                PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
81                len(output_padding_shape),
82                output_padding_shape,
83                np.array(output_padding, dtype=np.uint32),
84                True,
85            )
86        else:
87            conv_op.AddTensorParam(
88                OP.param_dilation,
89                PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
90                len(dilation_shape),
91                dilation_shape,
92                np.array(dilation, dtype=np.uint32),
93                True,
94            )
95
96        if groups is not None:
97            conv_op.AddScalarParam(
98                OP.param_group,
99                PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
100                {QCOM_DATA: np.uint32(groups)},
101            )
102
103        return conv_op
104
105    def _define_conv1d(
106        self,
107        node: torch.fx.Node,
108        nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
109    ) -> PyQnnWrapper.PyQnnOpWrapper:
110        """
111        Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore,
112        we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output.
113        """
114        transpose_conv = cast(bool, node.args[6])
115        if transpose_conv:
116            print("ConvTranspose1d is not yet supported")
117            return
118
119        op_wrapper_list = []  # op_wrapper to return
120        unsqueeze_input_node = node.args[0]
121        input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
122            unsqueeze_input_node,
123        )
124
125        unsqueeze_input_tensor = self.get_tensor(unsqueeze_input_node, node)
126        unsqueeze_input_tensor_wrapper = self.define_tensor(
127            unsqueeze_input_node,
128            unsqueeze_input_tensor,
129            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
130            nodes_to_wrappers,
131            is_input_tensor=True,
132        )
133        unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous()
134        dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs)
135        unsqueeze_output_tensor_wrapper = self.define_custom_tensor_wrapper(
136            node_name=node.name + "_unsqueeze",
137            tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
138            dtype=dtype,
139            quant_encoding=input_quant_encoding,
140            quant_configs=input_quant_configs,
141            dims=unsqueeze_output_tensor.size(),
142            tensor=unsqueeze_output_tensor,
143            is_fake_tensor=True,
144            nodes_to_wrappers=nodes_to_wrappers,
145        )
146        unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper(
147            node.name + "_unsqueeze",
148            QNN_OP_PACKAGE_NAME_QTI_AISW,
149            OpExpandDims.op_name,
150        )
151        unsqueeze_op.AddInputTensors([unsqueeze_input_tensor_wrapper])
152        unsqueeze_op.AddOutputTensors([unsqueeze_output_tensor_wrapper])
153        unsqueeze_op.AddScalarParam(
154            OpExpandDims.param_axis,
155            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
156            {QCOM_DATA: np.uint32(1)},
157        )
158        op_wrapper_list.append(unsqueeze_op)
159
160        filter_node = node.args[1]
161        filter_tensor = (
162            get_parameter(filter_node, self.edge_program).unsqueeze(2).contiguous()
163        )
164        filter_axis_order = (2, 3, 1, 0)
165        filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
166        filter_tensor_wrapper = self.define_tensor(
167            filter_node,
168            filter_tensor,
169            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
170            nodes_to_wrappers,
171            is_input_tensor=False,
172        )
173        conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
174        if node.args[2] is not None:
175            bias_node = node.args[2]
176            bias_tensor = get_parameter(bias_node, self.edge_program)
177            bias_tensor_wrapper = self.define_tensor(
178                bias_node,
179                bias_tensor,
180                PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
181                nodes_to_wrappers,
182                is_input_tensor=False,
183            )
184            conv_input_tensors.append(bias_tensor_wrapper)
185
186        stride = [1] + cast(List[int], node.args[3])
187        padding = [0] + cast(List[int], node.args[4])
188        dilation = [1] + cast(List[int], node.args[5])
189        groups = cast(int, node.args[8])
190
191        # args[6] = transposed
192        if cast(bool, node.args[6]):
193            warnings.warn(
194                "[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
195                stacklevel=1,
196            )
197            return
198
199        # args[7] = output padding
200        if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
201            warnings.warn(
202                "[QNN Delegate Op Builder]: QNN does not support output padding.",
203                stacklevel=1,
204            )
205            return
206
207        stride_shape = [len(stride)]
208        padding_shape = [2, 2]
209        dilation_shape = [len(dilation)]
210
211        conv_op = PyQnnWrapper.PyQnnOpWrapper(
212            node.name + "_squeeze",
213            QNN_OP_PACKAGE_NAME_QTI_AISW,
214            OpConv2d.op_name,
215        )
216        conv_output_tensor = self.get_tensor(node, node)
217        conv_output_tensor = conv_output_tensor.unsqueeze(1).contiguous()
218        dtype = self.get_data_type(conv_output_tensor, input_quant_configs)
219        conv_output_tensor_wrapper = self.define_custom_tensor_wrapper(
220            node_name=node.name + "_squeeze",
221            tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
222            dtype=dtype,
223            quant_encoding=input_quant_encoding,
224            quant_configs=input_quant_configs,
225            dims=conv_output_tensor.size(),
226            tensor=conv_output_tensor,
227            is_fake_tensor=True,
228            nodes_to_wrappers=nodes_to_wrappers,
229        )
230        conv_op = self._add_conv_op_parameter(
231            OpConv2d,
232            conv_op,
233            conv_input_tensors,
234            [conv_output_tensor_wrapper],
235            stride,
236            stride_shape,
237            padding,
238            padding_shape,
239            dilation,
240            dilation_shape,
241            groups,
242        )
243        op_wrapper_list.append(conv_op)
244
245        squeeze_op = PyQnnWrapper.PyQnnOpWrapper(
246            node.name,
247            QNN_OP_PACKAGE_NAME_QTI_AISW,
248            OpReshape.op_name,
249        )
250        squeeze_output_tensor = self.get_tensor(node, node)
251        squeeze_output_tensor_wrapper = self.define_tensor(
252            node,
253            squeeze_output_tensor,
254            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
255            nodes_to_wrappers,
256            is_input_tensor=False,
257            node_name=node.name,
258        )
259        squeeze_op.AddInputTensors([conv_output_tensor_wrapper])
260        squeeze_op.AddOutputTensors([squeeze_output_tensor_wrapper])
261        op_wrapper_list.append(squeeze_op)
262
263        return op_wrapper_list
264
265    def define_node(
266        self,
267        node: torch.fx.Node,
268        nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
269    ) -> PyQnnWrapper.PyQnnOpWrapper:
270        if get_parameter(node.args[1], self.edge_program).dim() == 3:
271            return self._define_conv1d(node, nodes_to_wrappers)
272
273        input_node = node.args[0]
274        input_tensor = self.get_tensor(input_node, node)
275        input_tensor_wrapper = self.define_tensor(
276            input_node,
277            input_tensor,
278            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
279            nodes_to_wrappers,
280            is_input_tensor=True,
281        )
282
283        filter_node = node.args[1]
284        filter_tensor = get_parameter(filter_node, self.edge_program)
285        # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
286        is_transpose_conv = cast(bool, node.args[6])
287        filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
288        filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
289        filter_tensor_wrapper = self.define_tensor(
290            filter_node,
291            filter_tensor,
292            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
293            nodes_to_wrappers,
294            is_input_tensor=False,
295        )
296        conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
297
298        if node.args[2] is not None:
299            bias_node = node.args[2]
300            bias_tensor = get_parameter(bias_node, self.edge_program)
301            bias_tensor_wrapper = self.define_tensor(
302                bias_node,
303                bias_tensor,
304                PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
305                nodes_to_wrappers,
306                is_input_tensor=False,
307            )
308            conv_input_tensors.append(bias_tensor_wrapper)
309
310        output_tensor = self.get_tensor(node, node)
311        output_tensor_wrapper = self.define_tensor(
312            node,
313            output_tensor,
314            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
315            nodes_to_wrappers,
316            is_input_tensor=False,
317        )
318        conv_output_tensors = [output_tensor_wrapper]
319
320        stride = cast(List[int], node.args[3])
321        padding = cast(List[int], node.args[4])
322        dilation = cast(List[int], node.args[5])
323        output_padding = cast(List[int], node.args[7])
324
325        groups = cast(int, node.args[8])
326        # Qnn filter tensor is (H, W, Cin, Cout)
327        group_input_channels = filter_tensor.shape[2]
328        group_output_channels = int(filter_tensor.shape[3] / groups)
329        # 1) groups = input_channels (i.e. group_input_channels = 1)
330        # 2) output_channels is a positive integer multiple of input channels
331        # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
332        # and test on QNN 2.14 rc1. Need to carefully investigate.
333        is_depthwise_conv = (
334            (group_input_channels == 1)
335            and (group_output_channels % group_input_channels == 0)
336            and (groups > 2)
337        )
338        if len(padding) == 1:
339            padding = padding + padding
340
341        stride_shape = [len(stride)]
342        padding_shape = [2, 2]
343        dilation_shape = [len(dilation)]
344        output_padding_shape = [len(output_padding)]
345
346        if is_depthwise_conv:
347            op_class = OpDepthWiseConv2d
348        elif is_transpose_conv:
349            op_class = OpTransposeConv2d
350        else:
351            op_class = OpConv2d
352
353        conv_op = PyQnnWrapper.PyQnnOpWrapper(
354            node.name,
355            QNN_OP_PACKAGE_NAME_QTI_AISW,
356            op_class.op_name,
357        )
358        conv_op = self._add_conv_op_parameter(
359            op_class,
360            conv_op,
361            conv_input_tensors,
362            conv_output_tensors,
363            stride,
364            stride_shape,
365            padding,
366            padding_shape,
367            dilation,
368            dilation_shape,
369            output_padding,
370            output_padding_shape,
371            is_transpose_conv,
372            None if is_depthwise_conv else groups,
373        )
374
375        return conv_op
376