1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 10 11 12class ConvertPReLU(ExportPass): 13 """ 14 Merge decomposed operators from prelu back to one super node. 15 """ 16 17 def __init__(self, edge_program: torch.export.ExportedProgram): 18 super(ConvertPReLU, self).__init__() 19 self.edge_program = edge_program 20 21 def call(self, graph_module: torch.fx.GraphModule): 22 graph = graph_module.graph 23 partitions = get_source_partitions(graph, [torch.nn.PReLU]) 24 for _, src_partitions in partitions.items(): 25 for src_partition in src_partitions: 26 input_node = src_partition.input_nodes[0] 27 output_node = src_partition.output_nodes[0] 28 placeholders = [n for n in src_partition.nodes if n.op == "placeholder"] 29 assert len(placeholders) == 1 30 31 with graph.inserting_after(input_node): 32 prelu_op = exir_ops.edge.aten.prelu.default 33 prelu_node = graph.create_node( 34 "call_function", prelu_op, (input_node, placeholders[0]) 35 ) 36 users = output_node.users.copy() 37 for user in users: 38 user.replace_input_with(output_node, prelu_node) 39 # copy metadata 40 prelu_node.meta = output_node.meta 41 42 graph.eliminate_dead_code() 43 graph_module.recompile() 44 return PassResult(graph_module, True) 45