1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import warnings 8from typing import cast, Dict, List 9 10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 11 12import numpy as np 13import torch 14from executorch.backends.qualcomm.utils.constants import QCOM_DATA 15 16from .node_visitor import NodeVisitor, register_node_visitor 17from .qnn_constants import ( 18 OpConv2d, 19 OpDepthWiseConv2d, 20 OpExpandDims, 21 OpReshape, 22 OpTransposeConv2d, 23 QNN_OP_PACKAGE_NAME_QTI_AISW, 24) 25from .utils import get_parameter 26 27 28@register_node_visitor 29class Conv2d(NodeVisitor): 30 target = ["aten.convolution.default"] 31 32 def __init__(self, *args) -> None: 33 super().__init__(*args) 34 35 def _add_conv_op_parameter( 36 self, 37 OP, 38 conv_op, 39 conv_input_tensors, 40 conv_output_tensors, 41 stride, 42 stride_shape, 43 padding, 44 padding_shape, 45 dilation, 46 dilation_shape, 47 output_padding=None, 48 output_padding_shape=None, 49 transpose_conv=False, 50 groups=None, 51 ) -> PyQnnWrapper.PyQnnOpWrapper: 52 """ 53 This function is shared among Conv1D, Conv2D, and DepthWise Conv2D as most of the required parameters overlaps. 54 """ 55 conv_op.AddInputTensors(conv_input_tensors) 56 conv_op.AddOutputTensors(conv_output_tensors) 57 conv_op.AddTensorParam( 58 OP.param_stride, 59 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 60 len(stride_shape), 61 stride_shape, 62 np.array(stride, dtype=np.uint32), 63 True, 64 ) 65 conv_op.AddTensorParam( 66 OP.param_pad_amount, 67 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 68 len(padding_shape), 69 padding_shape, 70 np.array( 71 [[padding[0], padding[0]], [padding[1], padding[1]]], 72 dtype=np.uint32, 73 ), 74 True, 75 ) 76 77 if transpose_conv: 78 conv_op.AddTensorParam( 79 OP.param_output_padding, 80 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 81 len(output_padding_shape), 82 output_padding_shape, 83 np.array(output_padding, dtype=np.uint32), 84 True, 85 ) 86 else: 87 conv_op.AddTensorParam( 88 OP.param_dilation, 89 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 90 len(dilation_shape), 91 dilation_shape, 92 np.array(dilation, dtype=np.uint32), 93 True, 94 ) 95 96 if groups is not None: 97 conv_op.AddScalarParam( 98 OP.param_group, 99 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 100 {QCOM_DATA: np.uint32(groups)}, 101 ) 102 103 return conv_op 104 105 def _define_conv1d( 106 self, 107 node: torch.fx.Node, 108 nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], 109 ) -> PyQnnWrapper.PyQnnOpWrapper: 110 """ 111 Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore, 112 we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output. 113 """ 114 transpose_conv = cast(bool, node.args[6]) 115 if transpose_conv: 116 print("ConvTranspose1d is not yet supported") 117 return 118 119 op_wrapper_list = [] # op_wrapper to return 120 unsqueeze_input_node = node.args[0] 121 input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( 122 unsqueeze_input_node, 123 ) 124 125 unsqueeze_input_tensor = self.get_tensor(unsqueeze_input_node, node) 126 unsqueeze_input_tensor_wrapper = self.define_tensor( 127 unsqueeze_input_node, 128 unsqueeze_input_tensor, 129 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 130 nodes_to_wrappers, 131 is_input_tensor=True, 132 ) 133 unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous() 134 dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs) 135 unsqueeze_output_tensor_wrapper = self.define_custom_tensor_wrapper( 136 node_name=node.name + "_unsqueeze", 137 tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 138 dtype=dtype, 139 quant_encoding=input_quant_encoding, 140 quant_configs=input_quant_configs, 141 dims=unsqueeze_output_tensor.size(), 142 tensor=unsqueeze_output_tensor, 143 is_fake_tensor=True, 144 nodes_to_wrappers=nodes_to_wrappers, 145 ) 146 unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper( 147 node.name + "_unsqueeze", 148 QNN_OP_PACKAGE_NAME_QTI_AISW, 149 OpExpandDims.op_name, 150 ) 151 unsqueeze_op.AddInputTensors([unsqueeze_input_tensor_wrapper]) 152 unsqueeze_op.AddOutputTensors([unsqueeze_output_tensor_wrapper]) 153 unsqueeze_op.AddScalarParam( 154 OpExpandDims.param_axis, 155 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 156 {QCOM_DATA: np.uint32(1)}, 157 ) 158 op_wrapper_list.append(unsqueeze_op) 159 160 filter_node = node.args[1] 161 filter_tensor = ( 162 get_parameter(filter_node, self.edge_program).unsqueeze(2).contiguous() 163 ) 164 filter_axis_order = (2, 3, 1, 0) 165 filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() 166 filter_tensor_wrapper = self.define_tensor( 167 filter_node, 168 filter_tensor, 169 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 170 nodes_to_wrappers, 171 is_input_tensor=False, 172 ) 173 conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper] 174 if node.args[2] is not None: 175 bias_node = node.args[2] 176 bias_tensor = get_parameter(bias_node, self.edge_program) 177 bias_tensor_wrapper = self.define_tensor( 178 bias_node, 179 bias_tensor, 180 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 181 nodes_to_wrappers, 182 is_input_tensor=False, 183 ) 184 conv_input_tensors.append(bias_tensor_wrapper) 185 186 stride = [1] + cast(List[int], node.args[3]) 187 padding = [0] + cast(List[int], node.args[4]) 188 dilation = [1] + cast(List[int], node.args[5]) 189 groups = cast(int, node.args[8]) 190 191 # args[6] = transposed 192 if cast(bool, node.args[6]): 193 warnings.warn( 194 "[QNN Delegate Op Builder]: Currently, No support for transposed convolution.", 195 stacklevel=1, 196 ) 197 return 198 199 # args[7] = output padding 200 if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])): 201 warnings.warn( 202 "[QNN Delegate Op Builder]: QNN does not support output padding.", 203 stacklevel=1, 204 ) 205 return 206 207 stride_shape = [len(stride)] 208 padding_shape = [2, 2] 209 dilation_shape = [len(dilation)] 210 211 conv_op = PyQnnWrapper.PyQnnOpWrapper( 212 node.name + "_squeeze", 213 QNN_OP_PACKAGE_NAME_QTI_AISW, 214 OpConv2d.op_name, 215 ) 216 conv_output_tensor = self.get_tensor(node, node) 217 conv_output_tensor = conv_output_tensor.unsqueeze(1).contiguous() 218 dtype = self.get_data_type(conv_output_tensor, input_quant_configs) 219 conv_output_tensor_wrapper = self.define_custom_tensor_wrapper( 220 node_name=node.name + "_squeeze", 221 tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 222 dtype=dtype, 223 quant_encoding=input_quant_encoding, 224 quant_configs=input_quant_configs, 225 dims=conv_output_tensor.size(), 226 tensor=conv_output_tensor, 227 is_fake_tensor=True, 228 nodes_to_wrappers=nodes_to_wrappers, 229 ) 230 conv_op = self._add_conv_op_parameter( 231 OpConv2d, 232 conv_op, 233 conv_input_tensors, 234 [conv_output_tensor_wrapper], 235 stride, 236 stride_shape, 237 padding, 238 padding_shape, 239 dilation, 240 dilation_shape, 241 groups, 242 ) 243 op_wrapper_list.append(conv_op) 244 245 squeeze_op = PyQnnWrapper.PyQnnOpWrapper( 246 node.name, 247 QNN_OP_PACKAGE_NAME_QTI_AISW, 248 OpReshape.op_name, 249 ) 250 squeeze_output_tensor = self.get_tensor(node, node) 251 squeeze_output_tensor_wrapper = self.define_tensor( 252 node, 253 squeeze_output_tensor, 254 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 255 nodes_to_wrappers, 256 is_input_tensor=False, 257 node_name=node.name, 258 ) 259 squeeze_op.AddInputTensors([conv_output_tensor_wrapper]) 260 squeeze_op.AddOutputTensors([squeeze_output_tensor_wrapper]) 261 op_wrapper_list.append(squeeze_op) 262 263 return op_wrapper_list 264 265 def define_node( 266 self, 267 node: torch.fx.Node, 268 nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], 269 ) -> PyQnnWrapper.PyQnnOpWrapper: 270 if get_parameter(node.args[1], self.edge_program).dim() == 3: 271 return self._define_conv1d(node, nodes_to_wrappers) 272 273 input_node = node.args[0] 274 input_tensor = self.get_tensor(input_node, node) 275 input_tensor_wrapper = self.define_tensor( 276 input_node, 277 input_tensor, 278 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 279 nodes_to_wrappers, 280 is_input_tensor=True, 281 ) 282 283 filter_node = node.args[1] 284 filter_tensor = get_parameter(filter_node, self.edge_program) 285 # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO 286 is_transpose_conv = cast(bool, node.args[6]) 287 filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) 288 filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() 289 filter_tensor_wrapper = self.define_tensor( 290 filter_node, 291 filter_tensor, 292 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 293 nodes_to_wrappers, 294 is_input_tensor=False, 295 ) 296 conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] 297 298 if node.args[2] is not None: 299 bias_node = node.args[2] 300 bias_tensor = get_parameter(bias_node, self.edge_program) 301 bias_tensor_wrapper = self.define_tensor( 302 bias_node, 303 bias_tensor, 304 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 305 nodes_to_wrappers, 306 is_input_tensor=False, 307 ) 308 conv_input_tensors.append(bias_tensor_wrapper) 309 310 output_tensor = self.get_tensor(node, node) 311 output_tensor_wrapper = self.define_tensor( 312 node, 313 output_tensor, 314 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 315 nodes_to_wrappers, 316 is_input_tensor=False, 317 ) 318 conv_output_tensors = [output_tensor_wrapper] 319 320 stride = cast(List[int], node.args[3]) 321 padding = cast(List[int], node.args[4]) 322 dilation = cast(List[int], node.args[5]) 323 output_padding = cast(List[int], node.args[7]) 324 325 groups = cast(int, node.args[8]) 326 # Qnn filter tensor is (H, W, Cin, Cout) 327 group_input_channels = filter_tensor.shape[2] 328 group_output_channels = int(filter_tensor.shape[3] / groups) 329 # 1) groups = input_channels (i.e. group_input_channels = 1) 330 # 2) output_channels is a positive integer multiple of input channels 331 # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1 332 # and test on QNN 2.14 rc1. Need to carefully investigate. 333 is_depthwise_conv = ( 334 (group_input_channels == 1) 335 and (group_output_channels % group_input_channels == 0) 336 and (groups > 2) 337 ) 338 if len(padding) == 1: 339 padding = padding + padding 340 341 stride_shape = [len(stride)] 342 padding_shape = [2, 2] 343 dilation_shape = [len(dilation)] 344 output_padding_shape = [len(output_padding)] 345 346 if is_depthwise_conv: 347 op_class = OpDepthWiseConv2d 348 elif is_transpose_conv: 349 op_class = OpTransposeConv2d 350 else: 351 op_class = OpConv2d 352 353 conv_op = PyQnnWrapper.PyQnnOpWrapper( 354 node.name, 355 QNN_OP_PACKAGE_NAME_QTI_AISW, 356 op_class.op_name, 357 ) 358 conv_op = self._add_conv_op_parameter( 359 op_class, 360 conv_op, 361 conv_input_tensors, 362 conv_output_tensors, 363 stride, 364 stride_shape, 365 padding, 366 padding_shape, 367 dilation, 368 dilation_shape, 369 output_padding, 370 output_padding_shape, 371 is_transpose_conv, 372 None if is_depthwise_conv else groups, 373 ) 374 375 return conv_op 376