xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/op_max_dim.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    get_tensor_value,
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16    XNNArgMaxPooling2d,
17    XNNGraph,
18    XNNMaxPooling2d,
19    XNode,
20)
21from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
22
23
24@register_node_visitor
25class MaxDim(NodeVisitor):
26    target = "aten.amax.default"
27
28    def __init__(self, *args) -> None:
29        super().__init__(*args)
30
31    def define_node(
32        self,
33        node: torch.fx.Node,
34        xnn_graph: XNNGraph,
35        vals_to_ids: Dict[torch.fx.Node, int],
36        debug_handle: int,
37    ) -> None:
38
39        check_or_raise(
40            len(node.args) == 3,
41            "amax.default only supports keep_dim == True",
42        )
43
44        dim_val = cast(int, node.args[1])
45        check_or_raise(
46            dim_val == 2 or dim_val == 3,
47            "amax.default only supports dim == 2 or dim == 3",
48        )
49
50        input_id = vals_to_ids[get_input_node(node, 0)]
51
52        self.define_nodes_tensor_inputs_outputs(
53            node, xnn_graph, vals_to_ids, convert_to_nhwc=True
54        )
55
56        output_id = vals_to_ids[node]
57
58        input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
59        check_or_raise(
60            len(input_shape) == 4, "Require input to max.dim be 4 dimensional"
61        )
62
63        # This is in NHWC
64        pooling_height = 1
65        pooling_width = 1
66        stride_height = 1
67        stride_width = 1
68        if dim_val == 2:
69            pooling_height = input_shape[1]
70            pooling_width = 1
71            stride_height = input_shape[1]
72        elif dim_val == 3:
73            pooling_height = 1
74            pooling_width = input_shape[2]
75            stride_width = input_shape[2]
76
77        ser_node = XNode(
78            xnode_union=XNNMaxPooling2d(
79                padding_top=0,
80                padding_right=0,
81                padding_bottom=0,
82                padding_left=0,
83                pooling_height=pooling_height,
84                pooling_width=pooling_width,
85                stride_height=stride_height,
86                stride_width=stride_width,
87                dilation_height=1,
88                dilation_width=1,
89                input_id=input_id,
90                output_id=output_id,
91                flags=0,
92            ),
93            debug_handle=debug_handle,
94        )
95
96        xnn_graph.xnodes.append(ser_node)
97
98
99@register_node_visitor
100class ArgMaxDim(NodeVisitor):
101    target = "aten.max.dim"
102
103    def __init__(self, *args) -> None:
104        super().__init__(*args)
105
106    def define_node(
107        self,
108        node: torch.fx.Node,
109        xnn_graph: XNNGraph,
110        vals_to_ids: Dict[torch.fx.Node, int],
111        debug_handle: int,
112    ) -> None:
113
114        check_or_raise(
115            len(node.args) == 3,
116            "max.dim only supports keep_dim == True",
117        )
118
119        dim_val = cast(int, node.args[1])
120        check_or_raise(
121            dim_val == 2 or dim_val == 3,
122            "max.dim only supports dim == 2 or dim == 3",
123        )
124
125        # node.meta["val"] is a tuple (values_tensor, indices_tensor)
126        # We don't care about how it is defined, so we can adjust val to be a
127        # single tensor rather than a tuple arbitrarily just to make
128        # define_nodes_tensor_inputs_outputs work
129        original_val = node.meta["val"]
130        node.meta["val"] = original_val[0]
131
132        self.define_nodes_tensor_inputs_outputs(
133            node, xnn_graph, vals_to_ids, convert_to_nhwc=True
134        )
135        for user in node.users:
136            self.define_nodes_tensor_inputs_outputs(
137                user, xnn_graph, vals_to_ids, convert_to_nhwc=True
138            )
139
140        # Restore node.meta["val"]
141        node.meta["val"] = original_val
142
143        input_id = vals_to_ids[get_input_node(node, 0)]
144
145        input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
146        check_or_raise(
147            len(input_shape) == 4, "Require input to max.dim be 4 dimensional"
148        )
149
150        users = list(node.users.keys())
151
152        if len(users) != 2:
153            raise AssertionError(
154                f"Invalid number of users for max.dim (Expected 2, Got: {len(users)})"
155            )
156
157        values_node = None
158        indices_node = None
159
160        for getitem_node in users:
161            taget_name = cast(torch._ops.OpOverload, getitem_node.target).__name__
162            if taget_name != "getitem":
163                raise AssertionError(
164                    f"Expected max node's user to be getitem, got: {taget_name}"
165                )
166
167            if getitem_node.args[1] == 0:
168                values_node = getitem_node
169            elif getitem_node.args[1] == 1:
170                indices_node = getitem_node
171
172        if values_node is None or indices_node is None:
173            raise AssertionError(
174                f"Expected max node's getitem args to be 1 and 2, got: {[user.args[1] for user in users]}"
175            )
176
177        output_index_id = vals_to_ids[indices_node]
178        output_value_id = vals_to_ids[values_node]
179
180        # This is in NHWC
181        pooling_height = 1
182        pooling_width = 1
183        if dim_val == 2:
184            pooling_height = input_shape[1]
185            pooling_width = 1
186        elif dim_val == 3:
187            pooling_height = 1
188            pooling_width = input_shape[2]
189
190        ser_node = XNode(
191            xnode_union=XNNArgMaxPooling2d(
192                padding_top=0,
193                padding_right=0,
194                padding_bottom=0,
195                padding_left=0,
196                pooling_height=pooling_height,
197                pooling_width=pooling_width,
198                input_id=input_id,
199                output_value_id=output_value_id,
200                output_index_id=output_index_id,
201                flags=0,
202            ),
203            debug_handle=debug_handle,
204        )
205
206        xnn_graph.xnodes.append(ser_node)
207