1# mypy: allow-untyped-defs 2from abc import ABC 3from typing import Callable, Dict, List, Optional, Type 4 5import torch 6from torch.ao.quantization.backend_config import ( 7 BackendConfig, 8 DTypeConfig, 9 ObservationType, 10) 11from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls 12from torch.fx.graph import Node 13 14from .utils import all_node_args_have_no_tensors 15 16 17__all__ = [ 18 "QuantizeHandler", 19 "BinaryOpQuantizeHandler", 20 "CatQuantizeHandler", 21 "ConvReluQuantizeHandler", 22 "LinearReLUQuantizeHandler", 23 "BatchNormQuantizeHandler", 24 "EmbeddingQuantizeHandler", 25 "RNNDynamicQuantizeHandler", 26 "DefaultNodeQuantizeHandler", 27 "FixedQParamsOpQuantizeHandler", 28 "CopyNodeQuantizeHandler", 29 "GeneralTensorShapeOpQuantizeHandler", 30 "CustomModuleQuantizeHandler", 31 "StandaloneModuleQuantizeHandler", 32] 33 34 35def _default_root_node_getter(node_pattern): 36 if node_pattern is None: 37 return node_pattern 38 while not isinstance(node_pattern, Node): 39 node_pattern = node_pattern[-1] 40 return node_pattern 41 42 43# Base Pattern Handler 44class QuantizeHandler(ABC): # noqa: B024 45 """Base handler class for the quantizer patterns""" 46 47 def __init__( 48 self, 49 node_pattern: NodePattern, 50 modules: Dict[str, torch.nn.Module], 51 root_node_getter: Optional[Callable] = None, 52 is_custom_module=False, 53 is_standalone_module=False, 54 ): 55 """Records pattern information in __init__, which will be used 56 in convert 57 """ 58 self.node_pattern = node_pattern 59 self.modules = modules 60 if root_node_getter is None: 61 root_node_getter = _default_root_node_getter 62 self.root_node = root_node_getter(node_pattern) 63 self.is_custom_module_ = is_custom_module 64 self.is_standalone_module_ = is_standalone_module 65 self.num_tensor_args = 0 66 # determine how many of the first two args are Tensors (versus scalars) 67 # this distinguishes things like "x + y" from "x + 2" or "2 + x" 68 if isinstance(self.root_node, Node): 69 cache_for_no_tensor_check: Dict[Node, bool] = {} 70 for arg_idx in range(len(self.root_node.args)): 71 arg = self.root_node.args[arg_idx] 72 if isinstance(arg, Node) and ( 73 not all_node_args_have_no_tensors( 74 arg, self.modules, cache_for_no_tensor_check 75 ) 76 ): 77 self.num_tensor_args += 1 78 79 def is_general_tensor_value_op(self) -> bool: 80 """ 81 Returns True if the operator works for both floating point and 82 quantized input, and does some computation based on the input Tensor, 83 or the ops that only re-arranges the Tensor values or query some metadata 84 about the Tensor 85 so we need to insert observer/fake_quant for the output of the 86 operator (same observer instance as input) 87 since the distribution of values is different for input and output 88 Tensors (for HistogramObserver) while they share the same quantization 89 parameters 90 Example operator: avgpool2d, reshape, transpose, maxpool2d 91 Example observed operator: 92 observer_0 - avgpool2d - observer_0 (same observer instance as input) 93 """ 94 return False 95 96 def is_custom_module(self): 97 return self.is_custom_module_ 98 99 def is_standalone_module(self): 100 return self.is_standalone_module_ 101 102 103def _get_quantize_handler_cls( 104 observation_type: ObservationType, 105 dtype_configs: List[DTypeConfig], 106 num_tensor_args_to_observation_type: Dict[int, ObservationType], 107) -> Type[QuantizeHandler]: 108 """ 109 Return a configurable QuantizeHandler that matches the given specifications from the backend. 110 """ 111 112 class ConfigurableQuantizeHandler(QuantizeHandler): 113 def __init__( 114 self, 115 node_pattern: NodePattern, 116 modules: Dict[str, torch.nn.Module], 117 root_node_getter: Optional[Callable] = None, 118 ): 119 super().__init__(node_pattern, modules, root_node_getter) 120 if num_tensor_args_to_observation_type: 121 assert self.num_tensor_args in num_tensor_args_to_observation_type, ( 122 f"Must provide observation_type config for tensor number {self.num_tensor_args}" 123 f" in num_tensor_args_to_observation_type for {node_pattern}" 124 ) 125 self.observation_type = num_tensor_args_to_observation_type[ 126 self.num_tensor_args 127 ] 128 else: 129 self.observation_type = observation_type 130 self.dtype_configs = dtype_configs 131 132 def is_general_tensor_value_op(self) -> bool: 133 return ( 134 self.observation_type 135 == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT 136 ) 137 138 return ConfigurableQuantizeHandler 139 140 141def _get_pattern_to_quantize_handlers( 142 backend_config: BackendConfig, 143) -> Dict[Pattern, QuantizerCls]: 144 """ 145 Note: Quantize handler is just a holder for some check methods like 146 (should_insert_observer_for_output), maybe this can be a enum as well, 147 we can refactor this after we convert the path for fbgemm/qnnpack fully to the 148 new path, this is not exposed to backend developers 149 """ 150 pattern_to_quantize_handlers = {} 151 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 152 observation_type = config.observation_type 153 dtype_configs = config.dtype_configs 154 num_tensor_args_to_observation_type = ( 155 config._num_tensor_args_to_observation_type 156 ) 157 pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls( 158 observation_type, dtype_configs, num_tensor_args_to_observation_type 159 ) 160 return pattern_to_quantize_handlers 161 162 163# TODO: remove this class, this is still exposed in torch.ao.quantization 164# but we should be able to break bc 165class BinaryOpQuantizeHandler(QuantizeHandler): 166 pass 167 168 169class CatQuantizeHandler(QuantizeHandler): 170 pass 171 172 173# TODO: remove this class 174class ConvReluQuantizeHandler(QuantizeHandler): 175 pass 176 177 178# TODO: remove this class 179class LinearReLUQuantizeHandler(QuantizeHandler): 180 pass 181 182 183# TODO: remove this class 184class BatchNormQuantizeHandler(QuantizeHandler): 185 pass 186 187 188# TODO: remove this class 189class EmbeddingQuantizeHandler(QuantizeHandler): 190 pass 191 192 193# TODO: remove this class 194class RNNDynamicQuantizeHandler(QuantizeHandler): 195 pass 196 197 198# TODO: remove this class 199class DefaultNodeQuantizeHandler(QuantizeHandler): 200 """Common quantized op, first input and first output will be quantized""" 201 202 203# TODO: remove this class 204class FixedQParamsOpQuantizeHandler(QuantizeHandler): 205 pass 206 207 208# TODO: remove 209class CopyNodeQuantizeHandler(QuantizeHandler): 210 pass 211 212 213# TODO: remove 214class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler): 215 pass 216 217 218# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated 219class CustomModuleQuantizeHandler(QuantizeHandler): 220 pass 221 222 223# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated 224class StandaloneModuleQuantizeHandler(QuantizeHandler): 225 pass 226