1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import torch 9from executorch.exir.pass_base import ExportPass, PassResult 10from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 11 12 13class DecomposeSilu(ExportPass): 14 def __init__(self): 15 super(DecomposeSilu, self).__init__() 16 17 def _copy_meta(self, meta: Dict): 18 copied = {} 19 for k, v in meta.items(): 20 copied[k] = v 21 return copied 22 23 def call(self, graph_module: torch.fx.GraphModule): 24 graph = graph_module.graph 25 partitions = get_source_partitions(graph, [torch.nn.functional.silu]) 26 for _, src_partitions in partitions.items(): 27 for src_partition in src_partitions: 28 29 inputs = src_partition.input_nodes 30 silu_node = src_partition.nodes[0] 31 with graph_module.graph.inserting_after(inputs[0]): 32 sigmoid_node = graph.create_node( 33 "call_function", torch.ops.aten.sigmoid, (inputs[0],) 34 ) 35 sigmoid_node.meta = self._copy_meta(silu_node.meta) 36 with graph_module.graph.inserting_after(sigmoid_node): 37 mul_node = graph.create_node( 38 "call_function", 39 torch.ops.aten.mul, 40 (inputs[0], sigmoid_node), 41 ) 42 mul_node.meta = self._copy_meta(silu_node.meta) 43 for user in silu_node.users.copy(): 44 user.replace_input_with(silu_node, mul_node) 45 46 graph.eliminate_dead_code() 47 graph_module.recompile() 48 return PassResult(graph_module, True) 49