1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8from typing import Any, Dict, Tuple 9 10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 11 12import numpy as np 13import torch 14from executorch.backends.qualcomm.utils.constants import ( 15 QCOM_AXIS, 16 QCOM_AXIS_ORDER, 17 QCOM_BITWIDTH, 18 QCOM_DTYPE, 19 QCOM_ENCODING, 20 QCOM_OFFSET, 21 QCOM_QUANT_ATTRS, 22 QCOM_QUANT_MAX, 23 QCOM_QUANT_MIN, 24 QCOM_REQUANTIZE, 25 QCOM_SCALE, 26 QCOM_SCALE_OFFSET, 27 QCOM_SCALES, 28 QCOM_ZERO_POINT, 29 QCOM_ZERO_POINTS, 30) 31 32from executorch.exir.dialects._ops import ops as exir_ops 33 34from .utils import ( 35 deduce_dtype, 36 get_parameter, 37 is_graph_input, 38 is_graph_output, 39 is_parameter, 40) 41 42 43QNN_QUANT_TYPE_MAP = { 44 torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, 45 torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, 46 torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32, 47 # Note that there is no int64 tensor data type in Qnn. 48 torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, 49 torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, 50 torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, 51} 52QNN_TENSOR_TYPE_MAP = { 53 torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 54 torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 55 torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, 56 torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, 57 torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, 58 torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, 59 torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, 60 torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, 61 float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 62} 63 64PER_CHANNEL_ENCODING = { 65 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 66 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 67} 68 69PER_TENSOR_ENCODING = { 70 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 71 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 72 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 73 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 74} 75 76 77class NodeVisitor: 78 """ 79 Node visitor pattern for visiting nodes in an edge IR graph 80 """ 81 82 def __init__( 83 self, 84 external_ids, 85 edge_program: torch.export.ExportedProgram, 86 enable_tensor_dump, 87 ) -> None: 88 self.external_ids = external_ids or {} 89 self.edge_program = edge_program 90 self.enable_tensor_dump = enable_tensor_dump 91 92 def get_tensor(self, input_node, op_node, idx=None): 93 """ 94 Get tensor value/shape with axis_order 95 """ 96 97 def _get_tensor(node, index): 98 if index is not None: 99 assert isinstance(index, int) 100 if is_parameter(node, self.edge_program): 101 return get_parameter(node, self.edge_program)[index] 102 return node.meta["val"][index] 103 104 if is_parameter(node, self.edge_program): 105 return get_parameter(node, self.edge_program) 106 return node.meta["val"] 107 108 tensor = _get_tensor(input_node, idx) 109 if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta: 110 tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous() 111 return tensor 112 113 def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): 114 quant_config = copy.deepcopy(quant_attrs) 115 116 scales = quant_attrs[QCOM_SCALES] 117 zero_points = quant_attrs[QCOM_ZERO_POINTS] 118 assert len(scales) == len( 119 zero_points 120 ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}" 121 122 scale_offset = [] 123 for i in range(len(scales)): 124 # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h 125 scale_offset.append( 126 PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) 127 ) 128 129 user_0 = list(node.users)[0] 130 # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO 131 if ( 132 "convolution" in user_0.target.__name__ 133 and list(node.users)[0].args[1] == node 134 ): 135 quant_config[QCOM_AXIS] = 3 136 137 else: 138 quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] 139 140 quant_config[QCOM_SCALE_OFFSET] = scale_offset 141 # special case for 4 bits 142 if ( 143 quant_config[QCOM_DTYPE] == torch.int8 144 and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15 145 ): 146 quant_config[QCOM_BITWIDTH] = 4 147 return ( 148 PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, 149 quant_config, 150 ) 151 return ( 152 PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, 153 quant_config, 154 ) 155 156 def make_qnn_per_tensor_config(self, quant_attrs: Dict): 157 quant_config = copy.deepcopy(quant_attrs) 158 # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h 159 quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT] 160 # special case for 4 bits 161 if ( 162 quant_config[QCOM_DTYPE] == torch.int8 163 and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15 164 ): 165 quant_config[QCOM_BITWIDTH] = 4 166 return ( 167 PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, 168 quant_config, 169 ) 170 return ( 171 PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, 172 quant_config, 173 ) 174 175 def get_quant_encoding_conf( 176 self, node: torch.fx.Node, is_input_tensor: bool = False 177 ) -> Tuple[Any, Dict]: 178 if not node.meta.get(QCOM_QUANT_ATTRS, None): 179 return ( 180 PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, 181 {}, 182 ) 183 quant_attrs = ( 184 node.meta[QCOM_REQUANTIZE] 185 if QCOM_REQUANTIZE in node.meta and is_input_tensor 186 else node.meta[QCOM_QUANT_ATTRS] 187 ) 188 if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING: 189 return self.make_qnn_per_channel_config(node, quant_attrs) 190 191 return self.make_qnn_per_tensor_config(quant_attrs) 192 193 def get_quant_tensor_value( 194 self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict 195 ) -> torch.Tensor: 196 if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING: 197 scale = quant_attrs[QCOM_SCALE] 198 zero_point = quant_attrs[QCOM_ZERO_POINT] 199 else: # per channel case 200 scale = quant_attrs[QCOM_SCALES] 201 zero_point = quant_attrs[QCOM_ZERO_POINTS] 202 203 dtype = quant_configs[QCOM_DTYPE] 204 205 tensor = tensor.div(scale).add(zero_point).round().to(dtype) 206 # Make the backends access data correctly 207 if quant_configs.get(QCOM_BITWIDTH) == 4: 208 mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) 209 tensor = torch.bitwise_and(mask, tensor) 210 return tensor 211 212 def get_tensor_type( 213 self, 214 node: torch.fx.Node, 215 tensor_type: PyQnnWrapper.Qnn_TensorType_t, 216 ) -> PyQnnWrapper.Qnn_TensorType_t: 217 is_input = is_graph_input(node, self.edge_program) 218 is_output = is_graph_output(node) 219 # handle logic for input/output tensors 220 if is_input or is_output: 221 assert ( 222 node in self.external_ids 223 ), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" 224 if is_input: 225 return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE 226 if is_output: 227 return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ 228 229 if is_parameter(node, self.edge_program): 230 return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC 231 # dump all tensor, set to app read, and we only dump native tensors 232 if ( 233 self.enable_tensor_dump 234 and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE 235 ): 236 return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ 237 return tensor_type 238 239 def get_data_type( 240 self, 241 tensor: torch.Tensor, 242 quant_config: Dict, 243 ) -> PyQnnWrapper.Qnn_TensorType_t: 244 if quant_config: 245 quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config) 246 return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]] 247 248 return QNN_TENSOR_TYPE_MAP[tensor.dtype] 249 250 def define_custom_tensor_wrapper( 251 self, 252 node_name: str, 253 tensor_type: PyQnnWrapper.Qnn_TensorType_t, 254 dtype: PyQnnWrapper.Qnn_DataType_t, 255 quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t, 256 quant_configs: dict, 257 dims: torch.Size, 258 tensor: torch.Tensor, 259 is_fake_tensor: bool, 260 nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], 261 wrapper_idx: int = 0, 262 ) -> PyQnnWrapper.TensorWrapper: 263 if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): 264 return cached 265 if is_fake_tensor: 266 tensor_wrapper = PyQnnWrapper.TensorWrapper( 267 node_name, 268 tensor_type, 269 dtype, 270 quant_encoding, 271 quant_configs, 272 len(dims), 273 dims, 274 np.array([]), 275 False, 276 ) 277 else: 278 # Can implement non-fake tensor when there is a need 279 return None 280 nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper 281 return tensor_wrapper 282 283 def define_tensor( 284 self, 285 node: torch.fx.Node, 286 tensor: torch.Tensor, 287 tensor_type: PyQnnWrapper.Qnn_TensorType_t, 288 nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], 289 is_input_tensor: bool, 290 node_name: str = None, 291 wrapper_idx: int = 0, 292 ) -> PyQnnWrapper.TensorWrapper: 293 """ 294 Covert torch.Tensor to TensorWrapper 295 296 Args: 297 node: EdgeIR Node 298 tensor: EdgeIR Tensor 299 tensor_type: QNN tensor type 300 nodes_to_wrappers: Set contains edge_graph values(node targets) 301 is_input_tensor: Whether tensor is a fake input tensor relatively to 302 the op builder that is calling this function 303 """ 304 if node_name is None: 305 node_name = node.name 306 307 if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): 308 return cached 309 310 tensor_name = f"{node.name}_{wrapper_idx}" 311 if is_graph_input(node, self.edge_program): 312 tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name 313 if is_graph_output(node): 314 tensor_name = "output_" + tensor_name 315 dims = [1] if len(tensor.size()) == 0 else tensor.size() 316 tensor_type = self.get_tensor_type(node, tensor_type) 317 quant_encoding, quant_configs = self.get_quant_encoding_conf( 318 node, is_input_tensor 319 ) 320 dtype = self.get_data_type(tensor, quant_configs) 321 if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): 322 tensor_wrapper = PyQnnWrapper.TensorWrapper( 323 tensor_name, 324 tensor_type, 325 dtype, 326 quant_encoding, 327 quant_configs, 328 len(dims), 329 dims, 330 np.array([]), 331 False, 332 ) 333 else: 334 if quant_configs: 335 tensor = self.get_quant_tensor_value( 336 tensor, 337 node.meta[QCOM_QUANT_ATTRS], 338 quant_configs, 339 ) 340 tensor_wrapper = PyQnnWrapper.TensorWrapper( 341 tensor_name, 342 tensor_type, 343 dtype, 344 quant_encoding, 345 quant_configs, 346 len(dims), 347 dims, 348 tensor.detach().numpy(), 349 True, 350 ) 351 nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper 352 return tensor_wrapper 353 354 def define_node( 355 self, 356 node: torch.fx.Node, 357 nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], 358 ) -> PyQnnWrapper.PyQnnOpWrapper: 359 """Convert torch.fx.Node to OpWrapper""" 360 raise NotImplementedError("NodeVisitor must be extended!") 361 362 363# This will hold mapping of all node names to the visitor class 364_node_visitor_dict = {} 365 366 367def register_node_visitor(visitor): 368 """Register node visitor into _node_visitor_dict""" 369 assert ( 370 isinstance(visitor, type) 371 and issubclass(visitor, NodeVisitor) 372 and hasattr(visitor, "target") 373 ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}" 374 for target in visitor.target: 375 _node_visitor_dict[target] = visitor 376 377 378def generate_node_to_external_map( 379 edge_program: torch.export.ExportedProgram, 380) -> Dict[torch.fx.Node, int]: 381 node_to_external_map = {} 382 for node in edge_program.graph_module.graph.nodes: 383 # The order in which we visit the placeholder node is same as the *args 384 # order for the forward(*args) signature for this gm. Using the order of 385 # the nodes as external_id to extract the right arg from *args at runtime 386 if is_graph_input(node, edge_program): 387 node_to_external_map[node] = len(node_to_external_map) 388 for node in edge_program.graph_module.graph.nodes: 389 if is_graph_output(node): 390 node_to_external_map[node] = len(node_to_external_map) 391 return node_to_external_map 392 393 394def get_node_visitors( 395 edge_program: torch.export.ExportedProgram, 396 enable_tensor_dump=False, 397) -> Dict[str, NodeVisitor]: 398 """Create a new class instance at runtime, and put them in a dict""" 399 node_to_external_map = generate_node_to_external_map(edge_program) 400 node_visitors = {} 401 for target, visitor in _node_visitor_dict.items(): 402 assert callable( 403 visitor 404 ), f"Expeting a callable class, but got {visitor} of type {type(visitor)}" 405 node_visitors[target] = visitor( 406 node_to_external_map, edge_program, enable_tensor_dump 407 ) 408 return node_visitors 409