1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from __future__ import annotations 8 9from typing import cast, Optional, Union 10 11import torch 12from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( 13 TagImplicitQDqPass, 14) 15from executorch.backends.xnnpack.utils.quant_utils import ( 16 extract_qdq_affine_op_args_for_decomposed_ops, 17 is_affine_qdq, 18 is_dequant, 19 is_dynamic_qdq, 20 is_per_channel, 21 is_per_channel_group, 22 is_quant, 23) 24from executorch.backends.xnnpack.utils.utils import ( 25 check_or_raise, 26 get_param_tensor, 27 is_param_node, 28) 29from executorch.exir.dialects._ops import ops as exir_ops 30from torch.export import ExportedProgram 31 32 33class QuantParams: 34 """ 35 QuantParams class, to represent the paramaters and meta data needed 36 to quantize a tensor. The metadata can technically all be encapsulated 37 within the quant torch.fx.Node, however, there are some cases in which 38 nodes which are meant to be quantized for XNNPACK are not quantized 39 in PyTorch IR, specifically bias nodes. In this case, we can still build 40 quantizer class to serialize the quantized attributes needed for XNNPACK. 41 42 Attributes: 43 per_channel: Whether this quantization is per channel or per tensor 44 q_input: node that is the input to this quantization 45 scale: tensor or float that is used as the quantization scale 46 zp: tensor or float that is used as the quantization zero point 47 axis: used for per_channel quantizaiton, representing the axis 48 dtype: dtype of the type being quantized to 49 qmin: quantization minimum 50 qmax: quantization maximum 51 is_output: whether this is an output node or not 52 is_input: whether this is an input node or not 53 """ 54 55 def __init__( 56 self, 57 per_channel: bool, 58 q_input: torch.fx.Node, 59 scale: Union[torch.Tensor, float], 60 zp: Union[torch.Tensor, float], 61 axis: int, 62 dtype: torch.dtype, 63 qmax: int, 64 qmin: int, 65 is_output: bool, 66 is_input: bool, 67 is_dynamic: bool = False, 68 num_nonbatch_dims: int = 1, 69 group_size: int = 0, 70 ) -> None: 71 self.per_channel = per_channel 72 self.q_input = q_input 73 self.scale = scale 74 self.zp = zp 75 self.axis = axis 76 self.dtype = dtype 77 self.qmax = qmax 78 self.qmin = qmin 79 self.is_output = is_output 80 self.is_input = is_input 81 self.is_dynamic = is_dynamic 82 self.num_nonbatch_dims = num_nonbatch_dims 83 self.is_qc4w = ( 84 self.per_channel 85 and not self.is_dynamic 86 and self.qmin == -8 87 and self.qmax == 7 88 and self.dtype == torch.int8 89 ) 90 91 # Groupwise quantization for weight 92 self.per_channel_group = False 93 self.group_size = group_size 94 if self.group_size > 0: 95 assert ( 96 self.per_channel is True 97 ), "Only per channel quantization supports groupwise quantization" 98 assert ( 99 cast(torch.Tensor, scale).ndim == 2 100 ), "Scale must be 2D for per channel groupwise quant" 101 self.per_channel_group = True 102 assert group_size > 0, "Group size must be greater than 0" 103 self.is_per_channel_group = self.per_channel and self.group_size > 0 104 105 def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: 106 # Do nothing if already quantized by the Quantizer 107 if tensor.dtype == self.dtype: 108 return tensor 109 110 if self.per_channel: 111 assert ( 112 self.per_channel_group is False 113 ), f"Not expecting per channel group quantization, got q dtype: {self.dtype}, tensor.dtype {tensor.dtype}" 114 assert ( 115 tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0] 116 ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}" 117 118 assert ( 119 tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0] 120 ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}" 121 122 # Assuming folded quant weights 123 # TODO Add support for unfolded weights 124 assert not self.is_qc4w, "Not expecting QC4W per channel tensor" 125 126 return exir_ops.edge.quantized_decomposed.quantize_per_channel.default( 127 tensor, self.scale, self.zp, self.axis, self.qmin, self.qmax, self.dtype 128 ) 129 else: 130 return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( 131 tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype 132 ) 133 134 @classmethod 135 def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams: 136 q_input = quant_node.args[0] # fp32 input 137 assert isinstance(q_input, torch.fx.Node) 138 # TODO - materialize this from the quant_node scale count and val shape 139 num_nonbatch_dims = 1 140 141 return cls( 142 per_channel=False, # True is not valid 143 q_input=q_input, 144 scale=0.0, # no need 145 zp=0.0, # no need 146 axis=0, # no need 147 dtype=torch.float32, # will be quantized at runtime 148 qmax=0, # no need 149 qmin=0, # no need 150 is_output=False, 151 is_input=q_input.op == "placeholder", 152 is_dynamic=True, 153 num_nonbatch_dims=num_nonbatch_dims, 154 ) 155 156 @classmethod 157 def from_q_dq_node( 158 cls, quant_node: torch.fx.Node, ep: Optional[ExportedProgram] = None 159 ) -> QuantParams: 160 check_or_raise( 161 is_quant(quant_node) or is_dequant(quant_node), 162 f"building quantizer from q/dq node but was given node:{quant_node}", 163 ) 164 q_input = quant_node.all_input_nodes[0] 165 166 # TODO: Use presence of choose_qparam node to determine if this is a dynamic quantization 167 if is_dynamic_qdq(quant_node): 168 return cls._from_dynamic_input_node(quant_node) 169 170 per_channel = is_per_channel(quant_node) 171 172 _groupwise = is_per_channel_group(quant_node) 173 quant_node_args = quant_node.args 174 if _groupwise and is_affine_qdq(quant_node): 175 quant_node_args = extract_qdq_affine_op_args_for_decomposed_ops(quant_node) 176 177 scale = quant_node_args[1] 178 zp = quant_node_args[2] 179 axis = 0 180 if per_channel: 181 assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str) 182 assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str) 183 assert ( 184 ep is not None 185 ), "ExportedProgram must be provided to extract per channel params" 186 187 def _get_tensor(node): 188 param = get_param_tensor(ep, node) 189 assert param is not None, f"Expected to find param tensor for {node}" 190 return cast(torch.Tensor, param) 191 192 scale = _get_tensor(scale) 193 zp = _get_tensor(zp) 194 axis = cast(int, quant_node_args[3]) 195 196 if _groupwise: 197 scale_tensor = cast(torch.Tensor, scale) 198 if scale_tensor.ndim == 1: 199 scale_tensor = scale_tensor.reshape(-1, 1) 200 zp = zp.reshape(-1, 1) 201 scale = scale_tensor 202 203 assert ( 204 scale_tensor.ndim == 2 205 ), "Weight scale must be 2D for per_channel_group [de]quant node, got {scale.ndim}D" 206 axis = 0 # axis is ignored for groupwise quantization 207 208 check_or_raise( 209 bool( 210 quant_node_args[-1] != torch.uint8 211 or quant_node_args[-1] != torch.quint8 212 ), 213 "XNNPACK does not support unsigned quantization", 214 ) 215 216 if _groupwise: 217 _ = quant_node_args[-1] # output dtype - not used 218 group_size = cast(int, quant_node_args[-2]) 219 dtype = cast(torch.dtype, quant_node_args[-3]) 220 qmax = cast(int, quant_node_args[-4]) 221 qmin = cast(int, quant_node_args[-5]) 222 else: 223 group_size = 0 224 dtype = cast(torch.dtype, quant_node_args[-1]) 225 qmax = cast(int, quant_node_args[-2]) 226 qmin = cast(int, quant_node_args[-3]) 227 228 is_output = any( 229 user_node.op == "output" for user_node in quant_node.users.keys() 230 ) 231 is_input = q_input.op == "placeholder" 232 return cls( 233 per_channel, 234 q_input, 235 scale, 236 zp, 237 axis, 238 dtype, 239 qmax, 240 qmin, 241 is_output, 242 is_input, 243 group_size=group_size, 244 ) 245 246 @classmethod 247 def from_weights( 248 cls, tensor_node: torch.fx.Node, ep: Optional[ExportedProgram] = None 249 ) -> Optional[QuantParams]: 250 if not is_dequant(tensor_node): 251 return None 252 253 # source node for quant params 254 src = tensor_node 255 256 # is input of dq is q? 257 dq_input = src.all_input_nodes[0] 258 if is_quant(dq_input): 259 src = dq_input 260 261 # replace this with pointing to the actual weight value. 262 # if no one else uses this weight value then take it out of the toplevel module 263 check_or_raise( 264 src.all_input_nodes[0].op in ["get_attr", "placeholder"], 265 f"q->dq->permute_copy not derived from static weight, input to the q or dq (for folded quant) node: {src.all_input_nodes[0]}", 266 ) 267 268 return cls.from_q_dq_node(src, ep) 269 270 @classmethod 271 def from_inputs( 272 cls, tensor_node: torch.fx.Node, ep: ExportedProgram 273 ) -> Optional[QuantParams]: 274 # tensor_node is quantized if it is produced by a dequant node 275 if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq( 276 tensor_node 277 ): 278 dq_input = cast(torch.fx.Node, tensor_node.args[0]) 279 if is_quant(dq_input): 280 q_input = cast(torch.fx.Node, dq_input.args[0]) 281 if is_param_node(ep, q_input): 282 return cls.from_q_dq_node(dq_input) 283 return cls.from_q_dq_node(tensor_node) 284 285 return None 286 287 @classmethod 288 def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]: 289 # tensor_node can also be quantized if it is used as in q -> dq 290 if len(tensor_node.users) == 1: 291 q = list(tensor_node.users.keys())[0] 292 # Check if user is a q node 293 if is_quant(q) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(q): 294 return cls.from_q_dq_node(q) 295 296 return None 297 298 @classmethod 299 def from_bias( 300 cls, 301 bias: torch.fx.Node, 302 weight_quantizer: Optional[QuantParams], 303 input_quantizer: Optional[QuantParams], 304 ) -> Optional[QuantParams]: 305 if weight_quantizer is None or input_quantizer is None: 306 check_or_raise( 307 weight_quantizer is None and input_quantizer is None, 308 "Weight and Input should both be quantized", 309 ) 310 return None 311 312 if input_quantizer.is_dynamic: 313 # No need to quantize bias for dyanamic quantization 314 return None 315 316 check_or_raise( 317 not input_quantizer.per_channel, 318 "Input can not be quantized per channel", 319 ) 320 321 # Only per_tensor quantization is supported for input here 322 check_or_raise( 323 isinstance(input_quantizer.scale, float), 324 f"q_input scale should be float, but got {input_quantizer.scale}", 325 ) 326 return cls( 327 per_channel=weight_quantizer.per_channel, 328 q_input=bias, 329 scale=weight_quantizer.scale * cast(float, input_quantizer.scale), 330 zp=weight_quantizer.zp * 0, 331 axis=0, # not using weight_quantizer.axis because bias is always of shape [out_channels] i.e. 1D 332 dtype=torch.int32, 333 qmin=-(2**31), 334 qmax=(2**31) - 1, 335 is_output=False, 336 is_input=False, 337 ) 338