xref: /aosp_15_r20/external/pytorch/docs/source/quantization.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _quantization-doc:
2
3Quantization
4============
5
6.. automodule:: torch.ao.quantization
7.. automodule:: torch.ao.quantization.fx
8
9.. warning ::
10     Quantization is in beta and subject to change.
11
12Introduction to Quantization
13----------------------------
14
15Quantization refers to techniques for performing computations and storing
16tensors at lower bitwidths than floating point precision. A quantized model
17executes some or all of the operations on tensors with reduced precision rather than
18full precision (floating point) values. This allows for a more compact model representation and
19the use of high performance vectorized operations on many hardware platforms.
20PyTorch supports INT8 quantization compared to typical FP32 models allowing for
21a 4x reduction in the model size and a 4x reduction in memory bandwidth
22requirements. Hardware support for INT8 computations is typically 2 to 4
23times faster compared to FP32 compute. Quantization is primarily a technique to
24speed up inference and only the forward pass is supported for quantized
25operators.
26
27PyTorch supports multiple approaches to quantizing a deep learning model. In
28most cases the model is trained in FP32 and then the model is converted to
29INT8. In addition, PyTorch also supports quantization aware training, which
30models quantization errors in both the forward and backward passes using
31fake-quantization modules. Note that the entire computation is carried out in
32floating point. At the end of quantization aware training, PyTorch provides
33conversion functions to convert the trained model into lower precision.
34
35At lower level, PyTorch provides a way to represent quantized tensors and
36perform operations with them. They can be used to directly construct models
37that perform all or part of the computation in lower precision. Higher-level
38APIs are provided that incorporate typical workflows of converting FP32 model
39to lower precision with minimal accuracy loss.
40
41Quantization API Summary
42-----------------------------
43
44PyTorch provides three different modes of quantization: Eager Mode Quantization, FX Graph Mode Quantization (maintenance) and PyTorch 2 Export Quantization.
45
46Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals.
47
48FX Graph Mode Quantization is an automated quantization workflow in PyTorch, and currently it's a prototype feature, it is in maintenance mode since we have PyTorch 2 Export Quantization. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable.
49
50PyTorch 2 Export Quantization is the new full graph mode quantization workflow, released as prototype feature in PyTorch 2.1. With PyTorch 2, we are moving to a better solution for full program capture (torch.export) since it can capture a higher percentage (88.8% on 14K models) of models compared to torch.fx.symbolic_trace (72.7% on 14K models), the program capture solution used by FX Graph Mode Quantization. torch.export still has limitations around some python constructs and requires user involvement to support dynamism in the exported model, but overall it is an improvement over the previous program capture solution. PyTorch 2 Export Quantization is built for models captured by torch.export, with flexibility and productivity of both modeling users and backend developers in mind. The main features are
51(1). Programmable API for configuring how a model is quantized that can scale to many more use cases
52(2). Simplified UX for modeling users and backend developers since they only need to interact with a single object (Quantizer) for expressing user’s intention about how to quantize a model and what the backend support.
53(3). Optional reference quantized model representation that can represent quantized computation with integer operations that maps closer to actual quantized computations that happens in hardware.
54
55New users of quantization are encouraged to try out PyTorch 2 Export Quantization first, if it does not work well, user can try eager mode quantization.
56
57The following table compares the differences between Eager Mode Quantization, FX Graph Mode Quantization and PyTorch 2 Export Quantization:
58
59+-----------------+-------------------+-------------------+-------------------------+
60|                 |Eager Mode         |FX Graph           |PyTorch 2 Export         |
61|                 |Quantization       |Mode               |Quantization             |
62|                 |                   |Quantization       |                         |
63+-----------------+-------------------+-------------------+-------------------------+
64|Release          |beta               |prototype          |prototype                |
65|Status           |                   |(maintenance)      |                         |
66+-----------------+-------------------+-------------------+-------------------------+
67|Operator         |Manual             |Automatic          |Automatic                |
68|Fusion           |                   |                   |                         |
69+-----------------+-------------------+-------------------+-------------------------+
70|Quant/DeQuant    |Manual             |Automatic          |Automatic                |
71|Placement        |                   |                   |                         |
72+-----------------+-------------------+-------------------+-------------------------+
73|Quantizing       |Supported          |Supported          |Supported                |
74|Modules          |                   |                   |                         |
75+-----------------+-------------------+-------------------+-------------------------+
76|Quantizing       |Manual             |Automatic          |Supported                |
77|Functionals/Torch|                   |                   |                         |
78|Ops              |                   |                   |                         |
79+-----------------+-------------------+-------------------+-------------------------+
80|Support for      |Limited Support    |Fully              |Fully Supported          |
81|Customization    |                   |Supported          |                         |
82+-----------------+-------------------+-------------------+-------------------------+
83|Quantization Mode|Post Training      |Post Training      |Defined by               |
84|Support          |Quantization:      |Quantization:      |Backend Specific         |
85|                 |Static, Dynamic,   |Static, Dynamic,   |Quantizer                |
86|                 |Weight Only        |Weight Only        |                         |
87|                 |                   |                   |                         |
88|                 |Quantization Aware |Quantization Aware |                         |
89|                 |Training:          |Training:          |                         |
90|                 |Static             |Static             |                         |
91+-----------------+-------------------+-------------------+-------------------------+
92|Input/Output     |``torch.nn.Module``|``torch.nn.Module``|``torch.fx.GraphModule`` |
93|Model Type       |                   |(May need some     |(captured by             |
94|                 |                   |refactors to make  |``torch.export``         |
95|                 |                   |the model          |                         |
96|                 |                   |compatible with FX |                         |
97|                 |                   |Graph Mode         |                         |
98|                 |                   |Quantization)      |                         |
99+-----------------+-------------------+-------------------+-------------------------+
100
101
102
103There are three types of quantization supported:
104
1051. dynamic quantization (weights quantized with activations read/stored in
106   floating point and quantized for compute)
1072. static quantization (weights quantized, activations quantized, calibration
108   required post training)
1093. static quantization aware training (weights quantized, activations quantized,
110   quantization numerics modeled during training)
111
112Please see our `Introduction to Quantization on PyTorch
113<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post
114for a more comprehensive overview of the tradeoffs between these quantization
115types.
116
117Operator coverage varies between dynamic and static quantization and is captured in the table below.
118
119+---------------------------+-------------------+--------------------+
120|                           |Static             | Dynamic            |
121|                           |Quantization       | Quantization       |
122+---------------------------+-------------------+--------------------+
123| | nn.Linear               | | Y               | | Y                |
124| | nn.Conv1d/2d/3d         | | Y               | | N                |
125+---------------------------+-------------------+--------------------+
126| | nn.LSTM                 | | Y (through      | | Y                |
127| |                         | | custom modules) | |                  |
128| | nn.GRU                  | | N               | | Y                |
129+---------------------------+-------------------+--------------------+
130| | nn.RNNCell              | | N               | | Y                |
131| | nn.GRUCell              | | N               | | Y                |
132| | nn.LSTMCell             | | N               | | Y                |
133+---------------------------+-------------------+--------------------+
134|nn.EmbeddingBag            | Y (activations    |                    |
135|                           | are in fp32)      | Y                  |
136+---------------------------+-------------------+--------------------+
137|nn.Embedding               | Y                 | Y                  |
138+---------------------------+-------------------+--------------------+
139| nn.MultiheadAttention     | Y (through        | Not supported      |
140|                           | custom modules)   |                    |
141+---------------------------+-------------------+--------------------+
142| Activations               | Broadly supported | Un-changed,        |
143|                           |                   | computations       |
144|                           |                   | stay in fp32       |
145+---------------------------+-------------------+--------------------+
146
147
148Eager Mode Quantization
149^^^^^^^^^^^^^^^^^^^^^^^
150For a general introduction to the quantization flow, including different types of quantization, please take a look at `General Quantization Flow`_.
151
152Post Training Dynamic Quantization
153~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154
155This is the simplest to apply form of quantization where the weights are
156quantized ahead of time but the activations are dynamically quantized
157during inference. This is used for situations where the model execution time
158is dominated by loading weights from memory rather than computing the matrix
159multiplications. This is true for LSTM and Transformer type models with
160small batch size.
161
162Diagram::
163
164  # original model
165  # all tensors and computations are in floating point
166  previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
167                   /
168  linear_weight_fp32
169
170  # dynamically quantized model
171  # linear and LSTM weights are in int8
172  previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
173                       /
174     linear_weight_int8
175
176PTDQ API Example::
177
178  import torch
179
180  # define a floating point model
181  class M(torch.nn.Module):
182      def __init__(self):
183          super().__init__()
184          self.fc = torch.nn.Linear(4, 4)
185
186      def forward(self, x):
187          x = self.fc(x)
188          return x
189
190  # create a model instance
191  model_fp32 = M()
192  # create a quantized model instance
193  model_int8 = torch.ao.quantization.quantize_dynamic(
194      model_fp32,  # the original model
195      {torch.nn.Linear},  # a set of layers to dynamically quantize
196      dtype=torch.qint8)  # the target dtype for quantized weights
197
198  # run the model
199  input_fp32 = torch.randn(4, 4, 4, 4)
200  res = model_int8(input_fp32)
201
202To learn more about dynamic quantization please see our `dynamic quantization tutorial
203<https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html>`_.
204
205Post Training Static Quantization
206~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
207
208Post Training Static Quantization (PTQ static) quantizes the weights and activations of the model.  It
209fuses activations into preceding layers where possible.  It requires
210calibration with a representative dataset to determine optimal quantization
211parameters for activations. Post Training Static Quantization is typically used when
212both memory bandwidth and compute savings are important with CNNs being a
213typical use case.
214
215We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_.
216
217Diagram::
218
219    # original model
220    # all tensors and computations are in floating point
221    previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
222                        /
223        linear_weight_fp32
224
225    # statically quantized model
226    # weights and activations are in int8
227    previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
228                        /
229      linear_weight_int8
230
231PTSQ API Example::
232
233  import torch
234
235  # define a floating point model where some layers could be statically quantized
236  class M(torch.nn.Module):
237      def __init__(self):
238          super().__init__()
239          # QuantStub converts tensors from floating point to quantized
240          self.quant = torch.ao.quantization.QuantStub()
241          self.conv = torch.nn.Conv2d(1, 1, 1)
242          self.relu = torch.nn.ReLU()
243          # DeQuantStub converts tensors from quantized to floating point
244          self.dequant = torch.ao.quantization.DeQuantStub()
245
246      def forward(self, x):
247          # manually specify where tensors will be converted from floating
248          # point to quantized in the quantized model
249          x = self.quant(x)
250          x = self.conv(x)
251          x = self.relu(x)
252          # manually specify where tensors will be converted from quantized
253          # to floating point in the quantized model
254          x = self.dequant(x)
255          return x
256
257  # create a model instance
258  model_fp32 = M()
259
260  # model must be set to eval mode for static quantization logic to work
261  model_fp32.eval()
262
263  # attach a global qconfig, which contains information about what kind
264  # of observers to attach. Use 'x86' for server inference and 'qnnpack'
265  # for mobile inference. Other quantization configurations such as selecting
266  # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
267  # can be specified here.
268  # Note: the old 'fbgemm' is still available but 'x86' is the recommended default
269  # for server inference.
270  # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
271  model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
272
273  # Fuse the activations to preceding layers, where applicable.
274  # This needs to be done manually depending on the model architecture.
275  # Common fusions include `conv + relu` and `conv + batchnorm + relu`
276  model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
277
278  # Prepare the model for static quantization. This inserts observers in
279  # the model that will observe activation tensors during calibration.
280  model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
281
282  # calibrate the prepared model to determine quantization parameters for activations
283  # in a real world setting, the calibration would be done with a representative dataset
284  input_fp32 = torch.randn(4, 1, 4, 4)
285  model_fp32_prepared(input_fp32)
286
287  # Convert the observed model to a quantized model. This does several things:
288  # quantizes the weights, computes and stores the scale and bias value to be
289  # used with each activation tensor, and replaces key operators with quantized
290  # implementations.
291  model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
292
293  # run the model, relevant calculations will happen in int8
294  res = model_int8(input_fp32)
295
296To learn more about static quantization, please see the `static quantization tutorial
297<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_.
298
299Quantization Aware Training for Static Quantization
300~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
301
302Quantization Aware Training (QAT) models the effects of quantization during training
303allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization.  During
304training, all calculations are done in floating point, with fake_quant modules
305modeling the effects of quantization by clamping and rounding to simulate the
306effects of INT8.  After model conversion, weights and
307activations are quantized, and activations are fused into the preceding layer
308where possible.  It is commonly used with CNNs and yields a higher accuracy
309compared to static quantization.
310
311We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_.
312
313Diagram::
314
315  # original model
316  # all tensors and computations are in floating point
317  previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
318                        /
319      linear_weight_fp32
320
321  # model with fake_quants for modeling quantization numerics during training
322  previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
323                             /
324     linear_weight_fp32 -- fq
325
326  # quantized model
327  # weights and activations are in int8
328  previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
329                       /
330     linear_weight_int8
331
332QAT API Example::
333
334  import torch
335
336  # define a floating point model where some layers could benefit from QAT
337  class M(torch.nn.Module):
338      def __init__(self):
339          super().__init__()
340          # QuantStub converts tensors from floating point to quantized
341          self.quant = torch.ao.quantization.QuantStub()
342          self.conv = torch.nn.Conv2d(1, 1, 1)
343          self.bn = torch.nn.BatchNorm2d(1)
344          self.relu = torch.nn.ReLU()
345          # DeQuantStub converts tensors from quantized to floating point
346          self.dequant = torch.ao.quantization.DeQuantStub()
347
348      def forward(self, x):
349          x = self.quant(x)
350          x = self.conv(x)
351          x = self.bn(x)
352          x = self.relu(x)
353          x = self.dequant(x)
354          return x
355
356  # create a model instance
357  model_fp32 = M()
358
359  # model must be set to eval for fusion to work
360  model_fp32.eval()
361
362  # attach a global qconfig, which contains information about what kind
363  # of observers to attach. Use 'x86' for server inference and 'qnnpack'
364  # for mobile inference. Other quantization configurations such as selecting
365  # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
366  # can be specified here.
367  # Note: the old 'fbgemm' is still available but 'x86' is the recommended default
368  # for server inference.
369  # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
370  model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
371
372  # fuse the activations to preceding layers, where applicable
373  # this needs to be done manually depending on the model architecture
374  model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
375      [['conv', 'bn', 'relu']])
376
377  # Prepare the model for QAT. This inserts observers and fake_quants in
378  # the model needs to be set to train for QAT logic to work
379  # the model that will observe weight and activation tensors during calibration.
380  model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())
381
382  # run the training loop (not shown)
383  training_loop(model_fp32_prepared)
384
385  # Convert the observed model to a quantized model. This does several things:
386  # quantizes the weights, computes and stores the scale and bias value to be
387  # used with each activation tensor, fuses modules where appropriate,
388  # and replaces key operators with quantized implementations.
389  model_fp32_prepared.eval()
390  model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
391
392  # run the model, relevant calculations will happen in int8
393  res = model_int8(input_fp32)
394
395To learn more about quantization aware training, please see the `QAT
396tutorial
397<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_.
398
399Model Preparation for Eager Mode Static Quantization
400~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
402It is necessary to currently make some modifications to the model definition
403prior to Eager mode quantization. This is because currently quantization works on a module
404by module basis. Specifically, for all quantization techniques, the user needs to:
405
4061. Convert any operations that require output requantization (and thus have
407   additional parameters) from functionals to module form (for example,
408   using ``torch.nn.ReLU`` instead of ``torch.nn.functional.relu``).
4092. Specify which parts of the model need to be quantized either by assigning
410   ``.qconfig`` attributes on submodules or by specifying ``qconfig_mapping``.
411   For example, setting ``model.conv1.qconfig = None`` means that the
412   ``model.conv`` layer will not be quantized, and setting
413   ``model.linear1.qconfig = custom_qconfig`` means that the quantization
414   settings for ``model.linear1`` will be using ``custom_qconfig`` instead
415   of the global qconfig.
416
417For static quantization techniques which quantize activations, the user needs
418to do the following in addition:
419
4201. Specify where activations are quantized and de-quantized. This is done using
421   :class:`~torch.ao.quantization.QuantStub` and
422   :class:`~torch.ao.quantization.DeQuantStub` modules.
4232. Use :class:`~torch.ao.nn.quantized.FloatFunctional` to wrap tensor operations
424   that require special handling for quantization into modules. Examples
425   are operations like ``add`` and ``cat`` which require special handling to
426   determine output quantization parameters.
4273. Fuse modules: combine operations/modules into a single module to obtain
428   higher accuracy and performance. This is done using the
429   :func:`~torch.ao.quantization.fuse_modules.fuse_modules` API, which takes in lists of modules
430   to be fused. We currently support the following fusions:
431   [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]
432
433(Prototype - maintenance mode) FX Graph Mode Quantization
434^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
435
436There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_mapping` (an argument of the `prepare_fx` function).
437
438FXPTQ API Example::
439
440  import torch
441  from torch.ao.quantization import (
442    get_default_qconfig_mapping,
443    get_default_qat_qconfig_mapping,
444    QConfigMapping,
445  )
446  import torch.ao.quantization.quantize_fx as quantize_fx
447  import copy
448
449  model_fp = UserModel()
450
451  #
452  # post training dynamic/weight_only quantization
453  #
454
455  # we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
456  model_to_quantize = copy.deepcopy(model_fp)
457  model_to_quantize.eval()
458  qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
459  # a tuple of one or more example inputs are needed to trace the model
460  example_inputs = (input_fp32)
461  # prepare
462  model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
463  # no calibration needed when we only have dynamic/weight_only quantization
464  # quantize
465  model_quantized = quantize_fx.convert_fx(model_prepared)
466
467  #
468  # post training static quantization
469  #
470
471  model_to_quantize = copy.deepcopy(model_fp)
472  qconfig_mapping = get_default_qconfig_mapping("qnnpack")
473  model_to_quantize.eval()
474  # prepare
475  model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
476  # calibrate (not shown)
477  # quantize
478  model_quantized = quantize_fx.convert_fx(model_prepared)
479
480  #
481  # quantization aware training for static quantization
482  #
483
484  model_to_quantize = copy.deepcopy(model_fp)
485  qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
486  model_to_quantize.train()
487  # prepare
488  model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
489  # training loop (not shown)
490  # quantize
491  model_quantized = quantize_fx.convert_fx(model_prepared)
492
493  #
494  # fusion
495  #
496  model_to_quantize = copy.deepcopy(model_fp)
497  model_fused = quantize_fx.fuse_fx(model_to_quantize)
498
499Please follow the tutorials below to learn more about FX Graph Mode Quantization:
500
501- `User Guide on Using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_
502- `FX Graph Mode Post Training Static Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`_
503- `FX Graph Mode Post Training Dynamic Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html>`_
504
505(Prototype) PyTorch 2 Export Quantization
506^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
507API Example::
508
509  import torch
510  from torch.ao.quantization.quantize_pt2e import prepare_pt2e
511  from torch._export import capture_pre_autograd_graph
512  from torch.ao.quantization.quantizer import (
513      XNNPACKQuantizer,
514      get_symmetric_quantization_config,
515  )
516
517  class M(torch.nn.Module):
518      def __init__(self):
519          super().__init__()
520          self.linear = torch.nn.Linear(5, 10)
521
522     def forward(self, x):
523         return self.linear(x)
524
525  # initialize a floating point model
526  float_model = M().eval()
527
528  # define calibration function
529  def calibrate(model, data_loader):
530      model.eval()
531      with torch.no_grad():
532          for image, target in data_loader:
533              model(image)
534
535  # Step 1. program capture
536  # NOTE: this API will be updated to torch.export API in the future, but the captured
537  # result should mostly stay the same
538  m = capture_pre_autograd_graph(m, *example_inputs)
539  # we get a model with aten ops
540
541  # Step 2. quantization
542  # backend developer will write their own Quantizer and expose methods to allow
543  # users to express how they
544  # want the model to be quantized
545  quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
546  # or prepare_qat_pt2e for Quantization Aware Training
547  m = prepare_pt2e(m, quantizer)
548
549  # run calibration
550  # calibrate(m, sample_inference_data)
551  m = convert_pt2e(m)
552
553  # Step 3. lowering
554  # lower to target backend
555
556
557Please follow these tutorials to get started on PyTorch 2 Export Quantization:
558
559Modeling Users:
560
561- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
562- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_
563- `PyTorch 2 Export Quantization Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
564
565Backend Developers (please check out all Modeling Users docs as well):
566
567- `How to Write a Quantizer for PyTorch 2 Export Quantization <https://pytorch.org/tutorials/prototype/pt2e_quantizer.html>`_
568
569
570Quantization Stack
571------------------------
572Quantization is the process to convert a floating point model to a quantized model. So at high level the quantization stack can be split into two parts: 1). The building blocks or abstractions for a quantized model 2). The building blocks or abstractions for the quantization flow that converts a floating point model to a quantized model
573
574Quantized Model
575^^^^^^^^^^^^^^^^^^^^^^^
576Quantized Tensor
577~~~~~~~~~~~~~~~~~
578In order to do quantization in PyTorch, we need to be able to represent
579quantized data in Tensors. A Quantized Tensor allows for storing
580quantized data (represented as int8/uint8/int32) along with quantization
581parameters like scale and zero\_point. Quantized Tensors allow for many
582useful operations making quantized arithmetic easy, in addition to
583allowing for serialization of data in a quantized format.
584
585PyTorch supports both per tensor and per channel symmetric and asymmetric quantization. Per tensor means that all the values within the tensor are quantized the same way with the same quantization parameters. Per channel means that for each dimension, typically the channel dimension of a tensor, the values in the tensor are quantized with different quantization parameters. This allows for less error in converting tensors to quantized values since outlier values would only impact the channel it was in, instead of the entire Tensor.
586
587The mapping is performed by converting the floating point tensors using
588
589.. image:: math-quantizer-equation.png
590   :width: 40%
591
592Note that, we ensure that zero in floating point is represented with no error
593after quantization, thereby ensuring that operations like padding do not cause
594additional quantization error.
595
596Here are a few key attributes for quantized Tensor:
597
598* QScheme (torch.qscheme): a enum that specifies the way we quantize the Tensor
599
600  * torch.per_tensor_affine
601  * torch.per_tensor_symmetric
602  * torch.per_channel_affine
603  * torch.per_channel_symmetric
604
605* dtype (torch.dtype): data type of the quantized Tensor
606
607  * torch.quint8
608  * torch.qint8
609  * torch.qint32
610  * torch.float16
611
612* quantization parameters (varies based on QScheme): parameters for the chosen way of quantization
613
614  * torch.per_tensor_affine would have quantization parameters of
615
616    * scale (float)
617    * zero_point (int)
618  * torch.per_channel_affine would have quantization parameters of
619
620    * per_channel_scales (list of float)
621    * per_channel_zero_points (list of int)
622    * axis (int)
623
624Quantize and Dequantize
625~~~~~~~~~~~~~~~~~~~~~~~
626The input and output of a model are floating point Tensors, but activations in the quantized model are quantized, so we need operators to convert between floating point and quantized Tensors.
627
628* Quantize (float -> quantized)
629
630  * torch.quantize_per_tensor(x, scale, zero_point, dtype)
631  * torch.quantize_per_channel(x, scales, zero_points, axis, dtype)
632  * torch.quantize_per_tensor_dynamic(x, dtype, reduce_range)
633  * to(torch.float16)
634
635* Dequantize (quantized -> float)
636
637  * quantized_tensor.dequantize() - calling dequantize on a torch.float16 Tensor will convert the Tensor back to torch.float
638  * torch.dequantize(x)
639
640Quantized Operators/Modules
641~~~~~~~~~~~~~~~~~~~~~~~~~~~
642* Quantized Operator are the operators that takes quantized Tensor as inputs, and outputs a quantized Tensor.
643* Quantized Modules are PyTorch Modules that performs quantized operations. They are typically defined for weighted operations like linear and conv.
644
645Quantized Engine
646~~~~~~~~~~~~~~~~~~~~
647When a quantized model is executed, the qengine (torch.backends.quantized.engine) specifies which backend is to be used for execution. It is important to ensure that the qengine is compatible with the quantized model in terms of value range of quantized activation and weights.
648
649Quantization Flow
650^^^^^^^^^^^^^^^^^^^^^^^
651
652Observer and FakeQuantize
653~~~~~~~~~~~~~~~~~~~~~~~~~~
654* Observer are PyTorch Modules used to:
655
656  * collect tensor statistics like min value and max value of the Tensor passing through the observer
657  * and calculate quantization parameters based on the collected tensor statistics
658* FakeQuantize are PyTorch Modules used to:
659
660  * simulate quantization (performing quantize/dequantize) for a Tensor in the network
661  * it can calculate quantization parameters based on the collected statistics from observer, or it can learn the quantization parameters as well
662
663QConfig
664~~~~~~~~~~~
665* QConfig is a namedtuple of Observer or FakeQuantize Module class that can are configurable with qscheme, dtype etc. it is used to configure how an operator should be observed
666
667  * Quantization configuration for an operator/module
668
669    * different types of Observer/FakeQuantize
670    * dtype
671    * qscheme
672    * quant_min/quant_max: can be used to simulate lower precision Tensors
673  * Currently supports configuration for activation and weight
674  * We insert input/weight/output observer based on the qconfig that is configured for a given operator or module
675
676General Quantization Flow
677~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
678In general, the flow is the following
679
680* prepare
681
682  * insert Observer/FakeQuantize modules based on user specified qconfig
683
684* calibrate/train (depending on post training quantization or quantization aware training)
685
686  * allow Observers to collect statistics or FakeQuantize modules to learn the quantization parameters
687
688* convert
689
690  * convert a calibrated/trained model to a quantized model
691
692There are different modes of quantization, they can be classified in two ways:
693
694In terms of where we apply the quantization flow, we have:
695
6961. Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data)
6972. Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data)
698
699And in terms of how we quantize the operators, we can have:
700
701- Weight Only Quantization (only weight is statically quantized)
702- Dynamic Quantization (weight is statically quantized, activation is dynamically quantized)
703- Static Quantization (both weight and activations are statically quantized)
704
705We can mix different ways of quantizing operators in the same quantization flow. For example, we can have post training quantization that has both statically and dynamically quantized operators.
706
707Quantization Support Matrix
708--------------------------------------
709Quantization Mode Support
710^^^^^^^^^^^^^^^^^^^^^^^^^^^
711+-----------------------------+------------------------------------------------------+----------------+----------------+------------+-----------------+
712|                             |Quantization                                          |Dataset         | Works Best For | Accuracy   |      Notes      |
713|                             |Mode                                                  |Requirement     |                |            |                 |
714+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+
715|Post Training Quantization   |Dynamic/Weight Only Quantization |activation          |None            |LSTM, MLP,      |good        |Easy to use,     |
716|                             |                                 |dynamically         |                |Embedding,      |            |close to static  |
717|                             |                                 |quantized (fp16,    |                |Transformer     |            |quantization when|
718|                             |                                 |int8) or not        |                |                |            |performance is   |
719|                             |                                 |quantized, weight   |                |                |            |compute or memory|
720|                             |                                 |statically quantized|                |                |            |bound due to     |
721|                             |                                 |(fp16, int8, in4)   |                |                |            |weights          |
722|                             +---------------------------------+--------------------+----------------+----------------+------------+-----------------+
723|                             |Static Quantization              |activation and      |calibration     |CNN             |good        |Provides best    |
724|                             |                                 |weights statically  |dataset         |                |            |perf, may have   |
725|                             |                                 |quantized (int8)    |                |                |            |big impact on    |
726|                             |                                 |                    |                |                |            |accuracy, good   |
727|                             |                                 |                    |                |                |            |for hardwares    |
728|                             |                                 |                    |                |                |            |that only support|
729|                             |                                 |                    |                |                |            |int8 computation |
730+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+
731|                             |Dynamic Quantization             |activation and      |fine-tuning     |MLP, Embedding  |best        |Limited support  |
732|                             |                                 |weight are fake     |dataset         |                |            |for now          |
733|                             |                                 |quantized           |                |                |            |                 |
734|                             +---------------------------------+--------------------+----------------+----------------+------------+-----------------+
735|                             |Static Quantization              |activation and      |fine-tuning     |CNN, MLP,       |best        |Typically used   |
736|                             |                                 |weight are fake     |dataset         |Embedding       |            |when static      |
737|                             |                                 |quantized           |                |                |            |quantization     |
738|                             |                                 |                    |                |                |            |leads to bad     |
739|                             |                                 |                    |                |                |            |accuracy, and    |
740|                             |                                 |                    |                |                |            |used to close the|
741|                             |                                 |                    |                |                |            |accuracy gap     |
742|Quantization Aware Training  |                                 |                    |                |                |            |                 |
743+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+
744
745Please see our `Introduction to Quantization on Pytorch
746<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post
747for a more comprehensive overview of the tradeoffs between these quantization
748types.
749
750Quantization Flow Support
751^^^^^^^^^^^^^^^^^^^^^^^^^^^
752PyTorch provides two modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization.
753
754Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals.
755
756FX Graph Mode Quantization is an automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable.
757
758New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of `using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_ or fall back to eager mode quantization.
759
760The following table compares the differences between Eager Mode Quantization and FX Graph Mode Quantization:
761
762+-----------------+-------------------+-------------------+
763|                 |Eager Mode         |FX Graph           |
764|                 |Quantization       |Mode               |
765|                 |                   |Quantization       |
766+-----------------+-------------------+-------------------+
767|Release          |beta               |prototype          |
768|Status           |                   |                   |
769+-----------------+-------------------+-------------------+
770|Operator         |Manual             |Automatic          |
771|Fusion           |                   |                   |
772+-----------------+-------------------+-------------------+
773|Quant/DeQuant    |Manual             |Automatic          |
774|Placement        |                   |                   |
775+-----------------+-------------------+-------------------+
776|Quantizing       |Supported          |Supported          |
777|Modules          |                   |                   |
778+-----------------+-------------------+-------------------+
779|Quantizing       |Manual             |Automatic          |
780|Functionals/Torch|                   |                   |
781|Ops              |                   |                   |
782+-----------------+-------------------+-------------------+
783|Support for      |Limited Support    |Fully              |
784|Customization    |                   |Supported          |
785+-----------------+-------------------+-------------------+
786|Quantization Mode|Post Training      |Post Training      |
787|Support          |Quantization:      |Quantization:      |
788|                 |Static, Dynamic,   |Static, Dynamic,   |
789|                 |Weight Only        |Weight Only        |
790|                 |                   |                   |
791|                 |Quantization Aware |Quantization Aware |
792|                 |Training:          |Training:          |
793|                 |Static             |Static             |
794+-----------------+-------------------+-------------------+
795|Input/Output     |``torch.nn.Module``|``torch.nn.Module``|
796|Model Type       |                   |(May need some     |
797|                 |                   |refactors to make  |
798|                 |                   |the model          |
799|                 |                   |compatible with FX |
800|                 |                   |Graph Mode         |
801|                 |                   |Quantization)      |
802+-----------------+-------------------+-------------------+
803
804Backend/Hardware Support
805^^^^^^^^^^^^^^^^^^^^^^^^^^^
806+-----------------+---------------+------------+------------+------------+
807|Hardware         |Kernel Library |Eager Mode  |FX Graph    |Quantization|
808|                 |               |Quantization|Mode        |Mode Support|
809|                 |               |            |Quantization|            |
810+-----------------+---------------+------------+------------+------------+
811|server CPU       |fbgemm/onednn  |Supported                |All         |
812|                 |               |                         |Supported   |
813+-----------------+---------------+                         |            +
814|mobile CPU       |qnnpack/xnnpack|                         |            |
815|                 |               |                         |            |
816+-----------------+---------------+------------+------------+------------+
817|server GPU       |TensorRT (early|Not support |Supported   |Static      |
818|                 |prototype)     |this it     |            |Quantization|
819|                 |               |requires a  |            |            |
820|                 |               |graph       |            |            |
821+-----------------+---------------+------------+------------+------------+
822
823Today, PyTorch supports the following backends for running quantized operators efficiently:
824
825* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via `x86` optimized by `fbgemm <https://github.com/pytorch/FBGEMM>`_ and `onednn <https://github.com/oneapi-src/oneDNN>`_ (see the details at `RFC <https://github.com/pytorch/pytorch/issues/83888>`_)
826* ARM CPUs (typically found in mobile/embedded devices), via `qnnpack <https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native/quantized/cpu/qnnpack>`_
827* (early prototype) support for NVidia GPU via `TensorRT <https://developer.nvidia.com/tensorrt>`_ through `fx2trt` (to be open sourced)
828
829
830Note for native CPU backends
831~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
832We expose both `x86` and `qnnpack` with the same native pytorch quantized operators, so we need additional flag to distinguish between them. The corresponding implementation of  `x86` and `qnnpack` is chosen automatically based on the PyTorch build mode, though users have the option to override this by setting `torch.backends.quantization.engine` to `x86` or `qnnpack`.
833
834When preparing a quantized model, it is necessary to ensure that qconfig
835and the engine used for quantized computations match the backend on which
836the model will be executed. The qconfig controls the type of observers used
837during the quantization passes. The qengine controls whether `x86` or `qnnpack`
838specific packing function is used when packing weights for
839linear and convolution functions and modules. For example:
840
841Default settings for x86::
842
843    # set the qconfig for PTQ
844    # Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs
845    qconfig = torch.ao.quantization.get_default_qconfig('x86')
846    # or, set the qconfig for QAT
847    qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
848    # set the qengine to control weight packing
849    torch.backends.quantized.engine = 'x86'
850
851Default settings for qnnpack::
852
853    # set the qconfig for PTQ
854    qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
855    # or, set the qconfig for QAT
856    qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
857    # set the qengine to control weight packing
858    torch.backends.quantized.engine = 'qnnpack'
859
860Operator Support
861^^^^^^^^^^^^^^^^^^^^
862
863Operator coverage varies between dynamic and static quantization and is captured in the table below.
864Note that for FX Graph Mode Quantization, the corresponding functionals are also supported.
865
866+---------------------------+-------------------+--------------------+
867|                           |Static             | Dynamic            |
868|                           |Quantization       | Quantization       |
869+---------------------------+-------------------+--------------------+
870| | nn.Linear               | | Y               | | Y                |
871| | nn.Conv1d/2d/3d         | | Y               | | N                |
872+---------------------------+-------------------+--------------------+
873| | nn.LSTM                 | | N               | | Y                |
874| | nn.GRU                  | | N               | | Y                |
875+---------------------------+-------------------+--------------------+
876| | nn.RNNCell              | | N               | | Y                |
877| | nn.GRUCell              | | N               | | Y                |
878| | nn.LSTMCell             | | N               | | Y                |
879+---------------------------+-------------------+--------------------+
880|nn.EmbeddingBag            | Y (activations    |                    |
881|                           | are in fp32)      | Y                  |
882+---------------------------+-------------------+--------------------+
883|nn.Embedding               | Y                 | Y                  |
884+---------------------------+-------------------+--------------------+
885|nn.MultiheadAttention      |Not Supported      | Not supported      |
886+---------------------------+-------------------+--------------------+
887|Activations                |Broadly supported  | Un-changed,        |
888|                           |                   | computations       |
889|                           |                   | stay in fp32       |
890+---------------------------+-------------------+--------------------+
891
892Note: this will be updated with some information generated from native backend_config_dict soon.
893
894Quantization API Reference
895---------------------------
896
897The :doc:`Quantization API Reference <quantization-support>` contains documentation
898of quantization APIs, such as quantization passes, quantized tensor operations,
899and supported quantized modules and functions.
900
901.. toctree::
902    :hidden:
903
904    quantization-support
905
906Quantization Backend Configuration
907----------------------------------
908
909The :doc:`Quantization Backend Configuration <quantization-backend-configuration>` contains documentation
910on how to configure the quantization workflows for various backends.
911
912.. toctree::
913    :hidden:
914
915    quantization-backend-configuration
916
917Quantization Accuracy Debugging
918-------------------------------
919
920The :doc:`Quantization Accuracy Debugging <quantization-accuracy-debugging>` contains documentation
921on how to debug quantization accuracy.
922
923.. toctree::
924    :hidden:
925
926    quantization-accuracy-debugging
927
928Quantization Customizations
929---------------------------
930
931While default implementations of observers to select the scale factor and bias
932based on observed tensor data are provided, developers can provide their own
933quantization functions. Quantization can be applied selectively to different
934parts of the model or configured differently for different parts of the model.
935
936We also provide support for per channel quantization for **conv1d()**, **conv2d()**,
937**conv3d()** and **linear()**.
938
939Quantization workflows work by adding (e.g. adding observers as
940``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to
941``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It
942means that the model stays a regular ``nn.Module``-based instance throughout the
943process and thus can work with the rest of PyTorch APIs.
944
945Quantization Custom Module API
946^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
947
948Both Eager mode and FX graph mode quantization APIs provide a hook for the user
949to specify module quantized in a custom way, with user defined logic for
950observation and quantization. The user needs to specify:
951
9521. The Python type of the source fp32 module (existing in the model)
9532. The Python type of the observed module (provided by user). This module needs
954   to define a `from_float` function which defines how the observed module is
955   created from the original fp32 module.
9563. The Python type of the quantized module (provided by user). This module needs
957   to define a `from_observed` function which defines how the quantized module is
958   created from the observed module.
9594. A configuration describing (1), (2), (3) above, passed to the quantization APIs.
960
961
962The framework will then do the following:
963
9641. during the `prepare` module swaps, it will convert every module of type
965   specified in (1) to the type specified in (2), using the `from_float` function of
966   the class in (2).
9672. during the `convert` module swaps, it will convert every module of type
968   specified in (2) to the type specified in (3), using the `from_observed` function
969   of the class in (3).
970
971Currently, there is a requirement that `ObservedCustomModule` will have a single
972Tensor output, and an observer will be added by the framework (not by the user)
973on that output. The observer will be stored under the `activation_post_process` key
974as an attribute of the custom module instance. Relaxing these restrictions may
975be done at a future time.
976
977Custom API Example::
978
979  import torch
980  import torch.ao.nn.quantized as nnq
981  from torch.ao.quantization import QConfigMapping
982  import torch.ao.quantization.quantize_fx
983
984  # original fp32 module to replace
985  class CustomModule(torch.nn.Module):
986      def __init__(self):
987          super().__init__()
988          self.linear = torch.nn.Linear(3, 3)
989
990      def forward(self, x):
991          return self.linear(x)
992
993  # custom observed module, provided by user
994  class ObservedCustomModule(torch.nn.Module):
995      def __init__(self, linear):
996          super().__init__()
997          self.linear = linear
998
999      def forward(self, x):
1000          return self.linear(x)
1001
1002      @classmethod
1003      def from_float(cls, float_module):
1004          assert hasattr(float_module, 'qconfig')
1005          observed = cls(float_module.linear)
1006          observed.qconfig = float_module.qconfig
1007          return observed
1008
1009  # custom quantized module, provided by user
1010  class StaticQuantCustomModule(torch.nn.Module):
1011      def __init__(self, linear):
1012          super().__init__()
1013          self.linear = linear
1014
1015      def forward(self, x):
1016          return self.linear(x)
1017
1018      @classmethod
1019      def from_observed(cls, observed_module):
1020          assert hasattr(observed_module, 'qconfig')
1021          assert hasattr(observed_module, 'activation_post_process')
1022          observed_module.linear.activation_post_process = \
1023              observed_module.activation_post_process
1024          quantized = cls(nnq.Linear.from_float(observed_module.linear))
1025          return quantized
1026
1027  #
1028  # example API call (Eager mode quantization)
1029  #
1030
1031  m = torch.nn.Sequential(CustomModule()).eval()
1032  prepare_custom_config_dict = {
1033      "float_to_observed_custom_module_class": {
1034          CustomModule: ObservedCustomModule
1035      }
1036  }
1037  convert_custom_config_dict = {
1038      "observed_to_quantized_custom_module_class": {
1039          ObservedCustomModule: StaticQuantCustomModule
1040      }
1041  }
1042  m.qconfig = torch.ao.quantization.default_qconfig
1043  mp = torch.ao.quantization.prepare(
1044      m, prepare_custom_config_dict=prepare_custom_config_dict)
1045  # calibration (not shown)
1046  mq = torch.ao.quantization.convert(
1047      mp, convert_custom_config_dict=convert_custom_config_dict)
1048  #
1049  # example API call (FX graph mode quantization)
1050  #
1051  m = torch.nn.Sequential(CustomModule()).eval()
1052  qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
1053  prepare_custom_config_dict = {
1054      "float_to_observed_custom_module_class": {
1055          "static": {
1056              CustomModule: ObservedCustomModule,
1057          }
1058      }
1059  }
1060  convert_custom_config_dict = {
1061      "observed_to_quantized_custom_module_class": {
1062          "static": {
1063              ObservedCustomModule: StaticQuantCustomModule,
1064          }
1065      }
1066  }
1067  mp = torch.ao.quantization.quantize_fx.prepare_fx(
1068      m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
1069  # calibration (not shown)
1070  mq = torch.ao.quantization.quantize_fx.convert_fx(
1071      mp, convert_custom_config=convert_custom_config_dict)
1072
1073Best Practices
1074--------------
1075
10761. If you are using the ``x86`` backend, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the ``quant\_min``, ``quant\_max``, e.g.
1077if ``dtype`` is ``torch.quint8``, make sure to set a custom ``quant_min`` to be ``0`` and ``quant_max`` to be ``127`` (``255`` / ``2``)
1078if ``dtype`` is ``torch.qint8``, make sure to set a custom ``quant_min`` to be ``-64`` (``-128`` / ``2``) and ``quant_max`` to be ``63`` (``127`` / ``2``), we already set this correctly if
1079you call the `torch.ao.quantization.get_default_qconfig(backend)` or `torch.ao.quantization.get_default_qat_qconfig(backend)` function to get the default ``qconfig`` for
1080``x86`` or ``qnnpack`` backend
1081
10822. If ``onednn`` backend is selected, 8 bits for activation will be used in the default qconfig mapping ``torch.ao.quantization.get_default_qconfig_mapping('onednn')``
1083and default qconfig ``torch.ao.quantization.get_default_qconfig('onednn')``. It is recommended to be used on CPUs with Vector Neural Network Instruction (VNNI)
1084support. Otherwise, setting ``reduce_range`` to True of the activation's observer to get better accuracy on CPUs without VNNI support.
1085
1086Frequently Asked Questions
1087--------------------------
1088
10891. How can I do quantized inference on GPU?:
1090
1091   We don't have official GPU support yet, but this is an area of active development, you can find more information
1092   `here <https://github.com/pytorch/pytorch/issues/87395>`_
1093
10942. Where can I get ONNX support for my quantized model?
1095
1096   If you get errors exporting the model (using APIs under ``torch.onnx``), you may open an issue in the PyTorch repository. Prefix the issue title with ``[ONNX]`` and tag the issue as ``module: onnx``.
1097
1098   If you encounter issues with ONNX Runtime, open an issue at `GitHub - microsoft/onnxruntime <https://github.com/microsoft/onnxruntime/issues/>`_.
1099
11003. How can I use quantization with LSTM's?:
1101
1102   LSTM is supported through our custom module api in both eager mode and fx graph mode quantization. Examples can be found at
1103   Eager Mode: `pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm <https://github.com/pytorch/pytorch/blob/9b88dcf248e717ca6c3f8c5e11f600825547a561/test/quantization/core/test_quantized_op.py#L2782>`_
1104   FX Graph Mode: `pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm <https://github.com/pytorch/pytorch/blob/9b88dcf248e717ca6c3f8c5e11f600825547a561/test/quantization/fx/test_quantize_fx.py#L4116>`_
1105
1106Common Errors
1107---------------------------------------
1108
1109Passing a non-quantized Tensor into a quantized kernel
1110^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
1112If you see an error similar to::
1113
1114  RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend...
1115
1116This means that you are trying to pass a non-quantized Tensor to a quantized
1117kernel. A common workaround is to use ``torch.ao.quantization.QuantStub`` to
1118quantize the tensor.  This needs to be done manually in Eager mode quantization.
1119An e2e example::
1120
1121  class M(torch.nn.Module):
1122      def __init__(self):
1123          super().__init__()
1124          self.quant = torch.ao.quantization.QuantStub()
1125          self.conv = torch.nn.Conv2d(1, 1, 1)
1126
1127      def forward(self, x):
1128          # during the convert step, this will be replaced with a
1129          # `quantize_per_tensor` call
1130          x = self.quant(x)
1131          x = self.conv(x)
1132          return x
1133
1134Passing a quantized Tensor into a non-quantized kernel
1135^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1136
1137If you see an error similar to::
1138
1139  RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend.
1140
1141This means that you are trying to pass a quantized Tensor to a non-quantized
1142kernel. A common workaround is to use ``torch.ao.quantization.DeQuantStub`` to
1143dequantize the tensor.  This needs to be done manually in Eager mode quantization.
1144An e2e example::
1145
1146  class M(torch.nn.Module):
1147      def __init__(self):
1148          super().__init__()
1149          self.quant = torch.ao.quantization.QuantStub()
1150          self.conv1 = torch.nn.Conv2d(1, 1, 1)
1151          # this module will not be quantized (see `qconfig = None` logic below)
1152          self.conv2 = torch.nn.Conv2d(1, 1, 1)
1153          self.dequant = torch.ao.quantization.DeQuantStub()
1154
1155      def forward(self, x):
1156          # during the convert step, this will be replaced with a
1157          # `quantize_per_tensor` call
1158          x = self.quant(x)
1159          x = self.conv1(x)
1160          # during the convert step, this will be replaced with a
1161          # `dequantize` call
1162          x = self.dequant(x)
1163          x = self.conv2(x)
1164          return x
1165
1166  m = M()
1167  m.qconfig = some_qconfig
1168  # turn off quantization for conv2
1169  m.conv2.qconfig = None
1170
1171Saving and Loading Quantized models
1172^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1173
1174When calling ``torch.load`` on a quantized model, if you see an error like::
1175
1176  AttributeError: 'LinearPackedParams' object has no attribute '_modules'
1177
1178This is because directly saving and loading a quantized model using ``torch.save`` and ``torch.load``
1179is not supported. To save/load quantized models, the following ways can be used:
1180
11811. Saving/Loading the quantized model state_dict
1182
1183An example::
1184
1185  class M(torch.nn.Module):
1186      def __init__(self):
1187          super().__init__()
1188          self.linear = nn.Linear(5, 5)
1189          self.relu = nn.ReLU()
1190
1191      def forward(self, x):
1192          x = self.linear(x)
1193          x = self.relu(x)
1194          return x
1195
1196  m = M().eval()
1197  prepare_orig = prepare_fx(m, {'' : default_qconfig})
1198  prepare_orig(torch.rand(5, 5))
1199  quantized_orig = convert_fx(prepare_orig)
1200
1201  # Save/load using state_dict
1202  b = io.BytesIO()
1203  torch.save(quantized_orig.state_dict(), b)
1204
1205  m2 = M().eval()
1206  prepared = prepare_fx(m2, {'' : default_qconfig})
1207  quantized = convert_fx(prepared)
1208  b.seek(0)
1209  quantized.load_state_dict(torch.load(b))
1210
12112. Saving/Loading scripted quantized models using ``torch.jit.save`` and ``torch.jit.load``
1212
1213An example::
1214
1215  # Note: using the same model M from previous example
1216  m = M().eval()
1217  prepare_orig = prepare_fx(m, {'' : default_qconfig})
1218  prepare_orig(torch.rand(5, 5))
1219  quantized_orig = convert_fx(prepare_orig)
1220
1221  # save/load using scripted model
1222  scripted = torch.jit.script(quantized_orig)
1223  b = io.BytesIO()
1224  torch.jit.save(scripted, b)
1225  b.seek(0)
1226  scripted_quantized = torch.jit.load(b)
1227
1228Symbolic Trace Error when using FX Graph Mode Quantization
1229^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1230Symbolic traceability is a requirement for `(Prototype - maintenance mode) FX Graph Mode Quantization`_, so if you pass a PyTorch Model that is not symbolically traceable to `torch.ao.quantization.prepare_fx` or `torch.ao.quantization.prepare_qat_fx`, we might see an error like the following::
1231
1232  torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
1233
1234Please take a look at `Limitations of Symbolic Tracing <https://pytorch.org/docs/2.0/fx.html#limitations-of-symbolic-tracing>`_ and use - `User Guide on Using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_ to workaround the problem.
1235
1236
1237.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
1238.. They are here for tracking purposes until they are more permanently fixed.
1239.. py:module:: torch.ao
1240.. py:module:: torch.ao.nn
1241.. py:module:: torch.ao.nn.quantizable
1242.. py:module:: torch.ao.nn.quantizable.modules
1243.. py:module:: torch.ao.nn.quantized
1244.. py:module:: torch.ao.nn.quantized.reference
1245.. py:module:: torch.ao.nn.quantized.reference.modules
1246.. py:module:: torch.ao.nn.sparse
1247.. py:module:: torch.ao.nn.sparse.quantized
1248.. py:module:: torch.ao.nn.sparse.quantized.dynamic
1249.. py:module:: torch.ao.ns
1250.. py:module:: torch.ao.ns.fx
1251.. py:module:: torch.ao.quantization.backend_config
1252.. py:module:: torch.ao.pruning
1253.. py:module:: torch.ao.pruning.scheduler
1254.. py:module:: torch.ao.pruning.sparsifier
1255.. py:module:: torch.ao.nn.intrinsic.modules.fused
1256.. py:module:: torch.ao.nn.intrinsic.qat.modules.conv_fused
1257.. py:module:: torch.ao.nn.intrinsic.qat.modules.linear_fused
1258.. py:module:: torch.ao.nn.intrinsic.qat.modules.linear_relu
1259.. py:module:: torch.ao.nn.intrinsic.quantized.dynamic.modules.linear_relu
1260.. py:module:: torch.ao.nn.intrinsic.quantized.modules.bn_relu
1261.. py:module:: torch.ao.nn.intrinsic.quantized.modules.conv_add
1262.. py:module:: torch.ao.nn.intrinsic.quantized.modules.conv_relu
1263.. py:module:: torch.ao.nn.intrinsic.quantized.modules.linear_relu
1264.. py:module:: torch.ao.nn.qat.dynamic.modules.linear
1265.. py:module:: torch.ao.nn.qat.modules.conv
1266.. py:module:: torch.ao.nn.qat.modules.embedding_ops
1267.. py:module:: torch.ao.nn.qat.modules.linear
1268.. py:module:: torch.ao.nn.quantizable.modules.activation
1269.. py:module:: torch.ao.nn.quantizable.modules.rnn
1270.. py:module:: torch.ao.nn.quantized.dynamic.modules.conv
1271.. py:module:: torch.ao.nn.quantized.dynamic.modules.linear
1272.. py:module:: torch.ao.nn.quantized.dynamic.modules.rnn
1273.. py:module:: torch.ao.nn.quantized.modules.activation
1274.. py:module:: torch.ao.nn.quantized.modules.batchnorm
1275.. py:module:: torch.ao.nn.quantized.modules.conv
1276.. py:module:: torch.ao.nn.quantized.modules.dropout
1277.. py:module:: torch.ao.nn.quantized.modules.embedding_ops
1278.. py:module:: torch.ao.nn.quantized.modules.functional_modules
1279.. py:module:: torch.ao.nn.quantized.modules.linear
1280.. py:module:: torch.ao.nn.quantized.modules.normalization
1281.. py:module:: torch.ao.nn.quantized.modules.rnn
1282.. py:module:: torch.ao.nn.quantized.modules.utils
1283.. py:module:: torch.ao.nn.quantized.reference.modules.conv
1284.. py:module:: torch.ao.nn.quantized.reference.modules.linear
1285.. py:module:: torch.ao.nn.quantized.reference.modules.rnn
1286.. py:module:: torch.ao.nn.quantized.reference.modules.sparse
1287.. py:module:: torch.ao.nn.quantized.reference.modules.utils
1288.. py:module:: torch.ao.nn.sparse.quantized.dynamic.linear
1289.. py:module:: torch.ao.nn.sparse.quantized.linear
1290.. py:module:: torch.ao.nn.sparse.quantized.utils
1291.. py:module:: torch.ao.ns.fx.graph_matcher
1292.. py:module:: torch.ao.ns.fx.graph_passes
1293.. py:module:: torch.ao.ns.fx.mappings
1294.. py:module:: torch.ao.ns.fx.n_shadows_utils
1295.. py:module:: torch.ao.ns.fx.ns_types
1296.. py:module:: torch.ao.ns.fx.pattern_utils
1297.. py:module:: torch.ao.ns.fx.qconfig_multi_mapping
1298.. py:module:: torch.ao.ns.fx.utils
1299.. py:module:: torch.ao.ns.fx.weight_utils
1300.. py:module:: torch.ao.pruning.scheduler.base_scheduler
1301.. py:module:: torch.ao.pruning.scheduler.cubic_scheduler
1302.. py:module:: torch.ao.pruning.scheduler.lambda_scheduler
1303.. py:module:: torch.ao.pruning.sparsifier.base_sparsifier
1304.. py:module:: torch.ao.pruning.sparsifier.nearly_diagonal_sparsifier
1305.. py:module:: torch.ao.pruning.sparsifier.utils
1306.. py:module:: torch.ao.pruning.sparsifier.weight_norm_sparsifier
1307.. py:module:: torch.ao.quantization.backend_config.backend_config
1308.. py:module:: torch.ao.quantization.backend_config.executorch
1309.. py:module:: torch.ao.quantization.backend_config.fbgemm
1310.. py:module:: torch.ao.quantization.backend_config.native
1311.. py:module:: torch.ao.quantization.backend_config.observation_type
1312.. py:module:: torch.ao.quantization.backend_config.onednn
1313.. py:module:: torch.ao.quantization.backend_config.qnnpack
1314.. py:module:: torch.ao.quantization.backend_config.tensorrt
1315.. py:module:: torch.ao.quantization.backend_config.utils
1316.. py:module:: torch.ao.quantization.backend_config.x86
1317.. py:module:: torch.ao.quantization.fake_quantize
1318.. py:module:: torch.ao.quantization.fuser_method_mappings
1319.. py:module:: torch.ao.quantization.fuse_modules
1320.. py:module:: torch.ao.quantization.fx.convert
1321.. py:module:: torch.ao.quantization.fx.custom_config
1322.. py:module:: torch.ao.quantization.fx.fuse
1323.. py:module:: torch.ao.quantization.fx.fuse_handler
1324.. py:module:: torch.ao.quantization.fx.graph_module
1325.. py:module:: torch.ao.quantization.fx.lower_to_fbgemm
1326.. py:module:: torch.ao.quantization.fx.lower_to_qnnpack
1327.. py:module:: torch.ao.quantization.fx.lstm_utils
1328.. py:module:: torch.ao.quantization.fx.match_utils
1329.. py:module:: torch.ao.quantization.fx.pattern_utils
1330.. py:module:: torch.ao.quantization.fx.prepare
1331.. py:module:: torch.ao.quantization.fx.qconfig_mapping_utils
1332.. py:module:: torch.ao.quantization.fx.quantize_handler
1333.. py:module:: torch.ao.quantization.fx.tracer
1334.. py:module:: torch.ao.quantization.fx.utils
1335.. py:module:: torch.ao.quantization.observer
1336.. py:module:: torch.ao.quantization.pt2e.duplicate_dq_pass
1337.. py:module:: torch.ao.quantization.pt2e.export_utils
1338.. py:module:: torch.ao.quantization.pt2e.graph_utils
1339.. py:module:: torch.ao.quantization.pt2e.port_metadata_pass
1340.. py:module:: torch.ao.quantization.pt2e.prepare
1341.. py:module:: torch.ao.quantization.pt2e.qat_utils
1342.. py:module:: torch.ao.quantization.pt2e.representation.rewrite
1343.. py:module:: torch.ao.quantization.pt2e.utils
1344.. py:module:: torch.ao.quantization.qconfig
1345.. py:module:: torch.ao.quantization.qconfig_mapping
1346.. py:module:: torch.ao.quantization.quant_type
1347.. py:module:: torch.ao.quantization.quantization_mappings
1348.. py:module:: torch.ao.quantization.quantize_fx
1349.. py:module:: torch.ao.quantization.quantize_jit
1350.. py:module:: torch.ao.quantization.quantize_pt2e
1351.. py:module:: torch.ao.quantization.quantizer.composable_quantizer
1352.. py:module:: torch.ao.quantization.quantizer.embedding_quantizer
1353.. py:module:: torch.ao.quantization.quantizer.quantizer
1354.. py:module:: torch.ao.quantization.quantizer.utils
1355.. py:module:: torch.ao.quantization.quantizer.x86_inductor_quantizer
1356.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer
1357.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer_utils
1358.. py:module:: torch.ao.quantization.stubs
1359.. py:module:: torch.ao.quantization.utils
1360.. py:module:: torch.nn.intrinsic.modules.fused
1361.. py:module:: torch.nn.intrinsic.qat.modules.conv_fused
1362.. py:module:: torch.nn.intrinsic.qat.modules.linear_fused
1363.. py:module:: torch.nn.intrinsic.qat.modules.linear_relu
1364.. py:module:: torch.nn.intrinsic.quantized.dynamic.modules.linear_relu
1365.. py:module:: torch.nn.intrinsic.quantized.modules.bn_relu
1366.. py:module:: torch.nn.intrinsic.quantized.modules.conv_relu
1367.. py:module:: torch.nn.intrinsic.quantized.modules.linear_relu
1368.. py:module:: torch.nn.qat.dynamic.modules.linear
1369.. py:module:: torch.nn.qat.modules.conv
1370.. py:module:: torch.nn.qat.modules.embedding_ops
1371.. py:module:: torch.nn.qat.modules.linear
1372.. py:module:: torch.nn.quantizable.modules.activation
1373.. py:module:: torch.nn.quantizable.modules.rnn
1374.. py:module:: torch.nn.quantized.dynamic.modules.conv
1375.. py:module:: torch.nn.quantized.dynamic.modules.linear
1376.. py:module:: torch.nn.quantized.dynamic.modules.rnn
1377.. py:module:: torch.nn.quantized.functional
1378.. py:module:: torch.nn.quantized.modules.activation
1379.. py:module:: torch.nn.quantized.modules.batchnorm
1380.. py:module:: torch.nn.quantized.modules.conv
1381.. py:module:: torch.nn.quantized.modules.dropout
1382.. py:module:: torch.nn.quantized.modules.embedding_ops
1383.. py:module:: torch.nn.quantized.modules.functional_modules
1384.. py:module:: torch.nn.quantized.modules.linear
1385.. py:module:: torch.nn.quantized.modules.normalization
1386.. py:module:: torch.nn.quantized.modules.rnn
1387.. py:module:: torch.nn.quantized.modules.utils
1388.. py:module:: torch.quantization.fake_quantize
1389.. py:module:: torch.quantization.fuse_modules
1390.. py:module:: torch.quantization.fuser_method_mappings
1391.. py:module:: torch.quantization.fx.convert
1392.. py:module:: torch.quantization.fx.fuse
1393.. py:module:: torch.quantization.fx.fusion_patterns
1394.. py:module:: torch.quantization.fx.graph_module
1395.. py:module:: torch.quantization.fx.match_utils
1396.. py:module:: torch.quantization.fx.pattern_utils
1397.. py:module:: torch.quantization.fx.prepare
1398.. py:module:: torch.quantization.fx.quantization_patterns
1399.. py:module:: torch.quantization.fx.quantization_types
1400.. py:module:: torch.quantization.fx.utils
1401.. py:module:: torch.quantization.observer
1402.. py:module:: torch.quantization.qconfig
1403.. py:module:: torch.quantization.quant_type
1404.. py:module:: torch.quantization.quantization_mappings
1405.. py:module:: torch.quantization.quantize
1406.. py:module:: torch.quantization.quantize_fx
1407.. py:module:: torch.quantization.quantize_jit
1408.. py:module:: torch.quantization.stubs
1409.. py:module:: torch.quantization.utils
1410