1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11 12from .node_visitor import NodeVisitor, register_node_visitor 13from .qnn_constants import OpElementWiseMultiply, QNN_OP_PACKAGE_NAME_QTI_AISW 14 15 16@register_node_visitor 17class Mul(NodeVisitor): 18 target = ["aten.mul.Tensor"] 19 20 def __init__(self, *args) -> None: 21 super().__init__(*args) 22 23 def define_node( 24 self, 25 node: torch.fx.Node, 26 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 27 ) -> PyQnnWrapper.PyQnnOpWrapper: 28 out_tensor = self.get_tensor(node, node) 29 output_tensor_wrapper = self.define_tensor( 30 node, 31 out_tensor, 32 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 33 nodes_to_wrappers, 34 is_input_tensor=False, 35 ) 36 mul_output_tensors = [output_tensor_wrapper] 37 38 mul_input_tensors = [] 39 for index in range(2): 40 input_node = node.args[index] 41 input_tensor = self.get_tensor(input_node, node) 42 tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE 43 44 input_tensor_wrapper = self.define_tensor( 45 input_node, 46 input_tensor, 47 tensor_type, 48 nodes_to_wrappers, 49 is_input_tensor=True, 50 ) 51 mul_input_tensors.append(input_tensor_wrapper) 52 53 mul_op = PyQnnWrapper.PyQnnOpWrapper( 54 node.name, 55 QNN_OP_PACKAGE_NAME_QTI_AISW, 56 OpElementWiseMultiply.op_name, 57 ) 58 mul_op.AddInputTensors(mul_input_tensors) 59 mul_op.AddOutputTensors(mul_output_tensors) 60 61 return mul_op 62