xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/convert_prelu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10
11
12class ConvertPReLU(ExportPass):
13    """
14    Merge decomposed operators from prelu back to one super node.
15    """
16
17    def __init__(self, edge_program: torch.export.ExportedProgram):
18        super(ConvertPReLU, self).__init__()
19        self.edge_program = edge_program
20
21    def call(self, graph_module: torch.fx.GraphModule):
22        graph = graph_module.graph
23        partitions = get_source_partitions(graph, [torch.nn.PReLU])
24        for _, src_partitions in partitions.items():
25            for src_partition in src_partitions:
26                input_node = src_partition.input_nodes[0]
27                output_node = src_partition.output_nodes[0]
28                placeholders = [n for n in src_partition.nodes if n.op == "placeholder"]
29                assert len(placeholders) == 1
30
31                with graph.inserting_after(input_node):
32                    prelu_op = exir_ops.edge.aten.prelu.default
33                    prelu_node = graph.create_node(
34                        "call_function", prelu_op, (input_node, placeholders[0])
35                    )
36                    users = output_node.users.copy()
37                    for user in users:
38                        user.replace_input_with(output_node, prelu_node)
39                    # copy metadata
40                    prelu_node.meta = output_node.meta
41
42        graph.eliminate_dead_code()
43        graph_module.recompile()
44        return PassResult(graph_module, True)
45