1# mypy: allow-untyped-defs 2 3from typing import Callable, List, Optional, Tuple, Union 4 5import torch 6from torch import Tensor 7 8from .fake_quantize import * # noqa: F403 9from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403 10from .fuser_method_mappings import * # noqa: F403 11from .observer import * # noqa: F403 12from .pt2e._numeric_debugger import ( # noqa: F401 13 compare_results, 14 CUSTOM_KEY, 15 extract_results_from_loggers, 16 generate_numeric_debug_handle, 17 NUMERIC_DEBUG_HANDLE_KEY, 18 prepare_for_propagation_comparison, 19) 20from .pt2e.export_utils import ( 21 _allow_exported_model_train_eval as allow_exported_model_train_eval, 22 _move_exported_model_to_eval as move_exported_model_to_eval, 23 _move_exported_model_to_train as move_exported_model_to_train, 24) 25from .qconfig import * # noqa: F403 26from .qconfig_mapping import * # noqa: F403 27from .quant_type import * # noqa: F403 28from .quantization_mappings import * # noqa: F403 # type: ignore[no-redef] 29from .quantize import * # noqa: F403 30from .quantize_jit import * # noqa: F403 31from .stubs import * # noqa: F403 32 33 34# ensure __module__ is set correctly for public APIs 35ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] 36ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" 37for _f in [ 38 compare_results, 39 extract_results_from_loggers, 40 generate_numeric_debug_handle, 41 prepare_for_propagation_comparison, 42]: 43 _f.__module__ = "torch.ao.quantization" 44 45__all__ = [ 46 "DeQuantStub", 47 "FakeQuantize", 48 "FakeQuantizeBase", 49 "FixedQParamsFakeQuantize", 50 "FixedQParamsObserver", 51 "FusedMovingAvgObsFakeQuantize", 52 "HistogramObserver", 53 "MatchAllNode", 54 "MinMaxObserver", 55 "MovingAverageMinMaxObserver", 56 "MovingAveragePerChannelMinMaxObserver", 57 "NoopObserver", 58 "ObserverBase", 59 "ObserverOrFakeQuantize", 60 "Pattern", 61 "PerChannelMinMaxObserver", 62 "PlaceholderObserver", 63 "QConfig", 64 "QConfigAny", 65 "QConfigDynamic", 66 "QConfigMapping", 67 "QuantStub", 68 "QuantType", 69 "QuantWrapper", 70 "RecordingObserver", 71 "ReuseInputObserver", 72 "UniformQuantizationObserverBase", 73 "add_quant_dequant", 74 "convert", 75 "convert_dynamic_jit", 76 "convert_jit", 77 "default_affine_fixed_qparams_fake_quant", 78 "default_affine_fixed_qparams_observer", 79 "default_debug_observer", 80 "default_dynamic_fake_quant", 81 "default_dynamic_quant_observer", 82 "default_embedding_fake_quant", 83 "default_embedding_fake_quant_4bit", 84 "default_eval_fn", 85 "default_fake_quant", 86 "default_fixed_qparams_range_0to1_fake_quant", 87 "default_fixed_qparams_range_0to1_observer", 88 "default_fixed_qparams_range_neg1to1_fake_quant", 89 "default_fixed_qparams_range_neg1to1_observer", 90 "default_float_qparams_observer", 91 "default_float_qparams_observer_4bit", 92 "default_fused_act_fake_quant", 93 "default_fused_per_channel_wt_fake_quant", 94 "default_fused_wt_fake_quant", 95 "default_histogram_fake_quant", 96 "default_histogram_observer", 97 "default_observer", 98 "default_per_channel_weight_fake_quant", 99 "default_per_channel_weight_observer", 100 "default_placeholder_observer", 101 "default_reuse_input_observer", 102 "default_symmetric_fixed_qparams_fake_quant", 103 "default_symmetric_fixed_qparams_observer", 104 "default_weight_fake_quant", 105 "default_weight_observer", 106 "disable_fake_quant", 107 "disable_observer", 108 "enable_fake_quant", 109 "enable_observer", 110 "fuse_conv_bn", 111 "fuse_conv_bn_jit", 112 "fuse_conv_bn_relu", 113 "fuse_convtranspose_bn", 114 "fuse_linear_bn", 115 "fuse_modules", 116 "fuse_modules_qat", 117 "fused_per_channel_wt_fake_quant_range_neg_127_to_127", 118 "fused_wt_fake_quant_range_neg_127_to_127", 119 "get_combined_dict", 120 "get_default_compare_output_module_list", 121 "get_default_custom_config_dict", 122 "get_default_dynamic_quant_module_mappings", 123 "get_default_dynamic_sparse_quant_module_mappings", 124 "get_default_float_to_quantized_operator_mappings", 125 "get_default_qat_module_mappings", 126 "get_default_qat_qconfig", 127 "get_default_qat_qconfig_dict", 128 "get_default_qat_qconfig_mapping", 129 "get_default_qconfig", 130 "get_default_qconfig_dict", 131 "get_default_qconfig_mapping", 132 "get_default_qconfig_propagation_list", 133 "get_default_static_quant_module_mappings", 134 "get_default_static_quant_reference_module_mappings", 135 "get_default_static_sparse_quant_module_mappings", 136 "get_dynamic_quant_module_class", 137 "get_embedding_qat_module_mappings", 138 "get_embedding_static_quant_module_mappings", 139 "get_fuser_method", 140 "get_fuser_method_new", 141 "get_observer_state_dict", 142 "get_quantized_operator", 143 "get_static_quant_module_class", 144 "load_observer_state_dict", 145 "move_exported_model_to_eval", 146 "move_exported_model_to_train", 147 "allow_exported_model_train_eval", 148 "no_observer_set", 149 "per_channel_weight_observer_range_neg_127_to_127", 150 "prepare", 151 "prepare_dynamic_jit", 152 "prepare_jit", 153 "prepare_qat", 154 "propagate_qconfig_", 155 "qconfig_equals", 156 "quantize", 157 "quantize_dynamic", 158 "quantize_dynamic_jit", 159 "quantize_jit", 160 "quantize_qat", 161 "script_qconfig", 162 "script_qconfig_dict", 163 "swap_module", 164 "weight_observer_range_neg_127_to_127", 165 "generate_numeric_debug_handle", 166 "CUSTOM_KEY", 167 "NUMERIC_DEBUG_HANDLE_KEY", 168 "prepare_for_propagation_comparison", 169 "extract_results_from_loggers", 170 "compare_results", 171] 172 173 174def default_eval_fn(model, calib_data): 175 r"""Define the default evaluation function. 176 177 Default evaluation function takes a torch.utils.data.Dataset or a list of 178 input Tensors and run the model on the dataset 179 """ 180 for data, target in calib_data: 181 model(data) 182 183 184class _DerivedObserverOrFakeQuantize(ObserverBase): 185 r"""This observer is used to describe an observer whose quantization parameters 186 are derived from other observers 187 """ 188 189 def __init__( 190 self, 191 dtype: torch.dtype, 192 obs_or_fqs: List[ObserverOrFakeQuantize], 193 derive_qparams_fn: Callable[ 194 [List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor] 195 ], 196 quant_min: Optional[int] = None, 197 quant_max: Optional[int] = None, 198 qscheme: Optional[torch.qscheme] = None, 199 ch_axis: Optional[int] = None, 200 ): 201 super().__init__(dtype) 202 self.obs_or_fqs = obs_or_fqs 203 self.derive_qparams_fn = derive_qparams_fn 204 self.quant_min = quant_min 205 self.quant_max = quant_max 206 self.qscheme = qscheme 207 self.ch_axis = ch_axis 208 209 from .utils import is_per_channel 210 211 if is_per_channel(self.qscheme): 212 assert ( 213 self.ch_axis is not None 214 ), "Must provide a valid ch_axis if qscheme is per channel" 215 216 def forward(self, x: Tensor) -> Tensor: 217 return x 218 219 def calculate_qparams(self): 220 return self.derive_qparams_fn(self.obs_or_fqs) 221