1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import copy 5import functools 6from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING 7 8import torch 9import torch._dynamo as torchdynamo 10import torch.nn.functional as F 11from torch.ao.quantization.fake_quantize import ( 12 FakeQuantize, 13 FusedMovingAvgObsFakeQuantize, 14) 15from torch.ao.quantization.observer import ( 16 HistogramObserver, 17 MinMaxObserver, 18 MovingAverageMinMaxObserver, 19 MovingAveragePerChannelMinMaxObserver, 20 PerChannelMinMaxObserver, 21 PlaceholderObserver, 22) 23from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer 24from torch.ao.quantization.quantizer.utils import _get_module_name_filter 25from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 26 _convert_scalars_to_attrs, 27 OP_TO_ANNOTATOR, 28 OperatorConfig, 29 OperatorPatternType, 30 propagate_annotation, 31 QuantizationConfig, 32) 33 34 35if TYPE_CHECKING: 36 from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor 37 from torch.fx import Node 38 39 40__all__ = [ 41 "XNNPACKQuantizer", 42 "get_symmetric_quantization_config", 43] 44 45 46def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: 47 gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) 48 gm.graph.eliminate_dead_code() 49 return gm.graph 50 51 52def _get_linear_patterns(input_size: List[int]): 53 in_channels = input_size[-1] 54 out_channels = 8 # hard coding but this should not matter 55 weight = torch.ones((out_channels, in_channels)) 56 bias = torch.ones((out_channels,)) 57 act = torch.ones(input_size) 58 59 def linear_op(act, weight, bias=None): 60 return F.linear(act, weight, bias) 61 62 pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) 63 pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) 64 return [pattern_w_bias, pattern_wo_bias] 65 66 67def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: 68 supported_operators: Dict[str, List[OperatorPatternType]] = { 69 # Both conv and linear should be able to handle relu + hardtanh fusion since 70 # those are clamp ops 71 "conv2d": [ 72 [torch.nn.Conv2d, torch.nn.ReLU], 73 [torch.nn.Conv2d, F.relu], 74 [F.conv2d, torch.nn.ReLU], 75 [F.conv2d, F.relu], 76 ], 77 "linear": [[torch.nn.Linear], [F.linear]], 78 "add": [[torch.add]], 79 "adaptive_avg_pool2d": [ 80 [torch.nn.AdaptiveAvgPool2d], 81 [F.adaptive_avg_pool2d], 82 ], 83 } 84 return copy.deepcopy(supported_operators) 85 86 87def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: 88 supported_config_and_operators: List[OperatorConfig] = [] 89 for quantization_config in [ 90 get_symmetric_quantization_config(), 91 get_symmetric_quantization_config(is_qat=True), 92 get_symmetric_quantization_config(is_per_channel=True), 93 get_symmetric_quantization_config(is_per_channel=True, is_qat=True), 94 ]: 95 ops = _supported_symmetric_quantized_operators() 96 for pattern_list in ops.values(): 97 supported_config_and_operators.append( 98 OperatorConfig(quantization_config, pattern_list) 99 ) 100 return copy.deepcopy(supported_config_and_operators) 101 102 103@functools.lru_cache 104def get_symmetric_quantization_config( 105 is_per_channel: bool = False, 106 is_qat: bool = False, 107 is_dynamic: bool = False, 108 act_qmin: int = -128, 109 act_qmax: int = 127, 110 weight_qmin: int = -127, 111 weight_qmax: int = 127, 112): 113 extra_args: Dict[str, Any] = {"eps": 2**-12} 114 if is_qat: 115 if is_dynamic: 116 act_observer_or_fake_quant_ctr = FakeQuantize 117 dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( 118 averaging_constant=1 119 ) 120 extra_args["observer"] = dynamic_quant_observer 121 else: 122 act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] 123 else: 124 if is_dynamic: 125 act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] 126 else: 127 act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] 128 129 act_quantization_spec = QuantizationSpec( 130 dtype=torch.int8, 131 quant_min=act_qmin, 132 quant_max=act_qmax, 133 qscheme=torch.per_tensor_affine, 134 is_dynamic=is_dynamic, 135 observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( 136 **extra_args, 137 ), 138 ) 139 weight_qscheme = ( 140 torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric 141 ) 142 weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( 143 MinMaxObserver 144 ) 145 if is_qat: 146 # TODO: qat + per channel? 147 weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize 148 elif is_per_channel: 149 weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver 150 151 extra_args: Dict[str, Any] = {"eps": 2**-12} 152 if is_qat: 153 if weight_qscheme == torch.per_tensor_symmetric: 154 extra_args["observer"] = MovingAverageMinMaxObserver 155 else: 156 extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] 157 weight_quantization_spec = QuantizationSpec( 158 dtype=torch.int8, 159 quant_min=weight_qmin, 160 quant_max=weight_qmax, 161 qscheme=weight_qscheme, 162 ch_axis=0, 163 is_dynamic=False, 164 observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( 165 **extra_args 166 ), 167 ) 168 169 bias_quantization_spec = None 170 if is_dynamic: 171 quantization_config = QuantizationConfig( 172 act_quantization_spec, 173 None, 174 weight_quantization_spec, 175 bias_quantization_spec, 176 is_qat, 177 ) 178 else: 179 quantization_config = QuantizationConfig( 180 act_quantization_spec, 181 act_quantization_spec, 182 weight_quantization_spec, 183 bias_quantization_spec, 184 is_qat, 185 ) 186 return quantization_config 187 188 189def _get_supported_config_and_operators() -> List[OperatorConfig]: 190 return _get_supported_symmetric_config_and_operators() 191 192 193def _get_module_type_filter(tp: Callable): 194 """Get the module_type_filter function for a given module type, the filter accepts 195 a node and checks if the node comes from a module that has certain module type 196 197 For example: 198 node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear 199 200 201 >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule 202 >> print(module_type_filter(node)) 203 True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) 204 """ 205 206 tp_str = tp.__module__ + "." + tp.__qualname__ 207 208 def module_type_filter(n: Node) -> bool: 209 # example: { 210 # 'L__self___sub': ("L['self'].sub", <class '....Sub'>), 211 # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>) 212 # } 213 nn_module_stack = n.meta.get("nn_module_stack", {}) 214 types = [] 215 for _, t in nn_module_stack.values(): 216 # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) 217 # return type. Handle both cases. 218 if isinstance(t, type): 219 t = t.__module__ + "." + t.__qualname__ 220 types.append(t) 221 return tp_str in types 222 223 return module_type_filter 224 225 226def _get_not_module_type_or_name_filter( 227 tp_list: List[Callable], module_name_list: List[str] 228) -> Callable[[Node], bool]: 229 module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] 230 module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] 231 232 def not_module_type_or_name_filter(n: Node) -> bool: 233 return not any(f(n) for f in module_type_filters + module_name_list_filters) 234 235 return not_module_type_or_name_filter 236 237 238class XNNPACKQuantizer(Quantizer): 239 supported_config_and_operators = _get_supported_config_and_operators() 240 STATIC_QAT_ONLY_OPS = [ 241 "conv_bn_relu", 242 "conv_bn", 243 "conv_transpose_bn_relu", 244 "conv_transpose_bn", 245 ] 246 247 # static quantization ops (both PTQ and QAT) 248 # Preserve the order that fusions come before singular ops 249 STATIC_OPS = [ 250 "linear_relu", 251 "linear", 252 "conv_relu", 253 "conv", 254 "conv_transpose_relu", 255 "adaptive_avg_pool2d", 256 # TODO: move this to BoltNNQuantizer? 257 "gru_io_only", 258 "add_relu", 259 "add", 260 "mul_relu", 261 "mul", 262 "cat", 263 ] 264 265 DYNAMIC_OPS = [ 266 "linear", 267 ] 268 269 def __init__(self) -> None: 270 super().__init__() 271 self.global_config: Optional[QuantizationConfig] = None 272 self.operator_type_config: Dict[ 273 torch._ops.OpOverloadPacket, Optional[QuantizationConfig] 274 ] = {} 275 self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} 276 self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} 277 278 @classmethod 279 def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: 280 op_configs: Set[QuantizationConfig] = { 281 spec for spec, _ in cls.supported_config_and_operators 282 } 283 return list(op_configs) 284 285 @classmethod 286 def get_supported_operator_for_quantization_config( 287 cls, quantization_config: Optional[QuantizationConfig] 288 ) -> List[OperatorPatternType]: 289 if quantization_config is None: 290 all_ops = [] 291 for _, ops in cls.supported_config_and_operators: 292 all_ops.extend(ops) 293 return all_ops 294 295 for config, ops in cls.supported_config_and_operators: 296 # note: this assumes each entry in cls.supported_spec_and_operators 297 # corresponds to one spec, e.g. we don't have 298 # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] 299 # where the first and second entry have the same spec but did not 300 # merge the op list 301 if config == quantization_config: 302 return ops 303 return [] 304 305 def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: 306 self.global_config = quantization_config 307 return self 308 309 def set_operator_type( 310 self, 311 operator_type: torch._ops.OpOverloadPacket, 312 quantization_config: QuantizationConfig, 313 ) -> XNNPACKQuantizer: 314 self.operator_type_config[operator_type] = quantization_config 315 return self 316 317 def set_module_type( 318 self, module_type: Callable, quantization_config: QuantizationConfig 319 ): 320 """Set quantization_config for a submodule with type: `module_type`, for example: 321 quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator 322 patterns in the submodule with this module type with the given `quantization_config` 323 """ 324 self.module_type_config[module_type] = quantization_config 325 return self 326 327 def set_module_name( 328 self, module_name: str, quantization_config: Optional[QuantizationConfig] 329 ): 330 """Set quantization_config for a submodule with name: `module_name`, for example: 331 quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator 332 patterns in the submodule with this module name with the given `quantization_config` 333 """ 334 assert ( 335 quantization_config is not None 336 ), " quantization_config == None is not supported yet" 337 self.module_name_config[module_name] = quantization_config 338 return self 339 340 def transform_for_annotation( 341 self, model: torch.fx.GraphModule 342 ) -> torch.fx.GraphModule: 343 """Transforms scalar values to tensor attributes""" 344 return _convert_scalars_to_attrs(model) 345 346 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 347 """just handling global spec for now""" 348 # hacked for handling dynamic linear quant. will fix later. 349 if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] 350 model = self._annotate_for_dynamic_quantization_config(model) 351 else: 352 model = self._annotate_for_static_quantization_config(model) 353 propagate_annotation(model) 354 return model 355 356 def _annotate_all_static_patterns( 357 self, 358 model: torch.fx.GraphModule, 359 quantization_config: Optional[QuantizationConfig], 360 filter_fn: Optional[Callable[[Node], bool]] = None, 361 ) -> torch.fx.GraphModule: 362 # TODO: implement the support for None to be canceling out previous annotations 363 if quantization_config is None: 364 return model 365 366 if quantization_config.is_qat: 367 for op in self.STATIC_QAT_ONLY_OPS: 368 OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) 369 for op in self.STATIC_OPS: 370 OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) 371 return model 372 373 def _annotate_all_dynamic_patterns( 374 self, 375 model: torch.fx.GraphModule, 376 quantization_config: Optional[QuantizationConfig], 377 filter_fn: Optional[Callable[[Node], bool]] = None, 378 ) -> torch.fx.GraphModule: 379 # TODO: implement the support for None to be canceling out previous annotations 380 if quantization_config is None: 381 return model 382 383 for op in self.DYNAMIC_OPS: 384 OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) 385 return model 386 387 def _annotate_for_static_quantization_config( 388 self, model: torch.fx.GraphModule 389 ) -> torch.fx.GraphModule: 390 module_name_list = list(self.module_name_config.keys()) 391 for module_name, config in self.module_name_config.items(): 392 self._annotate_all_static_patterns( 393 model, config, _get_module_name_filter(module_name) 394 ) 395 396 tp_list = list(self.module_type_config.keys()) 397 for module_type, config in self.module_type_config.items(): 398 self._annotate_all_static_patterns( 399 model, config, _get_module_type_filter(module_type) 400 ) 401 402 self._annotate_all_static_patterns( 403 model, 404 self.global_config, 405 _get_not_module_type_or_name_filter(tp_list, module_name_list), 406 ) 407 return model 408 409 def _annotate_for_dynamic_quantization_config( 410 self, model: torch.fx.GraphModule 411 ) -> torch.fx.GraphModule: 412 module_name_list = list(self.module_name_config.keys()) 413 for module_name, config in self.module_name_config.items(): 414 self._annotate_all_dynamic_patterns( 415 model, config, _get_module_name_filter(module_name) 416 ) 417 418 tp_list = list(self.module_type_config.keys()) 419 for module_type, config in self.module_type_config.items(): 420 self._annotate_all_dynamic_patterns( 421 model, config, _get_module_type_filter(module_type) 422 ) 423 424 self._annotate_all_dynamic_patterns( 425 model, 426 self.global_config, 427 _get_not_module_type_or_name_filter(tp_list, module_name_list), 428 ) 429 return model 430 431 def validate(self, model: torch.fx.GraphModule) -> None: 432 pass 433 434 @classmethod 435 def get_supported_operators(cls) -> List[OperatorConfig]: 436 return cls.supported_config_and_operators 437