xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/decompose_silu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import torch
9from executorch.exir.pass_base import ExportPass, PassResult
10from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11
12
13class DecomposeSilu(ExportPass):
14    def __init__(self):
15        super(DecomposeSilu, self).__init__()
16
17    def _copy_meta(self, meta: Dict):
18        copied = {}
19        for k, v in meta.items():
20            copied[k] = v
21        return copied
22
23    def call(self, graph_module: torch.fx.GraphModule):
24        graph = graph_module.graph
25        partitions = get_source_partitions(graph, [torch.nn.functional.silu])
26        for _, src_partitions in partitions.items():
27            for src_partition in src_partitions:
28
29                inputs = src_partition.input_nodes
30                silu_node = src_partition.nodes[0]
31                with graph_module.graph.inserting_after(inputs[0]):
32                    sigmoid_node = graph.create_node(
33                        "call_function", torch.ops.aten.sigmoid, (inputs[0],)
34                    )
35                    sigmoid_node.meta = self._copy_meta(silu_node.meta)
36                    with graph_module.graph.inserting_after(sigmoid_node):
37                        mul_node = graph.create_node(
38                            "call_function",
39                            torch.ops.aten.mul,
40                            (inputs[0], sigmoid_node),
41                        )
42                        mul_node.meta = self._copy_meta(silu_node.meta)
43                        for user in silu_node.users.copy():
44                            user.replace_input_with(silu_node, mul_node)
45
46        graph.eliminate_dead_code()
47        graph_module.recompile()
48        return PassResult(graph_module, True)
49