xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from typing import Callable, List, Optional, Tuple, Union
4
5import torch
6from torch import Tensor
7
8from .fake_quantize import *  # noqa: F403
9from .fuse_modules import fuse_modules, fuse_modules_qat  # noqa: F403
10from .fuser_method_mappings import *  # noqa: F403
11from .observer import *  # noqa: F403
12from .pt2e._numeric_debugger import (  # noqa: F401
13    compare_results,
14    CUSTOM_KEY,
15    extract_results_from_loggers,
16    generate_numeric_debug_handle,
17    NUMERIC_DEBUG_HANDLE_KEY,
18    prepare_for_propagation_comparison,
19)
20from .pt2e.export_utils import (
21    _allow_exported_model_train_eval as allow_exported_model_train_eval,
22    _move_exported_model_to_eval as move_exported_model_to_eval,
23    _move_exported_model_to_train as move_exported_model_to_train,
24)
25from .qconfig import *  # noqa: F403
26from .qconfig_mapping import *  # noqa: F403
27from .quant_type import *  # noqa: F403
28from .quantization_mappings import *  # noqa: F403 # type: ignore[no-redef]
29from .quantize import *  # noqa: F403
30from .quantize_jit import *  # noqa: F403
31from .stubs import *  # noqa: F403
32
33
34# ensure __module__ is set correctly for public APIs
35ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
36ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
37for _f in [
38    compare_results,
39    extract_results_from_loggers,
40    generate_numeric_debug_handle,
41    prepare_for_propagation_comparison,
42]:
43    _f.__module__ = "torch.ao.quantization"
44
45__all__ = [
46    "DeQuantStub",
47    "FakeQuantize",
48    "FakeQuantizeBase",
49    "FixedQParamsFakeQuantize",
50    "FixedQParamsObserver",
51    "FusedMovingAvgObsFakeQuantize",
52    "HistogramObserver",
53    "MatchAllNode",
54    "MinMaxObserver",
55    "MovingAverageMinMaxObserver",
56    "MovingAveragePerChannelMinMaxObserver",
57    "NoopObserver",
58    "ObserverBase",
59    "ObserverOrFakeQuantize",
60    "Pattern",
61    "PerChannelMinMaxObserver",
62    "PlaceholderObserver",
63    "QConfig",
64    "QConfigAny",
65    "QConfigDynamic",
66    "QConfigMapping",
67    "QuantStub",
68    "QuantType",
69    "QuantWrapper",
70    "RecordingObserver",
71    "ReuseInputObserver",
72    "UniformQuantizationObserverBase",
73    "add_quant_dequant",
74    "convert",
75    "convert_dynamic_jit",
76    "convert_jit",
77    "default_affine_fixed_qparams_fake_quant",
78    "default_affine_fixed_qparams_observer",
79    "default_debug_observer",
80    "default_dynamic_fake_quant",
81    "default_dynamic_quant_observer",
82    "default_embedding_fake_quant",
83    "default_embedding_fake_quant_4bit",
84    "default_eval_fn",
85    "default_fake_quant",
86    "default_fixed_qparams_range_0to1_fake_quant",
87    "default_fixed_qparams_range_0to1_observer",
88    "default_fixed_qparams_range_neg1to1_fake_quant",
89    "default_fixed_qparams_range_neg1to1_observer",
90    "default_float_qparams_observer",
91    "default_float_qparams_observer_4bit",
92    "default_fused_act_fake_quant",
93    "default_fused_per_channel_wt_fake_quant",
94    "default_fused_wt_fake_quant",
95    "default_histogram_fake_quant",
96    "default_histogram_observer",
97    "default_observer",
98    "default_per_channel_weight_fake_quant",
99    "default_per_channel_weight_observer",
100    "default_placeholder_observer",
101    "default_reuse_input_observer",
102    "default_symmetric_fixed_qparams_fake_quant",
103    "default_symmetric_fixed_qparams_observer",
104    "default_weight_fake_quant",
105    "default_weight_observer",
106    "disable_fake_quant",
107    "disable_observer",
108    "enable_fake_quant",
109    "enable_observer",
110    "fuse_conv_bn",
111    "fuse_conv_bn_jit",
112    "fuse_conv_bn_relu",
113    "fuse_convtranspose_bn",
114    "fuse_linear_bn",
115    "fuse_modules",
116    "fuse_modules_qat",
117    "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
118    "fused_wt_fake_quant_range_neg_127_to_127",
119    "get_combined_dict",
120    "get_default_compare_output_module_list",
121    "get_default_custom_config_dict",
122    "get_default_dynamic_quant_module_mappings",
123    "get_default_dynamic_sparse_quant_module_mappings",
124    "get_default_float_to_quantized_operator_mappings",
125    "get_default_qat_module_mappings",
126    "get_default_qat_qconfig",
127    "get_default_qat_qconfig_dict",
128    "get_default_qat_qconfig_mapping",
129    "get_default_qconfig",
130    "get_default_qconfig_dict",
131    "get_default_qconfig_mapping",
132    "get_default_qconfig_propagation_list",
133    "get_default_static_quant_module_mappings",
134    "get_default_static_quant_reference_module_mappings",
135    "get_default_static_sparse_quant_module_mappings",
136    "get_dynamic_quant_module_class",
137    "get_embedding_qat_module_mappings",
138    "get_embedding_static_quant_module_mappings",
139    "get_fuser_method",
140    "get_fuser_method_new",
141    "get_observer_state_dict",
142    "get_quantized_operator",
143    "get_static_quant_module_class",
144    "load_observer_state_dict",
145    "move_exported_model_to_eval",
146    "move_exported_model_to_train",
147    "allow_exported_model_train_eval",
148    "no_observer_set",
149    "per_channel_weight_observer_range_neg_127_to_127",
150    "prepare",
151    "prepare_dynamic_jit",
152    "prepare_jit",
153    "prepare_qat",
154    "propagate_qconfig_",
155    "qconfig_equals",
156    "quantize",
157    "quantize_dynamic",
158    "quantize_dynamic_jit",
159    "quantize_jit",
160    "quantize_qat",
161    "script_qconfig",
162    "script_qconfig_dict",
163    "swap_module",
164    "weight_observer_range_neg_127_to_127",
165    "generate_numeric_debug_handle",
166    "CUSTOM_KEY",
167    "NUMERIC_DEBUG_HANDLE_KEY",
168    "prepare_for_propagation_comparison",
169    "extract_results_from_loggers",
170    "compare_results",
171]
172
173
174def default_eval_fn(model, calib_data):
175    r"""Define the default evaluation function.
176
177    Default evaluation function takes a torch.utils.data.Dataset or a list of
178    input Tensors and run the model on the dataset
179    """
180    for data, target in calib_data:
181        model(data)
182
183
184class _DerivedObserverOrFakeQuantize(ObserverBase):
185    r"""This observer is used to describe an observer whose quantization parameters
186    are derived from other observers
187    """
188
189    def __init__(
190        self,
191        dtype: torch.dtype,
192        obs_or_fqs: List[ObserverOrFakeQuantize],
193        derive_qparams_fn: Callable[
194            [List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]
195        ],
196        quant_min: Optional[int] = None,
197        quant_max: Optional[int] = None,
198        qscheme: Optional[torch.qscheme] = None,
199        ch_axis: Optional[int] = None,
200    ):
201        super().__init__(dtype)
202        self.obs_or_fqs = obs_or_fqs
203        self.derive_qparams_fn = derive_qparams_fn
204        self.quant_min = quant_min
205        self.quant_max = quant_max
206        self.qscheme = qscheme
207        self.ch_axis = ch_axis
208
209        from .utils import is_per_channel
210
211        if is_per_channel(self.qscheme):
212            assert (
213                self.ch_axis is not None
214            ), "Must provide a valid ch_axis if qscheme is per channel"
215
216    def forward(self, x: Tensor) -> Tensor:
217        return x
218
219    def calculate_qparams(self):
220        return self.derive_qparams_fn(self.obs_or_fqs)
221