1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import cast, Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW 16 17 18@register_node_visitor 19class Softmax(NodeVisitor): 20 target = ["aten._softmax.default", "aten._safe_softmax.default"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 29 ) -> PyQnnWrapper.PyQnnOpWrapper: 30 input_node = node.args[0] 31 input_tensor = self.get_tensor(input_node, node) 32 softmax_inp_tensor_wrapper = self.define_tensor( 33 input_node, 34 input_tensor, 35 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 36 nodes_to_wrappers, 37 is_input_tensor=True, 38 ) 39 softmax_input_tensors = [softmax_inp_tensor_wrapper] 40 41 output_tensor = self.get_tensor(node, node) 42 output_tensor_wrapper = self.define_tensor( 43 node, 44 output_tensor, 45 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 46 nodes_to_wrappers, 47 is_input_tensor=False, 48 ) 49 softmax_output_tensors = [output_tensor_wrapper] 50 51 dim = cast(int, node.args[1]) 52 if dim < 0: 53 dim = dim % len(input_tensor.shape) 54 if QCOM_AXIS_ORDER in node.meta: 55 dim = node.meta[QCOM_AXIS_ORDER].index(dim) 56 57 # softmax only supports last dimension for now, which is channel in QNN 58 if dim != input_tensor.dim() - 1: 59 return None 60 61 softmax_op = PyQnnWrapper.PyQnnOpWrapper( 62 node.name, 63 QNN_OP_PACKAGE_NAME_QTI_AISW, 64 OpSoftmax.op_name, 65 ) 66 softmax_op.AddInputTensors(softmax_input_tensors) 67 softmax_op.AddOutputTensors(softmax_output_tensors) 68 69 softmax_op.AddScalarParam( 70 OpSoftmax.param_axis, 71 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 72 {QCOM_DATA: np.uint32(dim)}, 73 ) 74 75 return softmax_op 76