1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import warnings 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class MaxPool2d(NodeVisitor): 21 target = ["aten.max_pool2d_with_indices.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 input_tensor = self.get_tensor(input_node, node) 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 users = list(node.users.keys()) 42 for user in users: 43 if user.target.__name__ == "getitem": 44 getitem_index = user.args[1] 45 if getitem_index != 0: 46 warnings.warn( 47 f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}", 48 stacklevel=1, 49 ) 50 return 51 52 output_tensor = self.get_tensor(node, node, 0) 53 output_tensor_wrapper = self.define_tensor( 54 node, 55 output_tensor, 56 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 57 nodes_to_wrappers, 58 is_input_tensor=False, 59 ) 60 # kernel info 61 filter_size = cast(List[int], node.args[1]) 62 if len(filter_size) == 1: 63 filter_size = filter_size + filter_size 64 filter_size_shape = [len(filter_size)] 65 66 # stride info 67 stride = cast(List[int], node.args[2]) 68 if len(stride) == 1: 69 stride = stride + stride 70 stride_shape = [len(stride)] 71 72 padding = [0, 0] 73 if len(node.args) > 3: 74 padding = cast(List[int], node.args[3]) 75 if len(padding) == 1: 76 padding = padding + padding 77 padding_shape = [len(padding), len(padding)] 78 79 # dilation info 80 if len(node.args) > 4: 81 dilation = cast(List[int], node.args[4]) 82 if not (dilation == 1 or dilation == [1, 1]): 83 warnings.warn( 84 f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}", 85 stacklevel=1, 86 ) 87 return 88 89 # if cail mode is True, use ceil instead of floor to compute the output shape 90 mode = OpPoolMax2d.RoundingMode.FLOOR 91 if len(node.args) > 5: 92 ceil_mode = cast(bool, node.args[5]) 93 if ceil_mode: 94 mode = OpPoolMax2d.RoundingMode.CEIL 95 96 max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( 97 node.name, 98 QNN_OP_PACKAGE_NAME_QTI_AISW, 99 OpPoolMax2d.op_name, 100 ) 101 max_pool2d_op.AddInputTensors([input_tensor_wrapper]) 102 max_pool2d_op.AddOutputTensors([output_tensor_wrapper]) 103 104 max_pool2d_op.AddTensorParam( 105 OpPoolMax2d.param_filter_size, 106 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 107 len(filter_size_shape), 108 filter_size_shape, 109 np.array( 110 filter_size, 111 dtype=np.uint32, 112 ), 113 True, 114 ) 115 max_pool2d_op.AddTensorParam( 116 OpPoolMax2d.param_stride, 117 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 118 len(stride_shape), 119 stride_shape, 120 np.array( 121 stride, 122 dtype=np.uint32, 123 ), 124 True, 125 ) 126 max_pool2d_op.AddTensorParam( 127 OpPoolMax2d.param_pad_amount, 128 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 129 len(padding_shape), 130 padding_shape, 131 np.array( 132 [[padding[0], padding[0]], [padding[1], padding[1]]], 133 dtype=np.uint32, 134 ), 135 True, 136 ) 137 138 max_pool2d_op.AddScalarParam( 139 OpPoolMax2d.param_rounding_mode, 140 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 141 {QCOM_DATA: np.uint32(mode)}, 142 ) 143 144 return max_pool2d_op 145