1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import cast, Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpLogSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW 16 17 18@register_node_visitor 19class LogSoftmax(NodeVisitor): 20 target = ["aten._log_softmax.default"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 29 ) -> PyQnnWrapper.PyQnnOpWrapper: 30 input_node = node.args[0] 31 input_tensor = self.get_tensor(input_node, node) 32 33 log_softmax_inp_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 log_softmax_input_tensors = [log_softmax_inp_tensor_wrapper] 41 output_tensor = self.get_tensor(node, node) 42 43 log_softmax_output_tensor_wrapper = self.define_tensor( 44 node, 45 output_tensor, 46 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 47 nodes_to_wrappers, 48 is_input_tensor=False, 49 ) 50 log_softmax_output_tensors = [log_softmax_output_tensor_wrapper] 51 52 dim = cast(int, node.args[1]) 53 if dim < 0: 54 dim = dim % len(input_tensor.shape) 55 56 if QCOM_AXIS_ORDER in node.meta: 57 dim = node.meta[QCOM_AXIS_ORDER].index(dim) 58 59 # logsoftmax only supports last dimension for now, which is channel in QNN 60 if dim != input_tensor.dim() - 1: 61 return None 62 63 log_softmax_op = PyQnnWrapper.PyQnnOpWrapper( 64 node.name, 65 QNN_OP_PACKAGE_NAME_QTI_AISW, 66 OpLogSoftmax.op_name, 67 ) 68 log_softmax_op.AddInputTensors(log_softmax_input_tensors) 69 log_softmax_op.AddOutputTensors(log_softmax_output_tensors) 70 71 log_softmax_op.AddScalarParam( 72 OpLogSoftmax.param_axis, 73 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 74 {QCOM_DATA: np.uint32(dim)}, 75 ) 76 return log_softmax_op 77