1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 10 11from .utils import dq_ops 12 13 14class RecomposeRmsNorm(ExportPass): 15 """ 16 Merge decomposed operators back to one super node. 17 """ 18 19 def __init__(self): 20 super().__init__() 21 22 def _get_eps_node(self, nodes): 23 # eps: one of inputs of add node 24 add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0] 25 for a in add_node.args: 26 if isinstance(a, float) or a.op != "call_function": 27 return a 28 29 def _get_gamma_node(self, output_node): 30 # gamma: one of inputs of output node 31 for a in output_node.args: 32 if a.op != "call_function" or a.target in dq_ops: 33 return a 34 35 def call(self, graph_module: torch.fx.GraphModule): 36 graph = graph_module.graph 37 partitions = get_source_partitions(graph, [torch.nn.RMSNorm]) 38 for _, src_partitions in partitions.items(): 39 for src_partition in src_partitions: 40 input_len = len(src_partition.input_nodes) 41 if input_len == 1: 42 input_node = src_partition.input_nodes[0] 43 elif input_len == 2: 44 inp_0, inp_1 = src_partition.input_nodes 45 input_node = inp_0 if len(inp_0.users) == 2 else inp_1 46 else: 47 raise RuntimeError( 48 f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs" 49 ) 50 51 output_node = src_partition.output_nodes[0] 52 eps_node = self._get_eps_node(src_partition.nodes) 53 gamma_node = self._get_gamma_node(output_node) 54 55 with graph.inserting_before(output_node): 56 # args schema 57 # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor 58 rms_node = graph.create_node( 59 "call_function", 60 exir_ops.edge.aten.rms_norm.default, 61 ( 62 input_node, 63 list(gamma_node.meta["val"].shape), 64 gamma_node, 65 eps_node, 66 ), 67 ) 68 users = output_node.users.copy() 69 for user in users: 70 user.replace_input_with(output_node, rms_node) 71 # copy metadata 72 rms_node.meta = output_node.meta 73 74 graph.eliminate_dead_code() 75 graph_module.recompile() 76 return PassResult(graph_module, True) 77