xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_max_pool2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import warnings
7from typing import cast, Dict, List
8
9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10
11import numpy as np
12import torch
13from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14
15from .node_visitor import NodeVisitor, register_node_visitor
16from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW
17
18
19@register_node_visitor
20class MaxPool2d(NodeVisitor):
21    target = ["aten.max_pool2d_with_indices.default"]
22
23    def __init__(self, *args) -> None:
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30    ) -> PyQnnWrapper.PyQnnOpWrapper:
31        input_node = node.args[0]
32        input_tensor = self.get_tensor(input_node, node)
33        input_tensor_wrapper = self.define_tensor(
34            input_node,
35            input_tensor,
36            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37            nodes_to_wrappers,
38            is_input_tensor=True,
39        )
40
41        users = list(node.users.keys())
42        for user in users:
43            if user.target.__name__ == "getitem":
44                getitem_index = user.args[1]
45                if getitem_index != 0:
46                    warnings.warn(
47                        f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
48                        stacklevel=1,
49                    )
50                    return
51
52        output_tensor = self.get_tensor(node, node, 0)
53        output_tensor_wrapper = self.define_tensor(
54            node,
55            output_tensor,
56            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
57            nodes_to_wrappers,
58            is_input_tensor=False,
59        )
60        # kernel info
61        filter_size = cast(List[int], node.args[1])
62        if len(filter_size) == 1:
63            filter_size = filter_size + filter_size
64        filter_size_shape = [len(filter_size)]
65
66        # stride info
67        stride = cast(List[int], node.args[2])
68        if len(stride) == 1:
69            stride = stride + stride
70        stride_shape = [len(stride)]
71
72        padding = [0, 0]
73        if len(node.args) > 3:
74            padding = cast(List[int], node.args[3])
75            if len(padding) == 1:
76                padding = padding + padding
77        padding_shape = [len(padding), len(padding)]
78
79        # dilation info
80        if len(node.args) > 4:
81            dilation = cast(List[int], node.args[4])
82            if not (dilation == 1 or dilation == [1, 1]):
83                warnings.warn(
84                    f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}",
85                    stacklevel=1,
86                )
87                return
88
89        # if cail mode is True, use ceil instead of floor to compute the output shape
90        mode = OpPoolMax2d.RoundingMode.FLOOR
91        if len(node.args) > 5:
92            ceil_mode = cast(bool, node.args[5])
93            if ceil_mode:
94                mode = OpPoolMax2d.RoundingMode.CEIL
95
96        max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
97            node.name,
98            QNN_OP_PACKAGE_NAME_QTI_AISW,
99            OpPoolMax2d.op_name,
100        )
101        max_pool2d_op.AddInputTensors([input_tensor_wrapper])
102        max_pool2d_op.AddOutputTensors([output_tensor_wrapper])
103
104        max_pool2d_op.AddTensorParam(
105            OpPoolMax2d.param_filter_size,
106            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
107            len(filter_size_shape),
108            filter_size_shape,
109            np.array(
110                filter_size,
111                dtype=np.uint32,
112            ),
113            True,
114        )
115        max_pool2d_op.AddTensorParam(
116            OpPoolMax2d.param_stride,
117            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
118            len(stride_shape),
119            stride_shape,
120            np.array(
121                stride,
122                dtype=np.uint32,
123            ),
124            True,
125        )
126        max_pool2d_op.AddTensorParam(
127            OpPoolMax2d.param_pad_amount,
128            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
129            len(padding_shape),
130            padding_shape,
131            np.array(
132                [[padding[0], padding[0]], [padding[1], padding[1]]],
133                dtype=np.uint32,
134            ),
135            True,
136        )
137
138        max_pool2d_op.AddScalarParam(
139            OpPoolMax2d.param_rounding_mode,
140            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
141            {QCOM_DATA: np.uint32(mode)},
142        )
143
144        return max_pool2d_op
145