xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/convert_bmm_to_matmul.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import operator
7from collections import Counter
8from typing import List
9
10import torch
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.pass_base import ExportPass, PassResult
13from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14
15
16class ConvertBmmToMatmul(ExportPass):
17    """
18    Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
19    Handle missing quantization tag for bmm op.
20    """
21
22    view_copy = exir_ops.edge.aten.view_copy.default
23    expand_copy = exir_ops.edge.aten.expand_copy.default
24    clone = exir_ops.edge.aten.clone.default
25    bmm = exir_ops.edge.aten.bmm.default
26    matmul = exir_ops.edge.aten.matmul.default
27    patterns = [
28        {expand_copy: 2, view_copy: 3, bmm: 1},
29        {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
30        {bmm: 1},
31    ]
32
33    def __init__(self):
34        super(ConvertBmmToMatmul, self).__init__()
35
36    def _get_ordered_inputs(
37        self, inputs: List[torch.fx.Node], output: torch.fx.Node
38    ) -> List[torch.fx.Node]:
39        bmm_inputs = []
40        for arg in output.args:
41            while arg not in inputs:
42                arg = arg.args[0]
43            bmm_inputs.append(arg)
44        return bmm_inputs
45
46    def call(self, graph_module: torch.fx.GraphModule):
47        graph = graph_module.graph
48        partitions = get_source_partitions(
49            graph, [operator.matmul, torch.matmul, torch.bmm]
50        )
51        for _, src_partitions in partitions.items():
52            for src_partition in src_partitions:
53                op_cnt = Counter([n.target for n in src_partition.nodes])
54                if op_cnt not in self.patterns:
55                    continue
56
57                inputs = src_partition.input_nodes
58                bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
59                output = src_partition.output_nodes[0]
60                # the order of src_partition.inputs is not guaranteed.
61                lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
62                with graph_module.graph.inserting_before(output):
63                    # replace bmm to matmul, because bmm is eqaul to matmul in qnn.
64                    matmul_node = graph.create_node(
65                        "call_function", self.matmul, (lhs, rhs)
66                    )
67                    matmul_node.meta = output.meta
68                    for user in output.users.copy():
69                        user.replace_input_with(output, matmul_node)
70
71        graph.eliminate_dead_code()
72        graph_module.recompile()
73        return PassResult(graph_module, True)
74