1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import QCOM_DATA 12 13from .node_visitor import NodeVisitor, register_node_visitor 14from .qnn_constants import OpResizeNearestNeighbor, QNN_OP_PACKAGE_NAME_QTI_AISW 15 16 17@register_node_visitor 18class ResizeBilinear(NodeVisitor): 19 target = ["aten.upsample_nearest2d.default"] 20 21 def __init__(self, *args) -> None: 22 super().__init__(*args) 23 24 def define_node( 25 self, 26 node: torch.fx.Node, 27 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 28 ) -> PyQnnWrapper.PyQnnOpWrapper: 29 input_node = node.args[0] 30 input_tensor = self.get_tensor(input_node, node) 31 input_tensor_wrapper = self.define_tensor( 32 input_node, 33 input_tensor, 34 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 35 nodes_to_wrappers, 36 is_input_tensor=True, 37 ) 38 39 output_tensor = self.get_tensor(node, node) 40 output_tensor_wrapper = self.define_tensor( 41 node, 42 output_tensor, 43 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 44 nodes_to_wrappers, 45 is_input_tensor=False, 46 ) 47 48 reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper( 49 node.name, 50 QNN_OP_PACKAGE_NAME_QTI_AISW, 51 OpResizeNearestNeighbor.op_name, 52 ) 53 reisze_nearest_op.AddInputTensors([input_tensor_wrapper]) 54 reisze_nearest_op.AddOutputTensors([output_tensor_wrapper]) 55 # align_corners is guaranteed to be false 56 reisze_nearest_op.AddScalarParam( 57 OpResizeNearestNeighbor.param_align_corners, 58 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 59 {QCOM_DATA: False}, 60 ) 61 reisze_nearest_op.AddScalarParam( 62 OpResizeNearestNeighbor.param_half_pixel_centers, 63 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 64 {QCOM_DATA: True}, 65 ) 66 67 return reisze_nearest_op 68