1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Sequence 7 8import torch 9from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY 10from executorch.backends.qualcomm.quantizer.quantizer import ( 11 get_16a8w_qnn_ptq_config, 12 get_8a8w_qnn_ptq_config, 13 get_ptq_per_channel_quant_config, 14 QuantizationConfig, 15) 16from executorch.exir.dialects._ops import ops as exir_ops 17from torch.ao.quantization.observer import MinMaxObserver 18from torch.ao.quantization.quantizer import ( 19 QuantizationAnnotation, 20 SharedQuantizationSpec, 21) 22from torch.fx import Node 23 24 25def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: 26 """ 27 This function is specific for matmul op 16a8w. 28 """ 29 30 def annotate_matmul(node: Node, quantization_config: QuantizationConfig): 31 input_qspec_map = {} 32 input_act = node.args[0] 33 input_spec = quantization_config.input_activation 34 input_qspec_map[input_act] = input_spec 35 36 input_act1 = node.args[1] 37 input_spec1 = quantization_config.weight 38 input_qspec_map[input_act1] = input_spec1 39 40 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 41 input_qspec_map=input_qspec_map, 42 output_qspec=quantization_config.output_activation, 43 _annotated=True, 44 ) 45 46 def annotate_cat(node: Node, quantization_config: QuantizationConfig): 47 input_nodes = node.args[0] 48 49 first_input_node = input_nodes[0] 50 input_qspec_map = {} 51 input_qspec_map[first_input_node] = quantization_config.input_activation 52 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 53 (first_input_node, node) 54 ) 55 56 for input_node in input_nodes[1:]: 57 if input_node not in input_qspec_map: 58 input_qspec_map[input_node] = share_qparams_with_input_act0_qspec 59 60 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 61 input_qspec_map=input_qspec_map, 62 output_qspec=share_qparams_with_input_act0_qspec, 63 _annotated=True, 64 ) 65 66 def annotate_single_in_single_out( 67 node: Node, quantization_config: QuantizationConfig 68 ) -> None: 69 70 input_qspec_map = {} 71 input_act = node.args[0] 72 input_qspec_map[input_act] = quantization_config.input_activation 73 74 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 75 input_qspec_map=input_qspec_map, 76 output_qspec=quantization_config.output_activation, 77 _annotated=True, 78 ) 79 80 def annotate_matmul_input1(node: Node): 81 quantization_config_8a8w = get_8a8w_qnn_ptq_config( 82 act_symmetric=True, act_observer=MinMaxObserver 83 ) 84 while isinstance(node, Node) and node.op == "call_function": 85 if node.target in [ 86 torch.ops.aten.permute.default, 87 torch.ops.aten.transpose.int, 88 ]: 89 annotate_single_in_single_out(node, quantization_config_8a8w) 90 node = node.args[0] 91 elif node.target == torch.ops.aten.cat.default: 92 annotate_cat(node, quantization_config_8a8w) 93 node = node.args[0][0] 94 else: 95 node = node.args[0] 96 97 quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) 98 99 for node in gm.graph.nodes: 100 if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: 101 annotate_matmul(node, quantization_config_16a8w) 102 annotate_matmul_input1(node.args[1]) 103 104 105def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 106 """ 107 This function is specific for llama matmul op 16a8w. 108 """ 109 110 def annotate_matmul(node: Node, quantization_config: QuantizationConfig): 111 input_qspec_map = {} 112 input_act = node.args[0] 113 input_spec = quantization_config.input_activation 114 input_qspec_map[input_act] = input_spec 115 input_act1 = node.args[1] 116 input_spec1 = quantization_config.weight 117 input_qspec_map[input_act1] = input_spec1 118 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 119 input_qspec_map=input_qspec_map, 120 output_qspec=quantization_config.output_activation, 121 _annotated=True, 122 ) 123 124 def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: 125 input = node.args[0] 126 value = node.args[2] 127 input_qspec_map = {} 128 input_qspec_map[input] = quantization_config.input_activation 129 input_qspec_map[value] = SharedQuantizationSpec((input, node)) 130 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 131 input_qspec_map=input_qspec_map, 132 output_qspec=SharedQuantizationSpec((input, node)), 133 _annotated=True, 134 ) 135 136 def annotate_single_in_single_out( 137 node: Node, quantization_config: QuantizationConfig 138 ) -> None: 139 input_qspec_map = {} 140 input_act = node.args[0] 141 input_qspec_map[input_act] = quantization_config.input_activation 142 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 143 input_qspec_map=input_qspec_map, 144 output_qspec=quantization_config.output_activation, 145 _annotated=True, 146 ) 147 148 def annotate_cat(node: Node, quantization_config: QuantizationConfig): 149 input_nodes = node.args[0] 150 assert isinstance(input_nodes, Sequence) 151 first_input_node = input_nodes[0] 152 input_qspec_map = {} 153 assert isinstance(first_input_node, Node) 154 assert isinstance(node, Node) 155 input_qspec_map[first_input_node] = quantization_config.input_activation 156 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 157 (first_input_node, node) 158 ) 159 for input_node in input_nodes[1:]: 160 if input_node not in input_qspec_map: 161 assert isinstance(input_node, Node) 162 input_qspec_map[input_node] = share_qparams_with_input_act0_qspec 163 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 164 input_qspec_map=input_qspec_map, 165 output_qspec=share_qparams_with_input_act0_qspec, 166 _annotated=True, 167 ) 168 169 def is_edge_condition(node: Node): 170 if not isinstance(node, Node) or node.op != "call_function": 171 return True 172 return False 173 174 def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): 175 if is_edge_condition(node): 176 return 177 if node.target in [ 178 torch.ops.aten.index_put.default, 179 torch.ops.aten.index_put_.default, 180 ]: 181 annotate_index_put(node, quantization_config) 182 annotate_matmul_input1(node.args[0], quantization_config) 183 elif node.target == torch.ops.aten.cat.default: 184 annotate_cat(node, quantization_config) 185 # Expect that the inputs of the cat op are select ops 186 for arg in node.args[0]: 187 annotate_matmul_input1(arg, quantization_config) 188 else: 189 annotate_single_in_single_out(node, quantization_config) 190 annotate_matmul_input1(node.args[0], quantization_config) 191 192 # Annotate 16a8w for matmul op to get better performance 193 quantization_config_16a8w = get_16a8w_qnn_ptq_config() 194 # Annotate 8a8w for second input of matmul until past_kv_cache 195 quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) 196 for node in gm.graph.nodes: 197 if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: 198 if "nn_module_stack" in node.meta: 199 module_values_list = list(node.meta["nn_module_stack"].values()) 200 full_qualified_name = module_values_list[-1][0] 201 if "SDPA" in full_qualified_name: 202 annotate_matmul(node, quantization_config_16a8w) 203 annotate_matmul_input1(node.args[1], quantization_config_8a8w) 204 205 206def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None: 207 def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: 208 input_qspec_map = {} 209 input_act = node.args[0] 210 input_spec = quantization_config.input_activation 211 input_qspec_map[input_act] = input_spec 212 213 weight = node.args[1] 214 input_qspec_map[weight] = quantization_config.weight 215 216 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 217 input_qspec_map=input_qspec_map, 218 output_qspec=quantization_config.output_activation, 219 _annotated=True, 220 ) 221 222 quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( 223 torch.uint16, weight_dtype=torch.int8 224 ) 225 for node in gm.graph.nodes: 226 if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: 227 if "nn_module_stack" in node.meta: 228 module_values_list = list(node.meta["nn_module_stack"].values()) 229 full_qualified_name = module_values_list[0][0] 230 if full_qualified_name == "L['self'].llama.output": 231 annotate_conv2d( 232 node, quantization_config=quantization_config_16a8w_per_channel 233 ) 234 235 236def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule): 237 """ 238 Annotate matmul op with 16a8w quantization config 239 """ 240 241 def annotate_matmul(node: Node, quantization_config: QuantizationConfig): 242 input_qspec_map = {} 243 input_act = node.args[0] 244 input_spec = quantization_config.input_activation 245 input_qspec_map[input_act] = input_spec 246 input_act1 = node.args[1] 247 input_spec1 = quantization_config.weight 248 input_qspec_map[input_act1] = input_spec1 249 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 250 input_qspec_map=input_qspec_map, 251 output_qspec=quantization_config.output_activation, 252 _annotated=True, 253 ) 254 255 # Annotate 16a8w for matmul op to get better performance 256 quantization_config_16a8w = get_16a8w_qnn_ptq_config() 257 for node in gm.graph.nodes: 258 if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: 259 annotate_matmul(node, quantization_config_16a8w) 260 261 262def get_custom_quant_ios_dtype( 263 cache_shape: torch.Size, 264 node: torch.fx.Node, 265 kv_dtype=torch.uint8, 266 sharding_dtype=torch.uint16, 267): 268 """ 269 This function is specific for llama inputs and outputs 270 """ 271 if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: 272 return kv_dtype 273 274 # Tag index put node before copy node, because copy is a skipped node in qnn 275 if ( 276 exir_ops.edge.aten.index_put.default == node.target 277 and node.meta["val"].shape == cache_shape 278 ): 279 return kv_dtype 280 281 # Tag sharding io 282 if exir_ops.edge.llama.fallback.default in [ 283 u.target for u in list(node.users.keys()) 284 ] + [node.target]: 285 return sharding_dtype 286 287 # Tag index op as quantized tensors. It is caused by sharding 288 if exir_ops.edge.aten.index.Tensor in [ 289 u.target for u in list(node.users.keys()) 290 ] + [node.target]: 291 return sharding_dtype 292