xref: /aosp_15_r20/external/executorch/backends/xnnpack/_passes/remove_getitem_op.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects._ops import ops as exir_ops
9
10from executorch.exir.pass_base import ExportPass, PassResult
11
12
13class RemoveGetItemPass(ExportPass):
14    """
15    This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
16    that exratacts the first output. More speciafially, we are only getting the first output from aten::maxpool2d operator.
17    Before Pass:
18        MaxPool2d ---> GetItem[max_values, max_indexes]
19    After Pass:
20        MaxPool2d -> max_values
21    """
22
23    def call(self, graph_module: torch.fx.GraphModule):
24        mdule = graph_module
25        for node in mdule.graph.nodes:
26            if node.op == "call_function":
27                if (
28                    node.target.__name__ == "aten.max_pool2d_with_indices.default"
29                    or node.target.__name__ == "aten.max.dim"
30                ):
31                    users = list(node.users.keys())
32
33                    if len(users) != 1:
34                        if len(users) == 2 and node.target.__name__ == "aten.max.dim":
35                            # Two users is allowed for max.dim. For that case,
36                            # rather than removing the getitem node in this
37                            # pass, we handle the getitem nodes in the op's
38                            # visitor when serializing
39                            continue
40                        else:
41                            raise AssertionError(
42                                f"Invalid number of users for {node.target.__name__ }: {len(users)}"
43                            )
44
45                    getitem_node = list(node.users.keys())[0]
46
47                    if getitem_node.target.__name__ != "getitem":
48                        raise AssertionError(
49                            f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
50                        )
51
52                    getitem_index = getitem_node.args[1]
53
54                    with mdule.graph.inserting_before(node):
55                        if (
56                            node.target.__name__
57                            == "aten.max_pool2d_with_indices.default"
58                        ):
59                            if getitem_index != 0:
60                                raise AssertionError(
61                                    f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
62                                )
63                            new_max_wd = mdule.graph.create_node(
64                                "call_function",
65                                exir_ops.edge.aten.max_pool2d.default,
66                                args=node.args,
67                                kwargs=node.kwargs,
68                            )
69                        else:
70                            if getitem_index != 0:
71                                raise AssertionError(
72                                    f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
73                                )
74                            new_max_wd = mdule.graph.create_node(
75                                "call_function",
76                                exir_ops.edge.aten.amax.default,
77                                args=node.args,
78                                kwargs=node.kwargs,
79                            )
80
81                    getitem_node.replace_all_uses_with(new_max_wd)
82
83                    mdule.graph.erase_node(getitem_node)
84                    mdule.graph.erase_node(node)
85
86        graph_module.recompile()
87        # Propagate metadata and retrace module
88        graph_module = super().call(graph_module).graph_module
89
90        return PassResult(graph_module, True)
91