1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10# 11# Quantizer for Arm backend 12# 13 14from __future__ import annotations 15 16import copy 17import functools 18from typing import Any, Callable, Dict, List, Optional, Set 19 20import torch 21import torch.nn.functional as F 22from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager 23 24from executorch.backends.arm.quantizer import arm_quantizer_utils 25from executorch.backends.arm.quantizer.arm_quantizer_utils import ( 26 mark_nodes_as_annotated, 27 propagate_annotation, 28) 29from executorch.backends.arm.quantizer.quantization_annotation import ( 30 OP_TO_ANNOTATOR, 31 OperatorConfig, 32 OperatorPatternType, 33) 34from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig 35from torch.ao.quantization.fake_quantize import ( 36 FakeQuantize, 37 FusedMovingAvgObsFakeQuantize, 38) 39from torch.ao.quantization.observer import ( 40 HistogramObserver, 41 MinMaxObserver, 42 MovingAverageMinMaxObserver, 43 MovingAveragePerChannelMinMaxObserver, 44 PerChannelMinMaxObserver, 45 PlaceholderObserver, 46) 47from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor 48from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer 49from torch.ao.quantization.quantizer.utils import ( 50 _annotate_input_qspec_map, 51 _annotate_output_qspec, 52) 53from torch.fx import GraphModule, Node 54 55__all__ = [ 56 "ArmQuantizer", 57 "get_symmetric_quantization_config", 58] 59 60 61def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: 62 supported_operators: Dict[str, List[OperatorPatternType]] = { 63 # Both conv and linear should be able to handle relu + hardtanh fusion since 64 # those are clamp ops 65 "conv2d": [ 66 [torch.nn.Conv2d, torch.nn.ReLU], 67 [torch.nn.Conv2d, F.relu], 68 [F.conv2d, torch.nn.ReLU], 69 [F.conv2d, F.relu], 70 ], 71 "linear": [[torch.nn.Linear], [F.linear]], 72 "add": [[torch.add]], 73 "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], 74 "adaptive_avg_pool2d": [ 75 [torch.nn.AdaptiveAvgPool2d], 76 [F.adaptive_avg_pool2d], 77 ], 78 "mul": [[torch.mul]], 79 "sub": [[torch.sub]], 80 } 81 return copy.deepcopy(supported_operators) 82 83 84def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: 85 supported_config_and_operators: List[OperatorConfig] = [] 86 for quantization_config in [ 87 get_symmetric_quantization_config(), 88 get_symmetric_quantization_config(is_per_channel=True), 89 ]: 90 ops = _supported_symmetric_quantized_operators() 91 for pattern_list in ops.values(): 92 supported_config_and_operators.append( 93 OperatorConfig(quantization_config, pattern_list) 94 ) 95 return copy.deepcopy(supported_config_and_operators) 96 97 98@functools.lru_cache 99def get_symmetric_quantization_config( 100 is_per_channel: bool = False, 101 is_qat: bool = False, 102 is_dynamic: bool = False, 103 act_qmin: int = -128, 104 act_qmax: int = 127, 105 weight_qmin: int = -127, 106 weight_qmax: int = 127, 107): 108 extra_args: Dict[str, Any] = {"eps": 2**-12} 109 if is_qat: 110 if is_dynamic: 111 act_observer_or_fake_quant_ctr = FakeQuantize 112 dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( 113 averaging_constant=1 114 ) 115 extra_args["observer"] = dynamic_quant_observer 116 else: 117 act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] 118 else: 119 if is_dynamic: 120 act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] 121 else: 122 act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] 123 124 act_quantization_spec = QuantizationSpec( 125 dtype=torch.int8, 126 quant_min=act_qmin, 127 quant_max=act_qmax, 128 qscheme=torch.per_tensor_affine, 129 is_dynamic=is_dynamic, 130 observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( 131 **extra_args, 132 ), 133 ) 134 weight_qscheme = ( 135 torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric 136 ) 137 weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( 138 MinMaxObserver 139 ) 140 if is_qat: 141 # TODO: qat + per channel? 142 weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize 143 elif is_per_channel: 144 weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver 145 146 extra_args: Dict[str, Any] = {"eps": 2**-12} 147 if is_qat: 148 if weight_qscheme == torch.per_tensor_symmetric: 149 extra_args["observer"] = MovingAverageMinMaxObserver 150 else: 151 extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] 152 weight_quantization_spec = QuantizationSpec( 153 dtype=torch.int8, 154 quant_min=weight_qmin, 155 quant_max=weight_qmax, 156 qscheme=weight_qscheme, 157 ch_axis=0, 158 is_dynamic=False, 159 observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( 160 **extra_args 161 ), 162 ) 163 164 bias_quantization_spec = None 165 if is_dynamic: 166 quantization_config = QuantizationConfig( 167 act_quantization_spec, 168 None, 169 weight_quantization_spec, 170 bias_quantization_spec, 171 ) 172 else: 173 quantization_config = QuantizationConfig( 174 act_quantization_spec, 175 act_quantization_spec, 176 weight_quantization_spec, 177 bias_quantization_spec, 178 ) 179 return quantization_config 180 181 182def _get_supported_config_and_operators() -> List[OperatorConfig]: 183 return _get_supported_symmetric_config_and_operators() 184 185 186NodeFilterType = Callable[[Node], bool] 187"""Type for a Node Filter used by annotators. A Node filter is a function that takes 188 a Node and returns whether the node should be annotated or not. 189""" 190 191 192def _get_module_name_filter(module_name: str) -> NodeFilterType: 193 """Get the module_name_filter function for a given module name, the filter accepts 194 a node and checks if the node comes from a module that has certain module name 195 196 For example: 197 node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 198 199 >> module_name_filter = _get_module_name_filter("blocks.sub") 200 >> print(module_name_filter(node)) 201 True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" 202 """ 203 204 name_start = len("L['self'].") 205 206 def module_name_filter(n: Node) -> bool: 207 # node_stack example: { 208 # 'L__self___sub': ("L['self'].sub", <class '....Sub'>), 209 # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>) 210 # } 211 # get_attr nodes doesn't have nn_module_stack? 212 nn_module_stack = n.meta.get("nn_module_stack", {}) 213 names = [name[name_start:] for name, _ in nn_module_stack.values()] 214 return module_name in names 215 216 return module_name_filter 217 218 219def _get_module_type_filter(tp: Callable) -> NodeFilterType: 220 """Get the module_type_filter function for a given module type, the filter accepts 221 a node and checks if the node comes from a module that has certain module type 222 223 For example: 224 node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear 225 226 227 >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule 228 >> print(module_type_filter(node)) 229 True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) 230 """ 231 232 def module_type_filter(n: Node) -> bool: 233 # node_stack example: { 234 # 'L__self___sub': ("L['self'].sub", <class '....Sub'>), 235 # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>) 236 # } 237 nn_module_stack = n.meta.get("nn_module_stack", {}) 238 types = [t for _, t in nn_module_stack.values()] 239 return tp in types 240 241 return module_type_filter 242 243 244def _get_not_module_type_or_name_filter( 245 tp_list: List[Callable], module_name_list: List[str] 246) -> NodeFilterType: 247 module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] 248 module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] 249 250 def not_module_type_or_name_filter(n: Node) -> bool: 251 return not any(f(n) for f in module_type_filters + module_name_list_filters) 252 253 return not_module_type_or_name_filter 254 255 256class ArmQuantizer(Quantizer): 257 supported_config_and_operators = _get_supported_config_and_operators() 258 259 # A list of supported static quantization annotators, in order of application. 260 # For example, fusions come before singular ops. 261 # The name must match the name used when registering the annotator. 262 STATIC_ANNOTATION_ORDER = [ 263 "linear", 264 "conv", 265 "adaptive_avg_pool2d", 266 "max_pool2d", 267 "add", 268 "sub", 269 "mul", 270 "mm", 271 "one_to_one", 272 "generic", 273 "upsample_nearest2d", 274 ] 275 276 def __init__(self) -> None: 277 super().__init__() 278 self.global_config: Optional[QuantizationConfig] = None 279 self.io_config: Optional[QuantizationConfig] = None 280 self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} 281 self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} 282 283 def set_global(self, quantization_config: QuantizationConfig) -> ArmQuantizer: 284 """Set quantization_config for submodules that are not already annotated by name or type filters.""" 285 self.global_config = quantization_config 286 return self 287 288 def set_module_type( 289 self, module_type: Callable, quantization_config: QuantizationConfig 290 ) -> ArmQuantizer: 291 """Set quantization_config for a submodule with type: `module_type`, for example: 292 quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator 293 patterns in the submodule with this module type with the given `quantization_config` 294 """ 295 self.module_type_config[module_type] = quantization_config 296 return self 297 298 def set_module_name( 299 self, module_name: str, quantization_config: Optional[QuantizationConfig] 300 ) -> ArmQuantizer: 301 """Set quantization_config for a submodule with name: `module_name`, for example: 302 quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator 303 patterns in the submodule with this module name with the given `quantization_config` 304 """ 305 assert ( 306 quantization_config is not None 307 ), " quantization_config == None is not supported yet" 308 self.module_name_config[module_name] = quantization_config 309 return self 310 311 def set_io(self, quantization_config): 312 """Set quantization_config for input and output nodes.""" 313 self.io_config = quantization_config 314 return self 315 316 def transform_for_annotation(self, model: GraphModule) -> GraphModule: 317 """An initial pass for transforming the graph to prepare it for annotation. 318 Currently transforms scalar values to tensor attributes. 319 """ 320 321 return ArmPassManager().transform_for_annotation_pipeline(graph_module=model) 322 323 def annotate(self, model: GraphModule) -> GraphModule: 324 """Performs the quantization annotation on the graph. 325 Currently only does static quantization annotation. 326 Args: 327 model: The model to annotate statically. 328 Returns: 329 The annotated model. 330 """ 331 model = self._annotate_for_static_quantization_config(model) 332 propagate_annotation(model) 333 return model 334 335 def _annotate_all_static_patterns( 336 self, 337 model: GraphModule, 338 quantization_config: Optional[QuantizationConfig], 339 filter_fn: Optional[Callable[[Node], bool]] = None, 340 ) -> GraphModule: 341 """Loops over all STATIC_OPS and runs the corresponding registred annotator. 342 Args: 343 model: The model to annotate statically. 344 quantization_config: Specifices the QuantizationSpecs for the model's 345 input activations, output activations, weights and biases. 346 filter_fn: An optional filter function that takes a node and returns whether the node should be annotated. 347 Returns: 348 The annotated model. 349 """ 350 # TODO: implement the support for None to be canceling out previous annotations 351 if quantization_config is None: 352 return model 353 354 for op in self.STATIC_ANNOTATION_ORDER: 355 OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) 356 return model 357 358 def _annotate_for_static_quantization_config( 359 self, model: GraphModule 360 ) -> GraphModule: 361 """Matches the correct QuantizationConfig with the correct module using a filter 362 when running _annotate_all_static_patterns. 363 """ 364 module_name_list = list(self.module_name_config.keys()) 365 for module_name, config in self.module_name_config.items(): 366 self._annotate_all_static_patterns( 367 model, config, _get_module_name_filter(module_name) 368 ) 369 370 tp_list = list(self.module_type_config.keys()) 371 for module_type, config in self.module_type_config.items(): 372 self._annotate_all_static_patterns( 373 model, config, _get_module_type_filter(module_type) 374 ) 375 376 self._annotate_all_static_patterns( 377 model, 378 self.global_config, 379 _get_not_module_type_or_name_filter(tp_list, module_name_list), 380 ) 381 382 if self.io_config: 383 self._annotate_io(model, self.io_config) 384 385 return model 386 387 def _annotate_io( 388 self, 389 model: GraphModule, 390 quantization_config: QuantizationConfig, 391 ): 392 for node in model.graph.nodes: 393 if arm_quantizer_utils.is_annotated(node): 394 continue 395 if node.op == "placeholder" and len(node.users) > 0: 396 _annotate_output_qspec( 397 node, 398 quantization_config.get_output_act_qspec(), 399 ) 400 mark_nodes_as_annotated([node]) 401 if node.op == "output": 402 parent = node.all_input_nodes[0] 403 _annotate_input_qspec_map( 404 node, parent, quantization_config.get_input_act_qspec() 405 ) 406 mark_nodes_as_annotated([node]) 407 408 def validate(self, model: GraphModule) -> None: 409 pass 410 411 @classmethod 412 def get_supported_operators(cls) -> List[OperatorConfig]: 413 return cls.supported_config_and_operators 414 415 @classmethod 416 def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: 417 op_configs: Set[QuantizationConfig] = set({}) 418 for spec, _ in cls.supported_config_and_operators: 419 op_configs.add(spec) 420 return list(op_configs) 421 422 @classmethod 423 def get_supported_operator_for_quantization_config( 424 cls, quantization_config: Optional[QuantizationConfig] 425 ) -> List[OperatorPatternType]: 426 if quantization_config is None: 427 all_ops = [] 428 for _, ops in cls.supported_config_and_operators: 429 all_ops.extend(ops) 430 return all_ops 431 432 for config, ops in cls.supported_config_and_operators: 433 # note: this assumes each entry in cls.supported_spec_and_operators 434 # corresponds to one spec, e.g. we don't have 435 # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] 436 # where the first and second entry have the same spec but did not 437 # merge the op list 438 if config == quantization_config: 439 return ops 440 return [] 441