xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantize_pt2e.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch._export.passes.constant_folding import constant_fold
3from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
4from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
5from torch.ao.quantization.quantizer import (  # noqa: F401
6    DerivedQuantizationSpec,
7    FixedQParamsQuantizationSpec,
8    QuantizationAnnotation,
9    QuantizationSpec,
10    QuantizationSpecBase,
11    Quantizer,
12    SharedQuantizationSpec,
13)
14from torch.fx import GraphModule, Node
15from torch.fx.passes.infra.pass_manager import PassManager
16
17from .pt2e.prepare import prepare
18from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat
19from .pt2e.representation import reference_representation_rewrite
20from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope
21from .quantize_fx import _convert_to_reference_decomposed_fx
22
23
24__all__ = [
25    "prepare_pt2e",
26    "prepare_qat_pt2e",
27    "convert_pt2e",
28]
29
30
31def prepare_pt2e(
32    model: GraphModule,
33    quantizer: Quantizer,
34) -> GraphModule:
35    """Prepare a model for post training quantization
36
37    Args:
38      * `model` (torch.fx.GraphModule): a model captured by `torch.export` API
39        in the short term we are using `torch._export.capture_pre_autograd_graph`,
40        in the long term we'll migrate to some `torch.export` API
41      * `quantizer`: A backend specific quantizer that conveys how user want the
42        model to be quantized. Tutorial for how to write a quantizer can be found here:
43        https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
44
45    Return:
46      A GraphModule with observer (based on quantizer annotation), ready for calibration
47
48    Example::
49
50        import torch
51        from torch.ao.quantization.quantize_pt2e import prepare_pt2e
52        from torch._export import capture_pre_autograd_graph
53        from torch.ao.quantization.quantizer import (
54            XNNPACKQuantizer,
55            get_symmetric_quantization_config,
56        )
57
58        class M(torch.nn.Module):
59            def __init__(self) -> None:
60                super().__init__()
61                self.linear = torch.nn.Linear(5, 10)
62
63           def forward(self, x):
64               return self.linear(x)
65
66        # initialize a floating point model
67        float_model = M().eval()
68
69        # define calibration function
70        def calibrate(model, data_loader):
71            model.eval()
72            with torch.no_grad():
73                for image, target in data_loader:
74                    model(image)
75
76        # Step 1. program capture
77        # NOTE: this API will be updated to torch.export API in the future, but the captured
78        # result shoud mostly stay the same
79        m = capture_pre_autograd_graph(m, *example_inputs)
80        # we get a model with aten ops
81
82        # Step 2. quantization
83        # backend developer will write their own Quantizer and expose methods to allow
84        # users to express how they
85        # want the model to be quantized
86        quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
87        m = prepare_pt2e(m, quantizer)
88
89        # run calibration
90        # calibrate(m, sample_inference_data)
91    """
92    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
93    original_graph_meta = model.meta
94    node_name_to_scope = _get_node_name_to_scope(model)
95    # TODO: check qconfig_mapping to make sure conv and bn are both configured
96    # to be quantized before fusion
97    # TODO: (maybe) rewrite this with subgraph_rewriter
98    _fuse_conv_bn_(model)
99    model = quantizer.transform_for_annotation(model)
100    quantizer.annotate(model)
101    quantizer.validate(model)
102    model = prepare(model, node_name_to_scope, is_qat=False)
103    model.meta.update(original_graph_meta)
104    model = _disallow_eval_train(model)
105    return model
106
107
108def prepare_qat_pt2e(
109    model: GraphModule,
110    quantizer: Quantizer,
111) -> GraphModule:
112    """Prepare a model for quantization aware training
113
114    Args:
115      * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
116      * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
117
118    Return:
119      A GraphModule with fake quant modules (based on quantizer annotation), ready for
120      quantization aware training
121
122    Example::
123        import torch
124        from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
125        from torch._export import capture_pre_autograd_graph
126        from torch.ao.quantization.quantizer import (
127            XNNPACKQuantizer,
128            get_symmetric_quantization_config,
129        )
130
131        class M(torch.nn.Module):
132            def __init__(self) -> None:
133                super().__init__()
134                self.linear = torch.nn.Linear(5, 10)
135
136           def forward(self, x):
137               return self.linear(x)
138
139        # initialize a floating point model
140        float_model = M().eval()
141
142        # define the training loop for quantization aware training
143        def train_loop(model, train_data):
144            model.train()
145            for image, target in data_loader:
146                ...
147
148        # Step 1. program capture
149        # NOTE: this API will be updated to torch.export API in the future, but the captured
150        # result shoud mostly stay the same
151        m = capture_pre_autograd_graph(m, *example_inputs)
152        # we get a model with aten ops
153
154        # Step 2. quantization
155        # backend developer will write their own Quantizer and expose methods to allow
156        # users to express how they
157        # want the model to be quantized
158        quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
159        m = prepare_qat_pt2e(m, quantizer)
160
161        # run quantization aware training
162        train_loop(prepared_model, train_loop)
163
164    """
165    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
166    original_graph_meta = model.meta
167    node_name_to_scope = _get_node_name_to_scope(model)
168    model = quantizer.transform_for_annotation(model)
169    quantizer.annotate(model)
170    quantizer.validate(model)
171    # Perform fusion after annotate to avoid quantizing ops in the new
172    # subgraph that don't need to be quantized
173    # TODO: only fuse if conv and bn are both configured to be quantized
174    _fuse_conv_bn_qat(model)
175    model = prepare(model, node_name_to_scope, is_qat=True)
176    model.meta.update(original_graph_meta)
177    model = _disallow_eval_train(model)
178    return model
179
180
181_QUANT_OPS = [
182    torch.ops.quantized_decomposed.quantize_per_tensor.default,
183    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
184    torch.ops.quantized_decomposed.quantize_per_channel.default,
185]
186
187
188def _quant_node_constraint(n: Node) -> bool:
189    """If there is any pure ops between get_attr and quantize op they will be const propagated
190    e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
191    (Note: dequantize op is not going to be constant propagated)
192
193    This filter is added because we don't want to constant fold the things that are not
194    related to quantization
195    """
196    return n.op == "call_function" and n.target in _QUANT_OPS
197
198
199def convert_pt2e(
200    model: GraphModule,
201    use_reference_representation: bool = False,
202    fold_quantize: bool = True,
203) -> GraphModule:
204    """Convert a calibrated/trained model to a quantized model
205
206    Args:
207      * `model` (torch.fx.GraphModule): calibrated/trained model
208      * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not
209      * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not
210
211    Returns:
212        quantized model, either in q/dq representation or reference representation
213
214    Example::
215
216        # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training
217        # `convert_pt2e` produces a quantized model that represents quantized computation with
218        # quantize dequantize ops and fp32 ops by default.
219        # Please refer to
220        # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model
221        # for detailed explanation of output quantized model
222        quantized_model = convert_pt2e(prepared_model)
223
224    """  # flake8: noqa
225    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
226    if not isinstance(use_reference_representation, bool):
227        raise ValueError(
228            "Unexpected argument type for `use_reference_representation`, "
229            f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e"
230        )
231    original_graph_meta = model.meta
232    model = _convert_to_reference_decomposed_fx(model)
233    model = _fold_conv_bn_qat(model)
234
235    pm = PassManager([DuplicateDQPass()])
236    model = pm(model).graph_module
237
238    pm = PassManager([PortNodeMetaForQDQ()])
239    model = pm(model).graph_module
240
241    if fold_quantize:
242        constant_fold(model, _quant_node_constraint)
243
244    if use_reference_representation:
245        model = reference_representation_rewrite(model)
246
247    model.meta.update(original_graph_meta)
248    model = _disallow_eval_train(model)
249    return model
250