xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_avg_pool2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import warnings
7from typing import cast, Dict, List
8
9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10
11import numpy as np
12import torch
13from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14
15from .node_visitor import NodeVisitor, register_node_visitor
16from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
17
18
19@register_node_visitor
20class AvgPool2d(NodeVisitor):
21    target = ["aten.avg_pool2d.default"]
22
23    def __init__(self, *args) -> None:
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30    ) -> PyQnnWrapper.PyQnnOpWrapper:
31        input_node = node.args[0]
32        input_tensor = self.get_tensor(input_node, node)
33        input_tensor_wrapper = self.define_tensor(
34            input_node,
35            input_tensor,
36            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37            nodes_to_wrappers,
38            is_input_tensor=True,
39        )
40
41        output_tensor = self.get_tensor(node, node)
42        output_tensor_wrapper = self.define_tensor(
43            node,
44            output_tensor,
45            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46            nodes_to_wrappers,
47            is_input_tensor=False,
48        )
49        # kernel info
50        filter_size = cast(List[int], node.args[1])
51        if len(filter_size) == 1:
52            filter_size = filter_size + filter_size
53        filter_size_shape = [len(filter_size)]
54
55        # stride info - default to kernel_size if not given
56        stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
57        if len(stride) == 1:
58            stride = stride + stride
59        stride_shape = [len(stride)]
60
61        padding = [0, 0]
62        if len(node.args) > 3:
63            padding = cast(List[int], node.args[3])
64            if len(padding) == 1:
65                padding = padding + padding
66        padding_shape = [len(padding), len(padding)]
67
68        # if ceil mode is True, use ceil instead of floor to compute the output shape
69        mode = OpPoolAvg2d.RoundingMode.FLOOR
70        if len(node.args) > 4:
71            ceil_mode = cast(bool, node.args[4])
72            if ceil_mode:
73                mode = OpPoolAvg2d.RoundingMode.CEIL
74
75        count_include_pad = True
76        if len(node.args) > 5:
77            count_include_pad = cast(bool, node.args[5])
78        # TODO: If count_include_pad = False, it seems not to compute average with padding in Qnn.
79        # But it still compute average with padding value, and change divisor in torch
80        # if not count_include_pad:
81        #     print("Not support count_include_pad = False.")
82        #     return
83
84        pooling_region = filter_size[0] * filter_size[1]
85        divisor_override = pooling_region  # Default divisor is pooling_region
86        if len(node.args) > 6:
87            divisor_override = cast(int, node.args[6])
88        if divisor_override != pooling_region:
89            warnings.warn(
90                "[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
91                stacklevel=1,
92            )
93            return
94
95        avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
96            node.name,
97            QNN_OP_PACKAGE_NAME_QTI_AISW,
98            OpPoolAvg2d.op_name,
99        )
100        avg_pool2d_op.AddInputTensors([input_tensor_wrapper])
101        avg_pool2d_op.AddOutputTensors([output_tensor_wrapper])
102
103        avg_pool2d_op.AddTensorParam(
104            OpPoolAvg2d.param_filter_size,
105            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
106            len(filter_size_shape),
107            filter_size_shape,
108            np.array(
109                filter_size,
110                dtype=np.uint32,
111            ),
112            True,
113        )
114        avg_pool2d_op.AddTensorParam(
115            OpPoolAvg2d.param_stride,
116            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
117            len(stride_shape),
118            stride_shape,
119            np.array(
120                stride,
121                dtype=np.uint32,
122            ),
123            True,
124        )
125        avg_pool2d_op.AddTensorParam(
126            OpPoolAvg2d.param_pad_amount,
127            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
128            len(padding_shape),
129            padding_shape,
130            np.array(
131                [[padding[0], padding[0]], [padding[1], padding[1]]],
132                dtype=np.uint32,
133            ),
134            True,
135        )
136
137        avg_pool2d_op.AddScalarParam(
138            OpPoolAvg2d.param_rounding_mode,
139            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
140            {QCOM_DATA: np.uint32(mode)},
141        )
142        avg_pool2d_op.AddScalarParam(
143            OpPoolAvg2d.param_count_pad_for_edges,
144            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
145            {QCOM_DATA: count_include_pad},
146        )
147
148        return avg_pool2d_op
149