1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import warnings 8from typing import Dict 9 10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 11 12import torch 13from executorch.backends.qualcomm.utils.constants import ( 14 QCOM_QUANT_ATTRS, 15 QCOM_SCALES, 16 QCOM_ZERO_POINTS, 17) 18 19from .node_visitor import NodeVisitor, register_node_visitor 20from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW 21from .utils import get_parameter 22 23 24@register_node_visitor 25class LinearVisitor(NodeVisitor): 26 target = ["aten.linear.default"] 27 28 def __init__(self, *args) -> None: 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 35 ) -> PyQnnWrapper.PyQnnOpWrapper: 36 linear_input_tensors = [] 37 input_node = node.args[0] 38 input_tensor = self.get_tensor(input_node, node) 39 input_tensor_wrapper = self.define_tensor( 40 input_node, 41 input_tensor, 42 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 43 nodes_to_wrappers, 44 is_input_tensor=True, 45 ) 46 linear_input_tensors.append(input_tensor_wrapper) 47 48 weight_node = node.args[1] 49 if ( 50 quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS) 51 ) and QCOM_SCALES in quant_attrs: 52 # Dimension of weight is [m, n], per channel quant params is [m] 53 # Change to [m, 1] to fit the tensor.div(s).add(z) 54 quant_attrs[QCOM_SCALES] = quant_attrs[QCOM_SCALES].reshape([-1, 1]) 55 quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape( 56 [-1, 1] 57 ) 58 59 weight_tensor = get_parameter(weight_node, self.edge_program) 60 weight_tensor_wrapper = self.define_tensor( 61 weight_node, 62 weight_tensor, 63 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 64 nodes_to_wrappers, 65 is_input_tensor=False, 66 ) 67 linear_input_tensors.append(weight_tensor_wrapper) 68 69 if len(node.args) >= 3: 70 bias_node = node.args[2] 71 72 # TODO remove this when qnn sdk support 73 if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}): 74 warnings.warn( 75 f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.", 76 stacklevel=1, 77 ) 78 bias_tensor = get_parameter(bias_node, self.edge_program) 79 bias_tensor_wrapper = self.define_tensor( 80 bias_node, 81 bias_tensor, 82 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 83 nodes_to_wrappers, 84 is_input_tensor=False, 85 ) 86 linear_input_tensors.append(bias_tensor_wrapper) 87 88 output_tensor = self.get_tensor(node, node) 89 output_tensor_wrapper = self.define_tensor( 90 node, 91 output_tensor, 92 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 93 nodes_to_wrappers, 94 is_input_tensor=False, 95 ) 96 97 linear_op = PyQnnWrapper.PyQnnOpWrapper( 98 node.name, 99 QNN_OP_PACKAGE_NAME_QTI_AISW, 100 OpFullyConnected.op_name, 101 ) 102 linear_op.AddInputTensors(linear_input_tensors) 103 linear_op.AddOutputTensors([output_tensor_wrapper]) 104 105 return linear_op 106