1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import operator 7from collections import Counter 8from typing import List 9 10import torch 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.pass_base import ExportPass, PassResult 13from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 14 15 16class ConvertBmmToMatmul(ExportPass): 17 """ 18 Replace bmm to matmul, because bmm is eqaul to matmul in QNN. 19 Handle missing quantization tag for bmm op. 20 """ 21 22 view_copy = exir_ops.edge.aten.view_copy.default 23 expand_copy = exir_ops.edge.aten.expand_copy.default 24 clone = exir_ops.edge.aten.clone.default 25 bmm = exir_ops.edge.aten.bmm.default 26 matmul = exir_ops.edge.aten.matmul.default 27 patterns = [ 28 {expand_copy: 2, view_copy: 3, bmm: 1}, 29 {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1}, 30 {bmm: 1}, 31 ] 32 33 def __init__(self): 34 super(ConvertBmmToMatmul, self).__init__() 35 36 def _get_ordered_inputs( 37 self, inputs: List[torch.fx.Node], output: torch.fx.Node 38 ) -> List[torch.fx.Node]: 39 bmm_inputs = [] 40 for arg in output.args: 41 while arg not in inputs: 42 arg = arg.args[0] 43 bmm_inputs.append(arg) 44 return bmm_inputs 45 46 def call(self, graph_module: torch.fx.GraphModule): 47 graph = graph_module.graph 48 partitions = get_source_partitions( 49 graph, [operator.matmul, torch.matmul, torch.bmm] 50 ) 51 for _, src_partitions in partitions.items(): 52 for src_partition in src_partitions: 53 op_cnt = Counter([n.target for n in src_partition.nodes]) 54 if op_cnt not in self.patterns: 55 continue 56 57 inputs = src_partition.input_nodes 58 bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0] 59 output = src_partition.output_nodes[0] 60 # the order of src_partition.inputs is not guaranteed. 61 lhs, rhs = self._get_ordered_inputs(inputs, bmm_node) 62 with graph_module.graph.inserting_before(output): 63 # replace bmm to matmul, because bmm is eqaul to matmul in qnn. 64 matmul_node = graph.create_node( 65 "call_function", self.matmul, (lhs, rhs) 66 ) 67 matmul_node.meta = output.meta 68 for user in output.users.copy(): 69 user.replace_input_with(output, matmul_node) 70 71 graph.eliminate_dead_code() 72 graph_module.recompile() 73 return PassResult(graph_module, True) 74