1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import warnings 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class AvgPool2d(NodeVisitor): 21 target = ["aten.avg_pool2d.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 input_tensor = self.get_tensor(input_node, node) 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 output_tensor = self.get_tensor(node, node) 42 output_tensor_wrapper = self.define_tensor( 43 node, 44 output_tensor, 45 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 46 nodes_to_wrappers, 47 is_input_tensor=False, 48 ) 49 # kernel info 50 filter_size = cast(List[int], node.args[1]) 51 if len(filter_size) == 1: 52 filter_size = filter_size + filter_size 53 filter_size_shape = [len(filter_size)] 54 55 # stride info - default to kernel_size if not given 56 stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size 57 if len(stride) == 1: 58 stride = stride + stride 59 stride_shape = [len(stride)] 60 61 padding = [0, 0] 62 if len(node.args) > 3: 63 padding = cast(List[int], node.args[3]) 64 if len(padding) == 1: 65 padding = padding + padding 66 padding_shape = [len(padding), len(padding)] 67 68 # if ceil mode is True, use ceil instead of floor to compute the output shape 69 mode = OpPoolAvg2d.RoundingMode.FLOOR 70 if len(node.args) > 4: 71 ceil_mode = cast(bool, node.args[4]) 72 if ceil_mode: 73 mode = OpPoolAvg2d.RoundingMode.CEIL 74 75 count_include_pad = True 76 if len(node.args) > 5: 77 count_include_pad = cast(bool, node.args[5]) 78 # TODO: If count_include_pad = False, it seems not to compute average with padding in Qnn. 79 # But it still compute average with padding value, and change divisor in torch 80 # if not count_include_pad: 81 # print("Not support count_include_pad = False.") 82 # return 83 84 pooling_region = filter_size[0] * filter_size[1] 85 divisor_override = pooling_region # Default divisor is pooling_region 86 if len(node.args) > 6: 87 divisor_override = cast(int, node.args[6]) 88 if divisor_override != pooling_region: 89 warnings.warn( 90 "[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.", 91 stacklevel=1, 92 ) 93 return 94 95 avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( 96 node.name, 97 QNN_OP_PACKAGE_NAME_QTI_AISW, 98 OpPoolAvg2d.op_name, 99 ) 100 avg_pool2d_op.AddInputTensors([input_tensor_wrapper]) 101 avg_pool2d_op.AddOutputTensors([output_tensor_wrapper]) 102 103 avg_pool2d_op.AddTensorParam( 104 OpPoolAvg2d.param_filter_size, 105 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 106 len(filter_size_shape), 107 filter_size_shape, 108 np.array( 109 filter_size, 110 dtype=np.uint32, 111 ), 112 True, 113 ) 114 avg_pool2d_op.AddTensorParam( 115 OpPoolAvg2d.param_stride, 116 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 117 len(stride_shape), 118 stride_shape, 119 np.array( 120 stride, 121 dtype=np.uint32, 122 ), 123 True, 124 ) 125 avg_pool2d_op.AddTensorParam( 126 OpPoolAvg2d.param_pad_amount, 127 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 128 len(padding_shape), 129 padding_shape, 130 np.array( 131 [[padding[0], padding[0]], [padding[1], padding[1]]], 132 dtype=np.uint32, 133 ), 134 True, 135 ) 136 137 avg_pool2d_op.AddScalarParam( 138 OpPoolAvg2d.param_rounding_mode, 139 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 140 {QCOM_DATA: np.uint32(mode)}, 141 ) 142 avg_pool2d_op.AddScalarParam( 143 OpPoolAvg2d.param_count_pad_for_edges, 144 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 145 {QCOM_DATA: count_include_pad}, 146 ) 147 148 return avg_pool2d_op 149