1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from enum import IntEnum, unique 7from typing import Callable, Optional, Sequence, Set 8 9import torch 10from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum 11from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu 12from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( 13 RecomposePixelUnshuffle, 14) 15from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange 16from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer 17from executorch.backends.transforms.decompose_sdpa import ( 18 DecomposeScaledDotProductAttention, 19) 20 21from torch._ops import OpOverload 22from torch.ao.quantization.quantizer import Quantizer 23from torch.fx import GraphModule 24 25from .annotators import OP_ANNOTATOR 26 27from .qconfig import ( 28 get_16a16w_qnn_ptq_config, 29 get_16a4w_qnn_ptq_config, 30 get_16a4w_qnn_qat_config, 31 get_16a8w_qnn_ptq_config, 32 get_8a8w_qnn_ptq_config, 33 get_8a8w_qnn_qat_config, 34 get_ptq_per_channel_quant_config, 35 get_qat_per_channel_quant_config, 36 QuantizationConfig, 37) 38 39# To bypass the meta internal test error 40get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config 41 42__all__ = [ 43 "QnnQuantizer", 44 "QuantDtype", 45 "get_16a4w_qnn_ptq_config", 46 "get_16a8w_qnn_ptq_config", 47 "get_16a16w_qnn_ptq_config", 48 "get_8a8w_qnn_ptq_config", 49 "get_8a8w_qnn_qat_config", 50 "get_16a4w_qnn_qat_config", 51] 52 53 54@unique 55class QuantDtype(IntEnum): 56 """ 57 bits of activation and bits of weight 58 """ 59 60 use_16a16w = 0 61 use_16a8w = 1 62 use_16a4w = 2 63 use_8a8w = 3 64 65 66quant_config_dict = { 67 # PTQ 68 (QuantDtype.use_16a16w, False): ( 69 get_16a16w_qnn_ptq_config, 70 get_ptq_per_channel_quant_config(torch.uint16, torch.int16), 71 ), 72 (QuantDtype.use_16a8w, False): ( 73 get_16a8w_qnn_ptq_config, 74 get_ptq_per_channel_quant_config(torch.uint16, torch.int8), 75 ), 76 (QuantDtype.use_16a4w, False): ( 77 get_16a4w_qnn_ptq_config, 78 get_ptq_per_channel_quant_config(torch.uint16, "int4"), 79 ), 80 (QuantDtype.use_8a8w, False): ( 81 get_8a8w_qnn_ptq_config, 82 get_ptq_per_channel_quant_config(), 83 ), 84 # QAT, 85 (QuantDtype.use_16a4w, True): ( 86 get_16a4w_qnn_qat_config, 87 get_qat_per_channel_quant_config(torch.uint16, "int4"), 88 ), 89 (QuantDtype.use_8a8w, True): ( 90 get_8a8w_qnn_qat_config, 91 get_qat_per_channel_quant_config(), 92 ), 93} 94 95 96class QnnQuantizer(Quantizer): 97 SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) 98 99 def __init__(self): 100 super().__init__() 101 self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() 102 103 self.is_qat = False 104 self.quant_dtype = QuantDtype.use_8a8w 105 self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() 106 self.per_channel_quant_config = get_ptq_per_channel_quant_config() 107 self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() 108 109 self.custom_quant_annotations: Sequence[Callable] = [] 110 self.discard_nodes: Set[str] = set() 111 112 def _annotate(self, gm: GraphModule) -> None: 113 for node in gm.graph.nodes: 114 if node.name in self.discard_nodes: 115 continue 116 117 quant_config = self._get_quant_config(node.target) 118 if quant_config: 119 OP_ANNOTATOR[node.target](node, quant_config) 120 121 def _annotate_custom_annotation(self, gm: GraphModule) -> None: 122 for annotation_func in self.custom_quant_annotations: 123 annotation_func(gm) 124 125 def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: 126 """ 127 Priority: 128 1. is one of use_per_channel_weight_quant_ops 129 2. quant config 130 """ 131 if isinstance(op, str): 132 return 133 134 if op in self.use_per_channel_weight_quant_ops: 135 return self.per_channel_quant_config 136 137 if op in self.quant_ops: 138 return self.quant_config 139 140 print(f"No quant config is implemented for op, {op}") 141 142 def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): 143 if enable: 144 self.use_per_channel_weight_quant_ops.update(ops) 145 else: 146 self.use_per_channel_weight_quant_ops.difference_update(ops) 147 148 def add_custom_quant_annotations( 149 self, custom_quant_annotations: Sequence[Callable] 150 ) -> None: 151 self.custom_quant_annotations = custom_quant_annotations 152 153 def add_discard_nodes(self, nodes: Sequence[str]) -> None: 154 self.discard_nodes = set(nodes) 155 156 def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: 157 for op in ops: 158 self.quant_ops.remove(op) 159 160 def annotate(self, model: GraphModule) -> GraphModule: 161 self._annotate(model) 162 self._annotate_custom_annotation(model) 163 164 return model 165 166 def get_supported_ops(self) -> Set[OpOverload]: 167 return self.SUPPORTED_OPS 168 169 def set_quant_config( 170 self, quant_dtype: QuantDtype, is_qat=False, act_observer=None 171 ) -> None: 172 self.quant_dtype = quant_dtype 173 self.is_qat = is_qat 174 if (quant_dtype, is_qat) not in quant_config_dict: 175 raise RuntimeError( 176 f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" 177 ) 178 179 quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ 180 (quant_dtype, is_qat) 181 ] 182 self.quant_config = ( 183 quant_config_fuc(act_observer) if act_observer else quant_config_fuc() 184 ) 185 186 def set_per_channel_conv_quant(self, enable: bool) -> None: 187 conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} 188 self._update_per_channel_weight_quant_ops(conv_ops, enable) 189 190 def set_per_channel_linear_quant(self, enable: bool) -> None: 191 linear_ops = { 192 torch.ops.aten.linear.default, 193 } 194 self._update_per_channel_weight_quant_ops(linear_ops, enable) 195 196 def transform_for_annotation(self, model: GraphModule) -> GraphModule: 197 model = ReduceDynamicRange()(model).graph_module 198 model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module 199 model = DecomposeScaledDotProductAttention()(model).graph_module 200 model = DecomposeSilu()(model).graph_module 201 model = DecomposeEinsum()(model).graph_module 202 model = ReplaceInfBuffer()(model).graph_module 203 return model 204 205 def validate(self, model: GraphModule) -> None: 206 pass 207