xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/convert_to_linear.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from collections import Counter
7from typing import Callable, List
8
9import torch
10from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
11from executorch.backends.transforms.addmm_mm_to_linear import (
12    apply_addmm_mm_to_linear_transform,
13)
14from executorch.exir.dialects._ops import ops as exir_ops
15from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op
16from executorch.exir.pass_base import ExportPass, PassResult
17from executorch.exir.passes import dead_code_elimination_pass
18
19from torch.fx.passes.utils.source_matcher_utils import (
20    get_source_partitions,
21    SourcePartition,
22)
23
24from .utils import dq_ops, get_quant_attrs, q_ops
25
26
27class ConvertToLinear(ExportPass):
28    """
29    Handle missing quantization tag for addmm op after decomposing
30    """
31
32    view_copy = exir_ops.edge.aten.view_copy.default
33    permute_copy = exir_ops.edge.aten.permute_copy.default
34    expand_copy = exir_ops.edge.aten.expand_copy.default
35    linear = exir_ops.edge.aten.linear.default
36    add = exir_ops.edge.aten.add.Tensor
37    addmm = exir_ops.edge.aten.addmm.default
38    bmm = exir_ops.edge.aten.bmm.default
39    mm = exir_ops.edge.aten.mm.default
40
41    addmm_patterns = [
42        {view_copy: 2, permute_copy: 1, addmm: 1},
43        {permute_copy: 1, addmm: 1},
44    ]
45
46    bmm_patterns = [
47        {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1},
48        {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1},
49    ]
50
51    mm_patterns = [
52        {view_copy: 2, permute_copy: 1, mm: 1},
53        {permute_copy: 1, mm: 1},
54    ]
55
56    def __init__(self):
57        super(ConvertToLinear, self).__init__()
58
59    def _get_original_input(
60        self, inputs: List[torch.fx.Node], cur_node: torch.fx.Node
61    ) -> torch.fx.Node:
62        while cur_node not in inputs and cur_node.args:
63            cur_node = cur_node.args[0]
64        return cur_node
65
66    def _convert_to_linear(
67        self,
68        gm: torch.fx.GraphModule,
69        src_partition: SourcePartition,
70        extract_ops_fn: Callable,
71    ):
72        inputs = src_partition.input_nodes
73        # output_nodes contains output node and input buffer such as argX_X
74        outputs = [
75            node
76            for node in src_partition.output_nodes
77            if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder"
78        ]
79        assert (
80            len(outputs) == 1
81        ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}"
82        output = outputs[0]
83
84        ops = extract_ops_fn(src_partition.nodes)
85        input_node, weight_node, fn_node = ops[:3]
86        bias_node = None if len(ops) == 3 else ops[3]
87
88        # qnn htp does not support keepdim, the view_copy(reshape) should exist for now
89        if self._get_original_input(inputs, input_node).target in dq_ops:
90            input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs(
91                gm, self._get_original_input(inputs, input_node).args[0]
92            )
93        args = [input_node, weight_node]
94        if bias_node:
95            args.append(bias_node)
96
97        # We need a view copy node after linear op
98        with gm.graph.inserting_before(output):
99            linear_node = gm.graph.create_node(
100                "call_function", self.linear, tuple(args)
101            )
102            linear_node.meta = fn_node.meta
103            if list(output.users)[0].target in q_ops:
104                linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs(
105                    gm, list(output.users)[0]
106                )
107            for user in fn_node.users.copy():
108                user.replace_input_with(fn_node, linear_node)
109
110        # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
111        # TODO: Find a more general conditional statement.
112        linear_output = linear_node.meta["val"]
113        if linear_output.dim() == 3 and linear_output.shape[0] == 1:
114            with gm.graph.inserting_after(input_node):
115                input_users = list(input_node.users.keys())
116                input_tensor = input_node.meta["val"]
117                squeeze_dim = input_tensor.shape[-2:]
118                squeeze_node = gm.graph.create_node(
119                    "call_function",
120                    self.view_copy,
121                    (
122                        input_node,
123                        squeeze_dim,
124                    ),
125                )
126                # meta needs to be copied elementwisely for fake-tensor
127                # to be updated correctly and not affect meta of input_node
128                for k, v in input_node.meta.items():
129                    squeeze_node.meta[k] = v
130                squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
131                for user in input_users:
132                    if user == linear_node:
133                        user.replace_input_with(input_node, squeeze_node)
134
135            with gm.graph.inserting_after(linear_node):
136                output_users = list(linear_node.users.keys())
137                unsqueeze_dim = linear_output.shape
138                unsqueeze_node = gm.graph.create_node(
139                    "call_function",
140                    self.view_copy,
141                    (
142                        linear_node,
143                        unsqueeze_dim,
144                    ),
145                )
146                # meta needs to be copied elementwisely for fake-tensor
147                # to be updated correctly and not affect meta of unsqueeze_node
148                for k, v in linear_node.meta.items():
149                    unsqueeze_node.meta[k] = v
150                # update linear node's shape
151                linear_node.meta["val"] = linear_output.reshape(
152                    linear_output.shape[-2:]
153                )
154                for user in output_users:
155                    user.replace_input_with(linear_node, unsqueeze_node)
156
157    def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
158        mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
159        # weight -> permute -> input of mm
160        weight_node = mm_node.args[1].args[0]
161        input_node = mm_node.args[0]
162        return [input_node, weight_node, mm_node]
163
164    def _extract_addmm_ops(
165        self, partitioned_nodes: List[edge_op]
166    ) -> List[torch.fx.Node]:
167        addmm_node = [n for n in partitioned_nodes if n.target == self.addmm][0]
168        # weight -> permute -> input of addmm
169        weight_node = addmm_node.args[2].args[0]
170        input_node = addmm_node.args[1]
171        bias_node = addmm_node.args[0]
172        return [input_node, weight_node, addmm_node, bias_node]
173
174    def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
175        bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0]
176        add_node = [n for n in partitioned_nodes if n.target == self.add]
177
178        # weight -> expand_copy -> view_copy -> input of bmm
179        weight_node = bmm_node.args[1].args[0].args[0].args[0]
180        # input -> expand_copy -> view_copy -> input of bmm
181        input_node = bmm_node.args[0].args[0].args[0]
182
183        ret = [input_node, weight_node, bmm_node]
184        if add_node:
185            bias_node = add_node[0].args[1]
186            ret = [input_node, weight_node, add_node[0], bias_node]
187        else:
188            ret = [input_node, weight_node, bmm_node]
189
190        return ret
191
192    def _convert(self, graph_module: torch.fx.GraphModule):
193        partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear])
194        for _, src_partitions in partitions.items():
195            for src_partition in src_partitions:
196                op_cnt = Counter(
197                    [
198                        n.target
199                        for n in src_partition.nodes
200                        if isinstance(n.target, edge_op)
201                    ]
202                )
203                if self.linear in op_cnt:
204                    continue
205                elif op_cnt in self.addmm_patterns:
206                    self._convert_to_linear(
207                        graph_module, src_partition, self._extract_addmm_ops
208                    )
209                elif op_cnt in self.mm_patterns:
210                    self._convert_to_linear(
211                        graph_module, src_partition, self._extract_mm_ops
212                    )
213                elif op_cnt in self.bmm_patterns:
214                    self._convert_to_linear(
215                        graph_module, src_partition, self._extract_bmm_ops
216                    )
217                else:
218                    raise AssertionError(
219                        "Found a new pattern needed be converted to linear op"
220                    )
221
222    def call(self, graph_module: torch.fx.GraphModule):
223        self._convert(graph_module)
224        # We could not use get_source_partitions because it is the same source for MultiheadAttention
225        apply_addmm_mm_to_linear_transform(graph_module.graph)
226        dead_code_elimination_pass(graph_module)
227        graph_module.recompile()
228        return PassResult(graph_module, True)
229