1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from collections import Counter 7from typing import Callable, List 8 9import torch 10from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS 11from executorch.backends.transforms.addmm_mm_to_linear import ( 12 apply_addmm_mm_to_linear_transform, 13) 14from executorch.exir.dialects._ops import ops as exir_ops 15from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op 16from executorch.exir.pass_base import ExportPass, PassResult 17from executorch.exir.passes import dead_code_elimination_pass 18 19from torch.fx.passes.utils.source_matcher_utils import ( 20 get_source_partitions, 21 SourcePartition, 22) 23 24from .utils import dq_ops, get_quant_attrs, q_ops 25 26 27class ConvertToLinear(ExportPass): 28 """ 29 Handle missing quantization tag for addmm op after decomposing 30 """ 31 32 view_copy = exir_ops.edge.aten.view_copy.default 33 permute_copy = exir_ops.edge.aten.permute_copy.default 34 expand_copy = exir_ops.edge.aten.expand_copy.default 35 linear = exir_ops.edge.aten.linear.default 36 add = exir_ops.edge.aten.add.Tensor 37 addmm = exir_ops.edge.aten.addmm.default 38 bmm = exir_ops.edge.aten.bmm.default 39 mm = exir_ops.edge.aten.mm.default 40 41 addmm_patterns = [ 42 {view_copy: 2, permute_copy: 1, addmm: 1}, 43 {permute_copy: 1, addmm: 1}, 44 ] 45 46 bmm_patterns = [ 47 {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1}, 48 {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1}, 49 ] 50 51 mm_patterns = [ 52 {view_copy: 2, permute_copy: 1, mm: 1}, 53 {permute_copy: 1, mm: 1}, 54 ] 55 56 def __init__(self): 57 super(ConvertToLinear, self).__init__() 58 59 def _get_original_input( 60 self, inputs: List[torch.fx.Node], cur_node: torch.fx.Node 61 ) -> torch.fx.Node: 62 while cur_node not in inputs and cur_node.args: 63 cur_node = cur_node.args[0] 64 return cur_node 65 66 def _convert_to_linear( 67 self, 68 gm: torch.fx.GraphModule, 69 src_partition: SourcePartition, 70 extract_ops_fn: Callable, 71 ): 72 inputs = src_partition.input_nodes 73 # output_nodes contains output node and input buffer such as argX_X 74 outputs = [ 75 node 76 for node in src_partition.output_nodes 77 if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder" 78 ] 79 assert ( 80 len(outputs) == 1 81 ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}" 82 output = outputs[0] 83 84 ops = extract_ops_fn(src_partition.nodes) 85 input_node, weight_node, fn_node = ops[:3] 86 bias_node = None if len(ops) == 3 else ops[3] 87 88 # qnn htp does not support keepdim, the view_copy(reshape) should exist for now 89 if self._get_original_input(inputs, input_node).target in dq_ops: 90 input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( 91 gm, self._get_original_input(inputs, input_node).args[0] 92 ) 93 args = [input_node, weight_node] 94 if bias_node: 95 args.append(bias_node) 96 97 # We need a view copy node after linear op 98 with gm.graph.inserting_before(output): 99 linear_node = gm.graph.create_node( 100 "call_function", self.linear, tuple(args) 101 ) 102 linear_node.meta = fn_node.meta 103 if list(output.users)[0].target in q_ops: 104 linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( 105 gm, list(output.users)[0] 106 ) 107 for user in fn_node.users.copy(): 108 user.replace_input_with(fn_node, linear_node) 109 110 # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node 111 # TODO: Find a more general conditional statement. 112 linear_output = linear_node.meta["val"] 113 if linear_output.dim() == 3 and linear_output.shape[0] == 1: 114 with gm.graph.inserting_after(input_node): 115 input_users = list(input_node.users.keys()) 116 input_tensor = input_node.meta["val"] 117 squeeze_dim = input_tensor.shape[-2:] 118 squeeze_node = gm.graph.create_node( 119 "call_function", 120 self.view_copy, 121 ( 122 input_node, 123 squeeze_dim, 124 ), 125 ) 126 # meta needs to be copied elementwisely for fake-tensor 127 # to be updated correctly and not affect meta of input_node 128 for k, v in input_node.meta.items(): 129 squeeze_node.meta[k] = v 130 squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) 131 for user in input_users: 132 if user == linear_node: 133 user.replace_input_with(input_node, squeeze_node) 134 135 with gm.graph.inserting_after(linear_node): 136 output_users = list(linear_node.users.keys()) 137 unsqueeze_dim = linear_output.shape 138 unsqueeze_node = gm.graph.create_node( 139 "call_function", 140 self.view_copy, 141 ( 142 linear_node, 143 unsqueeze_dim, 144 ), 145 ) 146 # meta needs to be copied elementwisely for fake-tensor 147 # to be updated correctly and not affect meta of unsqueeze_node 148 for k, v in linear_node.meta.items(): 149 unsqueeze_node.meta[k] = v 150 # update linear node's shape 151 linear_node.meta["val"] = linear_output.reshape( 152 linear_output.shape[-2:] 153 ) 154 for user in output_users: 155 user.replace_input_with(linear_node, unsqueeze_node) 156 157 def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: 158 mm_node = [n for n in partitioned_nodes if n.target == self.mm][0] 159 # weight -> permute -> input of mm 160 weight_node = mm_node.args[1].args[0] 161 input_node = mm_node.args[0] 162 return [input_node, weight_node, mm_node] 163 164 def _extract_addmm_ops( 165 self, partitioned_nodes: List[edge_op] 166 ) -> List[torch.fx.Node]: 167 addmm_node = [n for n in partitioned_nodes if n.target == self.addmm][0] 168 # weight -> permute -> input of addmm 169 weight_node = addmm_node.args[2].args[0] 170 input_node = addmm_node.args[1] 171 bias_node = addmm_node.args[0] 172 return [input_node, weight_node, addmm_node, bias_node] 173 174 def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: 175 bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0] 176 add_node = [n for n in partitioned_nodes if n.target == self.add] 177 178 # weight -> expand_copy -> view_copy -> input of bmm 179 weight_node = bmm_node.args[1].args[0].args[0].args[0] 180 # input -> expand_copy -> view_copy -> input of bmm 181 input_node = bmm_node.args[0].args[0].args[0] 182 183 ret = [input_node, weight_node, bmm_node] 184 if add_node: 185 bias_node = add_node[0].args[1] 186 ret = [input_node, weight_node, add_node[0], bias_node] 187 else: 188 ret = [input_node, weight_node, bmm_node] 189 190 return ret 191 192 def _convert(self, graph_module: torch.fx.GraphModule): 193 partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear]) 194 for _, src_partitions in partitions.items(): 195 for src_partition in src_partitions: 196 op_cnt = Counter( 197 [ 198 n.target 199 for n in src_partition.nodes 200 if isinstance(n.target, edge_op) 201 ] 202 ) 203 if self.linear in op_cnt: 204 continue 205 elif op_cnt in self.addmm_patterns: 206 self._convert_to_linear( 207 graph_module, src_partition, self._extract_addmm_ops 208 ) 209 elif op_cnt in self.mm_patterns: 210 self._convert_to_linear( 211 graph_module, src_partition, self._extract_mm_ops 212 ) 213 elif op_cnt in self.bmm_patterns: 214 self._convert_to_linear( 215 graph_module, src_partition, self._extract_bmm_ops 216 ) 217 else: 218 raise AssertionError( 219 "Found a new pattern needed be converted to linear op" 220 ) 221 222 def call(self, graph_module: torch.fx.GraphModule): 223 self._convert(graph_module) 224 # We could not use get_source_partitions because it is the same source for MultiheadAttention 225 apply_addmm_mm_to_linear_transform(graph_module.graph) 226 dead_code_elimination_pass(graph_module) 227 graph_module.recompile() 228 return PassResult(graph_module, True) 229