xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/recompose_rms_norm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10
11from .utils import dq_ops
12
13
14class RecomposeRmsNorm(ExportPass):
15    """
16    Merge decomposed operators back to one super node.
17    """
18
19    def __init__(self):
20        super().__init__()
21
22    def _get_eps_node(self, nodes):
23        # eps: one of inputs of add node
24        add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0]
25        for a in add_node.args:
26            if isinstance(a, float) or a.op != "call_function":
27                return a
28
29    def _get_gamma_node(self, output_node):
30        # gamma: one of inputs of output node
31        for a in output_node.args:
32            if a.op != "call_function" or a.target in dq_ops:
33                return a
34
35    def call(self, graph_module: torch.fx.GraphModule):
36        graph = graph_module.graph
37        partitions = get_source_partitions(graph, [torch.nn.RMSNorm])
38        for _, src_partitions in partitions.items():
39            for src_partition in src_partitions:
40                input_len = len(src_partition.input_nodes)
41                if input_len == 1:
42                    input_node = src_partition.input_nodes[0]
43                elif input_len == 2:
44                    inp_0, inp_1 = src_partition.input_nodes
45                    input_node = inp_0 if len(inp_0.users) == 2 else inp_1
46                else:
47                    raise RuntimeError(
48                        f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs"
49                    )
50
51                output_node = src_partition.output_nodes[0]
52                eps_node = self._get_eps_node(src_partition.nodes)
53                gamma_node = self._get_gamma_node(output_node)
54
55                with graph.inserting_before(output_node):
56                    # args schema
57                    # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
58                    rms_node = graph.create_node(
59                        "call_function",
60                        exir_ops.edge.aten.rms_norm.default,
61                        (
62                            input_node,
63                            list(gamma_node.meta["val"].shape),
64                            gamma_node,
65                            eps_node,
66                        ),
67                    )
68                    users = output_node.users.copy()
69                    for user in users:
70                        user.replace_input_with(output_node, rms_node)
71                    # copy metadata
72                    rms_node.meta = output_node.meta
73
74        graph.eliminate_dead_code()
75        graph_module.recompile()
76        return PassResult(graph_module, True)
77