1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects._ops import ops as exir_ops 9 10from executorch.exir.pass_base import ExportPass, PassResult 11 12 13class RemoveGetItemPass(ExportPass): 14 """ 15 This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator, 16 that exratacts the first output. More speciafially, we are only getting the first output from aten::maxpool2d operator. 17 Before Pass: 18 MaxPool2d ---> GetItem[max_values, max_indexes] 19 After Pass: 20 MaxPool2d -> max_values 21 """ 22 23 def call(self, graph_module: torch.fx.GraphModule): 24 mdule = graph_module 25 for node in mdule.graph.nodes: 26 if node.op == "call_function": 27 if ( 28 node.target.__name__ == "aten.max_pool2d_with_indices.default" 29 or node.target.__name__ == "aten.max.dim" 30 ): 31 users = list(node.users.keys()) 32 33 if len(users) != 1: 34 if len(users) == 2 and node.target.__name__ == "aten.max.dim": 35 # Two users is allowed for max.dim. For that case, 36 # rather than removing the getitem node in this 37 # pass, we handle the getitem nodes in the op's 38 # visitor when serializing 39 continue 40 else: 41 raise AssertionError( 42 f"Invalid number of users for {node.target.__name__ }: {len(users)}" 43 ) 44 45 getitem_node = list(node.users.keys())[0] 46 47 if getitem_node.target.__name__ != "getitem": 48 raise AssertionError( 49 f"Expected max node's user to be getitem, got {getitem_node.target.__name__}" 50 ) 51 52 getitem_index = getitem_node.args[1] 53 54 with mdule.graph.inserting_before(node): 55 if ( 56 node.target.__name__ 57 == "aten.max_pool2d_with_indices.default" 58 ): 59 if getitem_index != 0: 60 raise AssertionError( 61 f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices." 62 ) 63 new_max_wd = mdule.graph.create_node( 64 "call_function", 65 exir_ops.edge.aten.max_pool2d.default, 66 args=node.args, 67 kwargs=node.kwargs, 68 ) 69 else: 70 if getitem_index != 0: 71 raise AssertionError( 72 f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone." 73 ) 74 new_max_wd = mdule.graph.create_node( 75 "call_function", 76 exir_ops.edge.aten.amax.default, 77 args=node.args, 78 kwargs=node.kwargs, 79 ) 80 81 getitem_node.replace_all_uses_with(new_max_wd) 82 83 mdule.graph.erase_node(getitem_node) 84 mdule.graph.erase_node(node) 85 86 graph_module.recompile() 87 # Propagate metadata and retrace module 88 graph_module = super().call(graph_module).graph_module 89 90 return PassResult(graph_module, True) 91