1.. _quantization-doc: 2 3Quantization 4============ 5 6.. automodule:: torch.ao.quantization 7.. automodule:: torch.ao.quantization.fx 8 9.. warning :: 10 Quantization is in beta and subject to change. 11 12Introduction to Quantization 13---------------------------- 14 15Quantization refers to techniques for performing computations and storing 16tensors at lower bitwidths than floating point precision. A quantized model 17executes some or all of the operations on tensors with reduced precision rather than 18full precision (floating point) values. This allows for a more compact model representation and 19the use of high performance vectorized operations on many hardware platforms. 20PyTorch supports INT8 quantization compared to typical FP32 models allowing for 21a 4x reduction in the model size and a 4x reduction in memory bandwidth 22requirements. Hardware support for INT8 computations is typically 2 to 4 23times faster compared to FP32 compute. Quantization is primarily a technique to 24speed up inference and only the forward pass is supported for quantized 25operators. 26 27PyTorch supports multiple approaches to quantizing a deep learning model. In 28most cases the model is trained in FP32 and then the model is converted to 29INT8. In addition, PyTorch also supports quantization aware training, which 30models quantization errors in both the forward and backward passes using 31fake-quantization modules. Note that the entire computation is carried out in 32floating point. At the end of quantization aware training, PyTorch provides 33conversion functions to convert the trained model into lower precision. 34 35At lower level, PyTorch provides a way to represent quantized tensors and 36perform operations with them. They can be used to directly construct models 37that perform all or part of the computation in lower precision. Higher-level 38APIs are provided that incorporate typical workflows of converting FP32 model 39to lower precision with minimal accuracy loss. 40 41Quantization API Summary 42----------------------------- 43 44PyTorch provides three different modes of quantization: Eager Mode Quantization, FX Graph Mode Quantization (maintenance) and PyTorch 2 Export Quantization. 45 46Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. 47 48FX Graph Mode Quantization is an automated quantization workflow in PyTorch, and currently it's a prototype feature, it is in maintenance mode since we have PyTorch 2 Export Quantization. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable. 49 50PyTorch 2 Export Quantization is the new full graph mode quantization workflow, released as prototype feature in PyTorch 2.1. With PyTorch 2, we are moving to a better solution for full program capture (torch.export) since it can capture a higher percentage (88.8% on 14K models) of models compared to torch.fx.symbolic_trace (72.7% on 14K models), the program capture solution used by FX Graph Mode Quantization. torch.export still has limitations around some python constructs and requires user involvement to support dynamism in the exported model, but overall it is an improvement over the previous program capture solution. PyTorch 2 Export Quantization is built for models captured by torch.export, with flexibility and productivity of both modeling users and backend developers in mind. The main features are 51(1). Programmable API for configuring how a model is quantized that can scale to many more use cases 52(2). Simplified UX for modeling users and backend developers since they only need to interact with a single object (Quantizer) for expressing user’s intention about how to quantize a model and what the backend support. 53(3). Optional reference quantized model representation that can represent quantized computation with integer operations that maps closer to actual quantized computations that happens in hardware. 54 55New users of quantization are encouraged to try out PyTorch 2 Export Quantization first, if it does not work well, user can try eager mode quantization. 56 57The following table compares the differences between Eager Mode Quantization, FX Graph Mode Quantization and PyTorch 2 Export Quantization: 58 59+-----------------+-------------------+-------------------+-------------------------+ 60| |Eager Mode |FX Graph |PyTorch 2 Export | 61| |Quantization |Mode |Quantization | 62| | |Quantization | | 63+-----------------+-------------------+-------------------+-------------------------+ 64|Release |beta |prototype |prototype | 65|Status | |(maintenance) | | 66+-----------------+-------------------+-------------------+-------------------------+ 67|Operator |Manual |Automatic |Automatic | 68|Fusion | | | | 69+-----------------+-------------------+-------------------+-------------------------+ 70|Quant/DeQuant |Manual |Automatic |Automatic | 71|Placement | | | | 72+-----------------+-------------------+-------------------+-------------------------+ 73|Quantizing |Supported |Supported |Supported | 74|Modules | | | | 75+-----------------+-------------------+-------------------+-------------------------+ 76|Quantizing |Manual |Automatic |Supported | 77|Functionals/Torch| | | | 78|Ops | | | | 79+-----------------+-------------------+-------------------+-------------------------+ 80|Support for |Limited Support |Fully |Fully Supported | 81|Customization | |Supported | | 82+-----------------+-------------------+-------------------+-------------------------+ 83|Quantization Mode|Post Training |Post Training |Defined by | 84|Support |Quantization: |Quantization: |Backend Specific | 85| |Static, Dynamic, |Static, Dynamic, |Quantizer | 86| |Weight Only |Weight Only | | 87| | | | | 88| |Quantization Aware |Quantization Aware | | 89| |Training: |Training: | | 90| |Static |Static | | 91+-----------------+-------------------+-------------------+-------------------------+ 92|Input/Output |``torch.nn.Module``|``torch.nn.Module``|``torch.fx.GraphModule`` | 93|Model Type | |(May need some |(captured by | 94| | |refactors to make |``torch.export`` | 95| | |the model | | 96| | |compatible with FX | | 97| | |Graph Mode | | 98| | |Quantization) | | 99+-----------------+-------------------+-------------------+-------------------------+ 100 101 102 103There are three types of quantization supported: 104 1051. dynamic quantization (weights quantized with activations read/stored in 106 floating point and quantized for compute) 1072. static quantization (weights quantized, activations quantized, calibration 108 required post training) 1093. static quantization aware training (weights quantized, activations quantized, 110 quantization numerics modeled during training) 111 112Please see our `Introduction to Quantization on PyTorch 113<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post 114for a more comprehensive overview of the tradeoffs between these quantization 115types. 116 117Operator coverage varies between dynamic and static quantization and is captured in the table below. 118 119+---------------------------+-------------------+--------------------+ 120| |Static | Dynamic | 121| |Quantization | Quantization | 122+---------------------------+-------------------+--------------------+ 123| | nn.Linear | | Y | | Y | 124| | nn.Conv1d/2d/3d | | Y | | N | 125+---------------------------+-------------------+--------------------+ 126| | nn.LSTM | | Y (through | | Y | 127| | | | custom modules) | | | 128| | nn.GRU | | N | | Y | 129+---------------------------+-------------------+--------------------+ 130| | nn.RNNCell | | N | | Y | 131| | nn.GRUCell | | N | | Y | 132| | nn.LSTMCell | | N | | Y | 133+---------------------------+-------------------+--------------------+ 134|nn.EmbeddingBag | Y (activations | | 135| | are in fp32) | Y | 136+---------------------------+-------------------+--------------------+ 137|nn.Embedding | Y | Y | 138+---------------------------+-------------------+--------------------+ 139| nn.MultiheadAttention | Y (through | Not supported | 140| | custom modules) | | 141+---------------------------+-------------------+--------------------+ 142| Activations | Broadly supported | Un-changed, | 143| | | computations | 144| | | stay in fp32 | 145+---------------------------+-------------------+--------------------+ 146 147 148Eager Mode Quantization 149^^^^^^^^^^^^^^^^^^^^^^^ 150For a general introduction to the quantization flow, including different types of quantization, please take a look at `General Quantization Flow`_. 151 152Post Training Dynamic Quantization 153~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 154 155This is the simplest to apply form of quantization where the weights are 156quantized ahead of time but the activations are dynamically quantized 157during inference. This is used for situations where the model execution time 158is dominated by loading weights from memory rather than computing the matrix 159multiplications. This is true for LSTM and Transformer type models with 160small batch size. 161 162Diagram:: 163 164 # original model 165 # all tensors and computations are in floating point 166 previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 167 / 168 linear_weight_fp32 169 170 # dynamically quantized model 171 # linear and LSTM weights are in int8 172 previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32 173 / 174 linear_weight_int8 175 176PTDQ API Example:: 177 178 import torch 179 180 # define a floating point model 181 class M(torch.nn.Module): 182 def __init__(self): 183 super().__init__() 184 self.fc = torch.nn.Linear(4, 4) 185 186 def forward(self, x): 187 x = self.fc(x) 188 return x 189 190 # create a model instance 191 model_fp32 = M() 192 # create a quantized model instance 193 model_int8 = torch.ao.quantization.quantize_dynamic( 194 model_fp32, # the original model 195 {torch.nn.Linear}, # a set of layers to dynamically quantize 196 dtype=torch.qint8) # the target dtype for quantized weights 197 198 # run the model 199 input_fp32 = torch.randn(4, 4, 4, 4) 200 res = model_int8(input_fp32) 201 202To learn more about dynamic quantization please see our `dynamic quantization tutorial 203<https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html>`_. 204 205Post Training Static Quantization 206~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 207 208Post Training Static Quantization (PTQ static) quantizes the weights and activations of the model. It 209fuses activations into preceding layers where possible. It requires 210calibration with a representative dataset to determine optimal quantization 211parameters for activations. Post Training Static Quantization is typically used when 212both memory bandwidth and compute savings are important with CNNs being a 213typical use case. 214 215We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_. 216 217Diagram:: 218 219 # original model 220 # all tensors and computations are in floating point 221 previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 222 / 223 linear_weight_fp32 224 225 # statically quantized model 226 # weights and activations are in int8 227 previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 228 / 229 linear_weight_int8 230 231PTSQ API Example:: 232 233 import torch 234 235 # define a floating point model where some layers could be statically quantized 236 class M(torch.nn.Module): 237 def __init__(self): 238 super().__init__() 239 # QuantStub converts tensors from floating point to quantized 240 self.quant = torch.ao.quantization.QuantStub() 241 self.conv = torch.nn.Conv2d(1, 1, 1) 242 self.relu = torch.nn.ReLU() 243 # DeQuantStub converts tensors from quantized to floating point 244 self.dequant = torch.ao.quantization.DeQuantStub() 245 246 def forward(self, x): 247 # manually specify where tensors will be converted from floating 248 # point to quantized in the quantized model 249 x = self.quant(x) 250 x = self.conv(x) 251 x = self.relu(x) 252 # manually specify where tensors will be converted from quantized 253 # to floating point in the quantized model 254 x = self.dequant(x) 255 return x 256 257 # create a model instance 258 model_fp32 = M() 259 260 # model must be set to eval mode for static quantization logic to work 261 model_fp32.eval() 262 263 # attach a global qconfig, which contains information about what kind 264 # of observers to attach. Use 'x86' for server inference and 'qnnpack' 265 # for mobile inference. Other quantization configurations such as selecting 266 # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques 267 # can be specified here. 268 # Note: the old 'fbgemm' is still available but 'x86' is the recommended default 269 # for server inference. 270 # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 271 model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86') 272 273 # Fuse the activations to preceding layers, where applicable. 274 # This needs to be done manually depending on the model architecture. 275 # Common fusions include `conv + relu` and `conv + batchnorm + relu` 276 model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']]) 277 278 # Prepare the model for static quantization. This inserts observers in 279 # the model that will observe activation tensors during calibration. 280 model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) 281 282 # calibrate the prepared model to determine quantization parameters for activations 283 # in a real world setting, the calibration would be done with a representative dataset 284 input_fp32 = torch.randn(4, 1, 4, 4) 285 model_fp32_prepared(input_fp32) 286 287 # Convert the observed model to a quantized model. This does several things: 288 # quantizes the weights, computes and stores the scale and bias value to be 289 # used with each activation tensor, and replaces key operators with quantized 290 # implementations. 291 model_int8 = torch.ao.quantization.convert(model_fp32_prepared) 292 293 # run the model, relevant calculations will happen in int8 294 res = model_int8(input_fp32) 295 296To learn more about static quantization, please see the `static quantization tutorial 297<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_. 298 299Quantization Aware Training for Static Quantization 300~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 301 302Quantization Aware Training (QAT) models the effects of quantization during training 303allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization. During 304training, all calculations are done in floating point, with fake_quant modules 305modeling the effects of quantization by clamping and rounding to simulate the 306effects of INT8. After model conversion, weights and 307activations are quantized, and activations are fused into the preceding layer 308where possible. It is commonly used with CNNs and yields a higher accuracy 309compared to static quantization. 310 311We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_. 312 313Diagram:: 314 315 # original model 316 # all tensors and computations are in floating point 317 previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 318 / 319 linear_weight_fp32 320 321 # model with fake_quants for modeling quantization numerics during training 322 previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32 323 / 324 linear_weight_fp32 -- fq 325 326 # quantized model 327 # weights and activations are in int8 328 previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 329 / 330 linear_weight_int8 331 332QAT API Example:: 333 334 import torch 335 336 # define a floating point model where some layers could benefit from QAT 337 class M(torch.nn.Module): 338 def __init__(self): 339 super().__init__() 340 # QuantStub converts tensors from floating point to quantized 341 self.quant = torch.ao.quantization.QuantStub() 342 self.conv = torch.nn.Conv2d(1, 1, 1) 343 self.bn = torch.nn.BatchNorm2d(1) 344 self.relu = torch.nn.ReLU() 345 # DeQuantStub converts tensors from quantized to floating point 346 self.dequant = torch.ao.quantization.DeQuantStub() 347 348 def forward(self, x): 349 x = self.quant(x) 350 x = self.conv(x) 351 x = self.bn(x) 352 x = self.relu(x) 353 x = self.dequant(x) 354 return x 355 356 # create a model instance 357 model_fp32 = M() 358 359 # model must be set to eval for fusion to work 360 model_fp32.eval() 361 362 # attach a global qconfig, which contains information about what kind 363 # of observers to attach. Use 'x86' for server inference and 'qnnpack' 364 # for mobile inference. Other quantization configurations such as selecting 365 # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques 366 # can be specified here. 367 # Note: the old 'fbgemm' is still available but 'x86' is the recommended default 368 # for server inference. 369 # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 370 model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') 371 372 # fuse the activations to preceding layers, where applicable 373 # this needs to be done manually depending on the model architecture 374 model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, 375 [['conv', 'bn', 'relu']]) 376 377 # Prepare the model for QAT. This inserts observers and fake_quants in 378 # the model needs to be set to train for QAT logic to work 379 # the model that will observe weight and activation tensors during calibration. 380 model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train()) 381 382 # run the training loop (not shown) 383 training_loop(model_fp32_prepared) 384 385 # Convert the observed model to a quantized model. This does several things: 386 # quantizes the weights, computes and stores the scale and bias value to be 387 # used with each activation tensor, fuses modules where appropriate, 388 # and replaces key operators with quantized implementations. 389 model_fp32_prepared.eval() 390 model_int8 = torch.ao.quantization.convert(model_fp32_prepared) 391 392 # run the model, relevant calculations will happen in int8 393 res = model_int8(input_fp32) 394 395To learn more about quantization aware training, please see the `QAT 396tutorial 397<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_. 398 399Model Preparation for Eager Mode Static Quantization 400~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 401 402It is necessary to currently make some modifications to the model definition 403prior to Eager mode quantization. This is because currently quantization works on a module 404by module basis. Specifically, for all quantization techniques, the user needs to: 405 4061. Convert any operations that require output requantization (and thus have 407 additional parameters) from functionals to module form (for example, 408 using ``torch.nn.ReLU`` instead of ``torch.nn.functional.relu``). 4092. Specify which parts of the model need to be quantized either by assigning 410 ``.qconfig`` attributes on submodules or by specifying ``qconfig_mapping``. 411 For example, setting ``model.conv1.qconfig = None`` means that the 412 ``model.conv`` layer will not be quantized, and setting 413 ``model.linear1.qconfig = custom_qconfig`` means that the quantization 414 settings for ``model.linear1`` will be using ``custom_qconfig`` instead 415 of the global qconfig. 416 417For static quantization techniques which quantize activations, the user needs 418to do the following in addition: 419 4201. Specify where activations are quantized and de-quantized. This is done using 421 :class:`~torch.ao.quantization.QuantStub` and 422 :class:`~torch.ao.quantization.DeQuantStub` modules. 4232. Use :class:`~torch.ao.nn.quantized.FloatFunctional` to wrap tensor operations 424 that require special handling for quantization into modules. Examples 425 are operations like ``add`` and ``cat`` which require special handling to 426 determine output quantization parameters. 4273. Fuse modules: combine operations/modules into a single module to obtain 428 higher accuracy and performance. This is done using the 429 :func:`~torch.ao.quantization.fuse_modules.fuse_modules` API, which takes in lists of modules 430 to be fused. We currently support the following fusions: 431 [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu] 432 433(Prototype - maintenance mode) FX Graph Mode Quantization 434^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 435 436There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_mapping` (an argument of the `prepare_fx` function). 437 438FXPTQ API Example:: 439 440 import torch 441 from torch.ao.quantization import ( 442 get_default_qconfig_mapping, 443 get_default_qat_qconfig_mapping, 444 QConfigMapping, 445 ) 446 import torch.ao.quantization.quantize_fx as quantize_fx 447 import copy 448 449 model_fp = UserModel() 450 451 # 452 # post training dynamic/weight_only quantization 453 # 454 455 # we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model 456 model_to_quantize = copy.deepcopy(model_fp) 457 model_to_quantize.eval() 458 qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig) 459 # a tuple of one or more example inputs are needed to trace the model 460 example_inputs = (input_fp32) 461 # prepare 462 model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) 463 # no calibration needed when we only have dynamic/weight_only quantization 464 # quantize 465 model_quantized = quantize_fx.convert_fx(model_prepared) 466 467 # 468 # post training static quantization 469 # 470 471 model_to_quantize = copy.deepcopy(model_fp) 472 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 473 model_to_quantize.eval() 474 # prepare 475 model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) 476 # calibrate (not shown) 477 # quantize 478 model_quantized = quantize_fx.convert_fx(model_prepared) 479 480 # 481 # quantization aware training for static quantization 482 # 483 484 model_to_quantize = copy.deepcopy(model_fp) 485 qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack") 486 model_to_quantize.train() 487 # prepare 488 model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) 489 # training loop (not shown) 490 # quantize 491 model_quantized = quantize_fx.convert_fx(model_prepared) 492 493 # 494 # fusion 495 # 496 model_to_quantize = copy.deepcopy(model_fp) 497 model_fused = quantize_fx.fuse_fx(model_to_quantize) 498 499Please follow the tutorials below to learn more about FX Graph Mode Quantization: 500 501- `User Guide on Using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_ 502- `FX Graph Mode Post Training Static Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`_ 503- `FX Graph Mode Post Training Dynamic Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html>`_ 504 505(Prototype) PyTorch 2 Export Quantization 506^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 507API Example:: 508 509 import torch 510 from torch.ao.quantization.quantize_pt2e import prepare_pt2e 511 from torch._export import capture_pre_autograd_graph 512 from torch.ao.quantization.quantizer import ( 513 XNNPACKQuantizer, 514 get_symmetric_quantization_config, 515 ) 516 517 class M(torch.nn.Module): 518 def __init__(self): 519 super().__init__() 520 self.linear = torch.nn.Linear(5, 10) 521 522 def forward(self, x): 523 return self.linear(x) 524 525 # initialize a floating point model 526 float_model = M().eval() 527 528 # define calibration function 529 def calibrate(model, data_loader): 530 model.eval() 531 with torch.no_grad(): 532 for image, target in data_loader: 533 model(image) 534 535 # Step 1. program capture 536 # NOTE: this API will be updated to torch.export API in the future, but the captured 537 # result should mostly stay the same 538 m = capture_pre_autograd_graph(m, *example_inputs) 539 # we get a model with aten ops 540 541 # Step 2. quantization 542 # backend developer will write their own Quantizer and expose methods to allow 543 # users to express how they 544 # want the model to be quantized 545 quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) 546 # or prepare_qat_pt2e for Quantization Aware Training 547 m = prepare_pt2e(m, quantizer) 548 549 # run calibration 550 # calibrate(m, sample_inference_data) 551 m = convert_pt2e(m) 552 553 # Step 3. lowering 554 # lower to target backend 555 556 557Please follow these tutorials to get started on PyTorch 2 Export Quantization: 558 559Modeling Users: 560 561- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ 562- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_ 563- `PyTorch 2 Export Quantization Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_ 564 565Backend Developers (please check out all Modeling Users docs as well): 566 567- `How to Write a Quantizer for PyTorch 2 Export Quantization <https://pytorch.org/tutorials/prototype/pt2e_quantizer.html>`_ 568 569 570Quantization Stack 571------------------------ 572Quantization is the process to convert a floating point model to a quantized model. So at high level the quantization stack can be split into two parts: 1). The building blocks or abstractions for a quantized model 2). The building blocks or abstractions for the quantization flow that converts a floating point model to a quantized model 573 574Quantized Model 575^^^^^^^^^^^^^^^^^^^^^^^ 576Quantized Tensor 577~~~~~~~~~~~~~~~~~ 578In order to do quantization in PyTorch, we need to be able to represent 579quantized data in Tensors. A Quantized Tensor allows for storing 580quantized data (represented as int8/uint8/int32) along with quantization 581parameters like scale and zero\_point. Quantized Tensors allow for many 582useful operations making quantized arithmetic easy, in addition to 583allowing for serialization of data in a quantized format. 584 585PyTorch supports both per tensor and per channel symmetric and asymmetric quantization. Per tensor means that all the values within the tensor are quantized the same way with the same quantization parameters. Per channel means that for each dimension, typically the channel dimension of a tensor, the values in the tensor are quantized with different quantization parameters. This allows for less error in converting tensors to quantized values since outlier values would only impact the channel it was in, instead of the entire Tensor. 586 587The mapping is performed by converting the floating point tensors using 588 589.. image:: math-quantizer-equation.png 590 :width: 40% 591 592Note that, we ensure that zero in floating point is represented with no error 593after quantization, thereby ensuring that operations like padding do not cause 594additional quantization error. 595 596Here are a few key attributes for quantized Tensor: 597 598* QScheme (torch.qscheme): a enum that specifies the way we quantize the Tensor 599 600 * torch.per_tensor_affine 601 * torch.per_tensor_symmetric 602 * torch.per_channel_affine 603 * torch.per_channel_symmetric 604 605* dtype (torch.dtype): data type of the quantized Tensor 606 607 * torch.quint8 608 * torch.qint8 609 * torch.qint32 610 * torch.float16 611 612* quantization parameters (varies based on QScheme): parameters for the chosen way of quantization 613 614 * torch.per_tensor_affine would have quantization parameters of 615 616 * scale (float) 617 * zero_point (int) 618 * torch.per_channel_affine would have quantization parameters of 619 620 * per_channel_scales (list of float) 621 * per_channel_zero_points (list of int) 622 * axis (int) 623 624Quantize and Dequantize 625~~~~~~~~~~~~~~~~~~~~~~~ 626The input and output of a model are floating point Tensors, but activations in the quantized model are quantized, so we need operators to convert between floating point and quantized Tensors. 627 628* Quantize (float -> quantized) 629 630 * torch.quantize_per_tensor(x, scale, zero_point, dtype) 631 * torch.quantize_per_channel(x, scales, zero_points, axis, dtype) 632 * torch.quantize_per_tensor_dynamic(x, dtype, reduce_range) 633 * to(torch.float16) 634 635* Dequantize (quantized -> float) 636 637 * quantized_tensor.dequantize() - calling dequantize on a torch.float16 Tensor will convert the Tensor back to torch.float 638 * torch.dequantize(x) 639 640Quantized Operators/Modules 641~~~~~~~~~~~~~~~~~~~~~~~~~~~ 642* Quantized Operator are the operators that takes quantized Tensor as inputs, and outputs a quantized Tensor. 643* Quantized Modules are PyTorch Modules that performs quantized operations. They are typically defined for weighted operations like linear and conv. 644 645Quantized Engine 646~~~~~~~~~~~~~~~~~~~~ 647When a quantized model is executed, the qengine (torch.backends.quantized.engine) specifies which backend is to be used for execution. It is important to ensure that the qengine is compatible with the quantized model in terms of value range of quantized activation and weights. 648 649Quantization Flow 650^^^^^^^^^^^^^^^^^^^^^^^ 651 652Observer and FakeQuantize 653~~~~~~~~~~~~~~~~~~~~~~~~~~ 654* Observer are PyTorch Modules used to: 655 656 * collect tensor statistics like min value and max value of the Tensor passing through the observer 657 * and calculate quantization parameters based on the collected tensor statistics 658* FakeQuantize are PyTorch Modules used to: 659 660 * simulate quantization (performing quantize/dequantize) for a Tensor in the network 661 * it can calculate quantization parameters based on the collected statistics from observer, or it can learn the quantization parameters as well 662 663QConfig 664~~~~~~~~~~~ 665* QConfig is a namedtuple of Observer or FakeQuantize Module class that can are configurable with qscheme, dtype etc. it is used to configure how an operator should be observed 666 667 * Quantization configuration for an operator/module 668 669 * different types of Observer/FakeQuantize 670 * dtype 671 * qscheme 672 * quant_min/quant_max: can be used to simulate lower precision Tensors 673 * Currently supports configuration for activation and weight 674 * We insert input/weight/output observer based on the qconfig that is configured for a given operator or module 675 676General Quantization Flow 677~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 678In general, the flow is the following 679 680* prepare 681 682 * insert Observer/FakeQuantize modules based on user specified qconfig 683 684* calibrate/train (depending on post training quantization or quantization aware training) 685 686 * allow Observers to collect statistics or FakeQuantize modules to learn the quantization parameters 687 688* convert 689 690 * convert a calibrated/trained model to a quantized model 691 692There are different modes of quantization, they can be classified in two ways: 693 694In terms of where we apply the quantization flow, we have: 695 6961. Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data) 6972. Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data) 698 699And in terms of how we quantize the operators, we can have: 700 701- Weight Only Quantization (only weight is statically quantized) 702- Dynamic Quantization (weight is statically quantized, activation is dynamically quantized) 703- Static Quantization (both weight and activations are statically quantized) 704 705We can mix different ways of quantizing operators in the same quantization flow. For example, we can have post training quantization that has both statically and dynamically quantized operators. 706 707Quantization Support Matrix 708-------------------------------------- 709Quantization Mode Support 710^^^^^^^^^^^^^^^^^^^^^^^^^^^ 711+-----------------------------+------------------------------------------------------+----------------+----------------+------------+-----------------+ 712| |Quantization |Dataset | Works Best For | Accuracy | Notes | 713| |Mode |Requirement | | | | 714+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ 715|Post Training Quantization |Dynamic/Weight Only Quantization |activation |None |LSTM, MLP, |good |Easy to use, | 716| | |dynamically | |Embedding, | |close to static | 717| | |quantized (fp16, | |Transformer | |quantization when| 718| | |int8) or not | | | |performance is | 719| | |quantized, weight | | | |compute or memory| 720| | |statically quantized| | | |bound due to | 721| | |(fp16, int8, in4) | | | |weights | 722| +---------------------------------+--------------------+----------------+----------------+------------+-----------------+ 723| |Static Quantization |activation and |calibration |CNN |good |Provides best | 724| | |weights statically |dataset | | |perf, may have | 725| | |quantized (int8) | | | |big impact on | 726| | | | | | |accuracy, good | 727| | | | | | |for hardwares | 728| | | | | | |that only support| 729| | | | | | |int8 computation | 730+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ 731| |Dynamic Quantization |activation and |fine-tuning |MLP, Embedding |best |Limited support | 732| | |weight are fake |dataset | | |for now | 733| | |quantized | | | | | 734| +---------------------------------+--------------------+----------------+----------------+------------+-----------------+ 735| |Static Quantization |activation and |fine-tuning |CNN, MLP, |best |Typically used | 736| | |weight are fake |dataset |Embedding | |when static | 737| | |quantized | | | |quantization | 738| | | | | | |leads to bad | 739| | | | | | |accuracy, and | 740| | | | | | |used to close the| 741| | | | | | |accuracy gap | 742|Quantization Aware Training | | | | | | | 743+-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ 744 745Please see our `Introduction to Quantization on Pytorch 746<https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_ blog post 747for a more comprehensive overview of the tradeoffs between these quantization 748types. 749 750Quantization Flow Support 751^^^^^^^^^^^^^^^^^^^^^^^^^^^ 752PyTorch provides two modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization. 753 754Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. 755 756FX Graph Mode Quantization is an automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable. 757 758New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of `using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_ or fall back to eager mode quantization. 759 760The following table compares the differences between Eager Mode Quantization and FX Graph Mode Quantization: 761 762+-----------------+-------------------+-------------------+ 763| |Eager Mode |FX Graph | 764| |Quantization |Mode | 765| | |Quantization | 766+-----------------+-------------------+-------------------+ 767|Release |beta |prototype | 768|Status | | | 769+-----------------+-------------------+-------------------+ 770|Operator |Manual |Automatic | 771|Fusion | | | 772+-----------------+-------------------+-------------------+ 773|Quant/DeQuant |Manual |Automatic | 774|Placement | | | 775+-----------------+-------------------+-------------------+ 776|Quantizing |Supported |Supported | 777|Modules | | | 778+-----------------+-------------------+-------------------+ 779|Quantizing |Manual |Automatic | 780|Functionals/Torch| | | 781|Ops | | | 782+-----------------+-------------------+-------------------+ 783|Support for |Limited Support |Fully | 784|Customization | |Supported | 785+-----------------+-------------------+-------------------+ 786|Quantization Mode|Post Training |Post Training | 787|Support |Quantization: |Quantization: | 788| |Static, Dynamic, |Static, Dynamic, | 789| |Weight Only |Weight Only | 790| | | | 791| |Quantization Aware |Quantization Aware | 792| |Training: |Training: | 793| |Static |Static | 794+-----------------+-------------------+-------------------+ 795|Input/Output |``torch.nn.Module``|``torch.nn.Module``| 796|Model Type | |(May need some | 797| | |refactors to make | 798| | |the model | 799| | |compatible with FX | 800| | |Graph Mode | 801| | |Quantization) | 802+-----------------+-------------------+-------------------+ 803 804Backend/Hardware Support 805^^^^^^^^^^^^^^^^^^^^^^^^^^^ 806+-----------------+---------------+------------+------------+------------+ 807|Hardware |Kernel Library |Eager Mode |FX Graph |Quantization| 808| | |Quantization|Mode |Mode Support| 809| | | |Quantization| | 810+-----------------+---------------+------------+------------+------------+ 811|server CPU |fbgemm/onednn |Supported |All | 812| | | |Supported | 813+-----------------+---------------+ | + 814|mobile CPU |qnnpack/xnnpack| | | 815| | | | | 816+-----------------+---------------+------------+------------+------------+ 817|server GPU |TensorRT (early|Not support |Supported |Static | 818| |prototype) |this it | |Quantization| 819| | |requires a | | | 820| | |graph | | | 821+-----------------+---------------+------------+------------+------------+ 822 823Today, PyTorch supports the following backends for running quantized operators efficiently: 824 825* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via `x86` optimized by `fbgemm <https://github.com/pytorch/FBGEMM>`_ and `onednn <https://github.com/oneapi-src/oneDNN>`_ (see the details at `RFC <https://github.com/pytorch/pytorch/issues/83888>`_) 826* ARM CPUs (typically found in mobile/embedded devices), via `qnnpack <https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native/quantized/cpu/qnnpack>`_ 827* (early prototype) support for NVidia GPU via `TensorRT <https://developer.nvidia.com/tensorrt>`_ through `fx2trt` (to be open sourced) 828 829 830Note for native CPU backends 831~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 832We expose both `x86` and `qnnpack` with the same native pytorch quantized operators, so we need additional flag to distinguish between them. The corresponding implementation of `x86` and `qnnpack` is chosen automatically based on the PyTorch build mode, though users have the option to override this by setting `torch.backends.quantization.engine` to `x86` or `qnnpack`. 833 834When preparing a quantized model, it is necessary to ensure that qconfig 835and the engine used for quantized computations match the backend on which 836the model will be executed. The qconfig controls the type of observers used 837during the quantization passes. The qengine controls whether `x86` or `qnnpack` 838specific packing function is used when packing weights for 839linear and convolution functions and modules. For example: 840 841Default settings for x86:: 842 843 # set the qconfig for PTQ 844 # Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs 845 qconfig = torch.ao.quantization.get_default_qconfig('x86') 846 # or, set the qconfig for QAT 847 qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') 848 # set the qengine to control weight packing 849 torch.backends.quantized.engine = 'x86' 850 851Default settings for qnnpack:: 852 853 # set the qconfig for PTQ 854 qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') 855 # or, set the qconfig for QAT 856 qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') 857 # set the qengine to control weight packing 858 torch.backends.quantized.engine = 'qnnpack' 859 860Operator Support 861^^^^^^^^^^^^^^^^^^^^ 862 863Operator coverage varies between dynamic and static quantization and is captured in the table below. 864Note that for FX Graph Mode Quantization, the corresponding functionals are also supported. 865 866+---------------------------+-------------------+--------------------+ 867| |Static | Dynamic | 868| |Quantization | Quantization | 869+---------------------------+-------------------+--------------------+ 870| | nn.Linear | | Y | | Y | 871| | nn.Conv1d/2d/3d | | Y | | N | 872+---------------------------+-------------------+--------------------+ 873| | nn.LSTM | | N | | Y | 874| | nn.GRU | | N | | Y | 875+---------------------------+-------------------+--------------------+ 876| | nn.RNNCell | | N | | Y | 877| | nn.GRUCell | | N | | Y | 878| | nn.LSTMCell | | N | | Y | 879+---------------------------+-------------------+--------------------+ 880|nn.EmbeddingBag | Y (activations | | 881| | are in fp32) | Y | 882+---------------------------+-------------------+--------------------+ 883|nn.Embedding | Y | Y | 884+---------------------------+-------------------+--------------------+ 885|nn.MultiheadAttention |Not Supported | Not supported | 886+---------------------------+-------------------+--------------------+ 887|Activations |Broadly supported | Un-changed, | 888| | | computations | 889| | | stay in fp32 | 890+---------------------------+-------------------+--------------------+ 891 892Note: this will be updated with some information generated from native backend_config_dict soon. 893 894Quantization API Reference 895--------------------------- 896 897The :doc:`Quantization API Reference <quantization-support>` contains documentation 898of quantization APIs, such as quantization passes, quantized tensor operations, 899and supported quantized modules and functions. 900 901.. toctree:: 902 :hidden: 903 904 quantization-support 905 906Quantization Backend Configuration 907---------------------------------- 908 909The :doc:`Quantization Backend Configuration <quantization-backend-configuration>` contains documentation 910on how to configure the quantization workflows for various backends. 911 912.. toctree:: 913 :hidden: 914 915 quantization-backend-configuration 916 917Quantization Accuracy Debugging 918------------------------------- 919 920The :doc:`Quantization Accuracy Debugging <quantization-accuracy-debugging>` contains documentation 921on how to debug quantization accuracy. 922 923.. toctree:: 924 :hidden: 925 926 quantization-accuracy-debugging 927 928Quantization Customizations 929--------------------------- 930 931While default implementations of observers to select the scale factor and bias 932based on observed tensor data are provided, developers can provide their own 933quantization functions. Quantization can be applied selectively to different 934parts of the model or configured differently for different parts of the model. 935 936We also provide support for per channel quantization for **conv1d()**, **conv2d()**, 937**conv3d()** and **linear()**. 938 939Quantization workflows work by adding (e.g. adding observers as 940``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to 941``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It 942means that the model stays a regular ``nn.Module``-based instance throughout the 943process and thus can work with the rest of PyTorch APIs. 944 945Quantization Custom Module API 946^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 947 948Both Eager mode and FX graph mode quantization APIs provide a hook for the user 949to specify module quantized in a custom way, with user defined logic for 950observation and quantization. The user needs to specify: 951 9521. The Python type of the source fp32 module (existing in the model) 9532. The Python type of the observed module (provided by user). This module needs 954 to define a `from_float` function which defines how the observed module is 955 created from the original fp32 module. 9563. The Python type of the quantized module (provided by user). This module needs 957 to define a `from_observed` function which defines how the quantized module is 958 created from the observed module. 9594. A configuration describing (1), (2), (3) above, passed to the quantization APIs. 960 961 962The framework will then do the following: 963 9641. during the `prepare` module swaps, it will convert every module of type 965 specified in (1) to the type specified in (2), using the `from_float` function of 966 the class in (2). 9672. during the `convert` module swaps, it will convert every module of type 968 specified in (2) to the type specified in (3), using the `from_observed` function 969 of the class in (3). 970 971Currently, there is a requirement that `ObservedCustomModule` will have a single 972Tensor output, and an observer will be added by the framework (not by the user) 973on that output. The observer will be stored under the `activation_post_process` key 974as an attribute of the custom module instance. Relaxing these restrictions may 975be done at a future time. 976 977Custom API Example:: 978 979 import torch 980 import torch.ao.nn.quantized as nnq 981 from torch.ao.quantization import QConfigMapping 982 import torch.ao.quantization.quantize_fx 983 984 # original fp32 module to replace 985 class CustomModule(torch.nn.Module): 986 def __init__(self): 987 super().__init__() 988 self.linear = torch.nn.Linear(3, 3) 989 990 def forward(self, x): 991 return self.linear(x) 992 993 # custom observed module, provided by user 994 class ObservedCustomModule(torch.nn.Module): 995 def __init__(self, linear): 996 super().__init__() 997 self.linear = linear 998 999 def forward(self, x): 1000 return self.linear(x) 1001 1002 @classmethod 1003 def from_float(cls, float_module): 1004 assert hasattr(float_module, 'qconfig') 1005 observed = cls(float_module.linear) 1006 observed.qconfig = float_module.qconfig 1007 return observed 1008 1009 # custom quantized module, provided by user 1010 class StaticQuantCustomModule(torch.nn.Module): 1011 def __init__(self, linear): 1012 super().__init__() 1013 self.linear = linear 1014 1015 def forward(self, x): 1016 return self.linear(x) 1017 1018 @classmethod 1019 def from_observed(cls, observed_module): 1020 assert hasattr(observed_module, 'qconfig') 1021 assert hasattr(observed_module, 'activation_post_process') 1022 observed_module.linear.activation_post_process = \ 1023 observed_module.activation_post_process 1024 quantized = cls(nnq.Linear.from_float(observed_module.linear)) 1025 return quantized 1026 1027 # 1028 # example API call (Eager mode quantization) 1029 # 1030 1031 m = torch.nn.Sequential(CustomModule()).eval() 1032 prepare_custom_config_dict = { 1033 "float_to_observed_custom_module_class": { 1034 CustomModule: ObservedCustomModule 1035 } 1036 } 1037 convert_custom_config_dict = { 1038 "observed_to_quantized_custom_module_class": { 1039 ObservedCustomModule: StaticQuantCustomModule 1040 } 1041 } 1042 m.qconfig = torch.ao.quantization.default_qconfig 1043 mp = torch.ao.quantization.prepare( 1044 m, prepare_custom_config_dict=prepare_custom_config_dict) 1045 # calibration (not shown) 1046 mq = torch.ao.quantization.convert( 1047 mp, convert_custom_config_dict=convert_custom_config_dict) 1048 # 1049 # example API call (FX graph mode quantization) 1050 # 1051 m = torch.nn.Sequential(CustomModule()).eval() 1052 qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig) 1053 prepare_custom_config_dict = { 1054 "float_to_observed_custom_module_class": { 1055 "static": { 1056 CustomModule: ObservedCustomModule, 1057 } 1058 } 1059 } 1060 convert_custom_config_dict = { 1061 "observed_to_quantized_custom_module_class": { 1062 "static": { 1063 ObservedCustomModule: StaticQuantCustomModule, 1064 } 1065 } 1066 } 1067 mp = torch.ao.quantization.quantize_fx.prepare_fx( 1068 m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict) 1069 # calibration (not shown) 1070 mq = torch.ao.quantization.quantize_fx.convert_fx( 1071 mp, convert_custom_config=convert_custom_config_dict) 1072 1073Best Practices 1074-------------- 1075 10761. If you are using the ``x86`` backend, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the ``quant\_min``, ``quant\_max``, e.g. 1077if ``dtype`` is ``torch.quint8``, make sure to set a custom ``quant_min`` to be ``0`` and ``quant_max`` to be ``127`` (``255`` / ``2``) 1078if ``dtype`` is ``torch.qint8``, make sure to set a custom ``quant_min`` to be ``-64`` (``-128`` / ``2``) and ``quant_max`` to be ``63`` (``127`` / ``2``), we already set this correctly if 1079you call the `torch.ao.quantization.get_default_qconfig(backend)` or `torch.ao.quantization.get_default_qat_qconfig(backend)` function to get the default ``qconfig`` for 1080``x86`` or ``qnnpack`` backend 1081 10822. If ``onednn`` backend is selected, 8 bits for activation will be used in the default qconfig mapping ``torch.ao.quantization.get_default_qconfig_mapping('onednn')`` 1083and default qconfig ``torch.ao.quantization.get_default_qconfig('onednn')``. It is recommended to be used on CPUs with Vector Neural Network Instruction (VNNI) 1084support. Otherwise, setting ``reduce_range`` to True of the activation's observer to get better accuracy on CPUs without VNNI support. 1085 1086Frequently Asked Questions 1087-------------------------- 1088 10891. How can I do quantized inference on GPU?: 1090 1091 We don't have official GPU support yet, but this is an area of active development, you can find more information 1092 `here <https://github.com/pytorch/pytorch/issues/87395>`_ 1093 10942. Where can I get ONNX support for my quantized model? 1095 1096 If you get errors exporting the model (using APIs under ``torch.onnx``), you may open an issue in the PyTorch repository. Prefix the issue title with ``[ONNX]`` and tag the issue as ``module: onnx``. 1097 1098 If you encounter issues with ONNX Runtime, open an issue at `GitHub - microsoft/onnxruntime <https://github.com/microsoft/onnxruntime/issues/>`_. 1099 11003. How can I use quantization with LSTM's?: 1101 1102 LSTM is supported through our custom module api in both eager mode and fx graph mode quantization. Examples can be found at 1103 Eager Mode: `pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm <https://github.com/pytorch/pytorch/blob/9b88dcf248e717ca6c3f8c5e11f600825547a561/test/quantization/core/test_quantized_op.py#L2782>`_ 1104 FX Graph Mode: `pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm <https://github.com/pytorch/pytorch/blob/9b88dcf248e717ca6c3f8c5e11f600825547a561/test/quantization/fx/test_quantize_fx.py#L4116>`_ 1105 1106Common Errors 1107--------------------------------------- 1108 1109Passing a non-quantized Tensor into a quantized kernel 1110^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1111 1112If you see an error similar to:: 1113 1114 RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend... 1115 1116This means that you are trying to pass a non-quantized Tensor to a quantized 1117kernel. A common workaround is to use ``torch.ao.quantization.QuantStub`` to 1118quantize the tensor. This needs to be done manually in Eager mode quantization. 1119An e2e example:: 1120 1121 class M(torch.nn.Module): 1122 def __init__(self): 1123 super().__init__() 1124 self.quant = torch.ao.quantization.QuantStub() 1125 self.conv = torch.nn.Conv2d(1, 1, 1) 1126 1127 def forward(self, x): 1128 # during the convert step, this will be replaced with a 1129 # `quantize_per_tensor` call 1130 x = self.quant(x) 1131 x = self.conv(x) 1132 return x 1133 1134Passing a quantized Tensor into a non-quantized kernel 1135^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1136 1137If you see an error similar to:: 1138 1139 RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. 1140 1141This means that you are trying to pass a quantized Tensor to a non-quantized 1142kernel. A common workaround is to use ``torch.ao.quantization.DeQuantStub`` to 1143dequantize the tensor. This needs to be done manually in Eager mode quantization. 1144An e2e example:: 1145 1146 class M(torch.nn.Module): 1147 def __init__(self): 1148 super().__init__() 1149 self.quant = torch.ao.quantization.QuantStub() 1150 self.conv1 = torch.nn.Conv2d(1, 1, 1) 1151 # this module will not be quantized (see `qconfig = None` logic below) 1152 self.conv2 = torch.nn.Conv2d(1, 1, 1) 1153 self.dequant = torch.ao.quantization.DeQuantStub() 1154 1155 def forward(self, x): 1156 # during the convert step, this will be replaced with a 1157 # `quantize_per_tensor` call 1158 x = self.quant(x) 1159 x = self.conv1(x) 1160 # during the convert step, this will be replaced with a 1161 # `dequantize` call 1162 x = self.dequant(x) 1163 x = self.conv2(x) 1164 return x 1165 1166 m = M() 1167 m.qconfig = some_qconfig 1168 # turn off quantization for conv2 1169 m.conv2.qconfig = None 1170 1171Saving and Loading Quantized models 1172^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1173 1174When calling ``torch.load`` on a quantized model, if you see an error like:: 1175 1176 AttributeError: 'LinearPackedParams' object has no attribute '_modules' 1177 1178This is because directly saving and loading a quantized model using ``torch.save`` and ``torch.load`` 1179is not supported. To save/load quantized models, the following ways can be used: 1180 11811. Saving/Loading the quantized model state_dict 1182 1183An example:: 1184 1185 class M(torch.nn.Module): 1186 def __init__(self): 1187 super().__init__() 1188 self.linear = nn.Linear(5, 5) 1189 self.relu = nn.ReLU() 1190 1191 def forward(self, x): 1192 x = self.linear(x) 1193 x = self.relu(x) 1194 return x 1195 1196 m = M().eval() 1197 prepare_orig = prepare_fx(m, {'' : default_qconfig}) 1198 prepare_orig(torch.rand(5, 5)) 1199 quantized_orig = convert_fx(prepare_orig) 1200 1201 # Save/load using state_dict 1202 b = io.BytesIO() 1203 torch.save(quantized_orig.state_dict(), b) 1204 1205 m2 = M().eval() 1206 prepared = prepare_fx(m2, {'' : default_qconfig}) 1207 quantized = convert_fx(prepared) 1208 b.seek(0) 1209 quantized.load_state_dict(torch.load(b)) 1210 12112. Saving/Loading scripted quantized models using ``torch.jit.save`` and ``torch.jit.load`` 1212 1213An example:: 1214 1215 # Note: using the same model M from previous example 1216 m = M().eval() 1217 prepare_orig = prepare_fx(m, {'' : default_qconfig}) 1218 prepare_orig(torch.rand(5, 5)) 1219 quantized_orig = convert_fx(prepare_orig) 1220 1221 # save/load using scripted model 1222 scripted = torch.jit.script(quantized_orig) 1223 b = io.BytesIO() 1224 torch.jit.save(scripted, b) 1225 b.seek(0) 1226 scripted_quantized = torch.jit.load(b) 1227 1228Symbolic Trace Error when using FX Graph Mode Quantization 1229^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1230Symbolic traceability is a requirement for `(Prototype - maintenance mode) FX Graph Mode Quantization`_, so if you pass a PyTorch Model that is not symbolically traceable to `torch.ao.quantization.prepare_fx` or `torch.ao.quantization.prepare_qat_fx`, we might see an error like the following:: 1231 1232 torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow 1233 1234Please take a look at `Limitations of Symbolic Tracing <https://pytorch.org/docs/2.0/fx.html#limitations-of-symbolic-tracing>`_ and use - `User Guide on Using FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html>`_ to workaround the problem. 1235 1236 1237.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now. 1238.. They are here for tracking purposes until they are more permanently fixed. 1239.. py:module:: torch.ao 1240.. py:module:: torch.ao.nn 1241.. py:module:: torch.ao.nn.quantizable 1242.. py:module:: torch.ao.nn.quantizable.modules 1243.. py:module:: torch.ao.nn.quantized 1244.. py:module:: torch.ao.nn.quantized.reference 1245.. py:module:: torch.ao.nn.quantized.reference.modules 1246.. py:module:: torch.ao.nn.sparse 1247.. py:module:: torch.ao.nn.sparse.quantized 1248.. py:module:: torch.ao.nn.sparse.quantized.dynamic 1249.. py:module:: torch.ao.ns 1250.. py:module:: torch.ao.ns.fx 1251.. py:module:: torch.ao.quantization.backend_config 1252.. py:module:: torch.ao.pruning 1253.. py:module:: torch.ao.pruning.scheduler 1254.. py:module:: torch.ao.pruning.sparsifier 1255.. py:module:: torch.ao.nn.intrinsic.modules.fused 1256.. py:module:: torch.ao.nn.intrinsic.qat.modules.conv_fused 1257.. py:module:: torch.ao.nn.intrinsic.qat.modules.linear_fused 1258.. py:module:: torch.ao.nn.intrinsic.qat.modules.linear_relu 1259.. py:module:: torch.ao.nn.intrinsic.quantized.dynamic.modules.linear_relu 1260.. py:module:: torch.ao.nn.intrinsic.quantized.modules.bn_relu 1261.. py:module:: torch.ao.nn.intrinsic.quantized.modules.conv_add 1262.. py:module:: torch.ao.nn.intrinsic.quantized.modules.conv_relu 1263.. py:module:: torch.ao.nn.intrinsic.quantized.modules.linear_relu 1264.. py:module:: torch.ao.nn.qat.dynamic.modules.linear 1265.. py:module:: torch.ao.nn.qat.modules.conv 1266.. py:module:: torch.ao.nn.qat.modules.embedding_ops 1267.. py:module:: torch.ao.nn.qat.modules.linear 1268.. py:module:: torch.ao.nn.quantizable.modules.activation 1269.. py:module:: torch.ao.nn.quantizable.modules.rnn 1270.. py:module:: torch.ao.nn.quantized.dynamic.modules.conv 1271.. py:module:: torch.ao.nn.quantized.dynamic.modules.linear 1272.. py:module:: torch.ao.nn.quantized.dynamic.modules.rnn 1273.. py:module:: torch.ao.nn.quantized.modules.activation 1274.. py:module:: torch.ao.nn.quantized.modules.batchnorm 1275.. py:module:: torch.ao.nn.quantized.modules.conv 1276.. py:module:: torch.ao.nn.quantized.modules.dropout 1277.. py:module:: torch.ao.nn.quantized.modules.embedding_ops 1278.. py:module:: torch.ao.nn.quantized.modules.functional_modules 1279.. py:module:: torch.ao.nn.quantized.modules.linear 1280.. py:module:: torch.ao.nn.quantized.modules.normalization 1281.. py:module:: torch.ao.nn.quantized.modules.rnn 1282.. py:module:: torch.ao.nn.quantized.modules.utils 1283.. py:module:: torch.ao.nn.quantized.reference.modules.conv 1284.. py:module:: torch.ao.nn.quantized.reference.modules.linear 1285.. py:module:: torch.ao.nn.quantized.reference.modules.rnn 1286.. py:module:: torch.ao.nn.quantized.reference.modules.sparse 1287.. py:module:: torch.ao.nn.quantized.reference.modules.utils 1288.. py:module:: torch.ao.nn.sparse.quantized.dynamic.linear 1289.. py:module:: torch.ao.nn.sparse.quantized.linear 1290.. py:module:: torch.ao.nn.sparse.quantized.utils 1291.. py:module:: torch.ao.ns.fx.graph_matcher 1292.. py:module:: torch.ao.ns.fx.graph_passes 1293.. py:module:: torch.ao.ns.fx.mappings 1294.. py:module:: torch.ao.ns.fx.n_shadows_utils 1295.. py:module:: torch.ao.ns.fx.ns_types 1296.. py:module:: torch.ao.ns.fx.pattern_utils 1297.. py:module:: torch.ao.ns.fx.qconfig_multi_mapping 1298.. py:module:: torch.ao.ns.fx.utils 1299.. py:module:: torch.ao.ns.fx.weight_utils 1300.. py:module:: torch.ao.pruning.scheduler.base_scheduler 1301.. py:module:: torch.ao.pruning.scheduler.cubic_scheduler 1302.. py:module:: torch.ao.pruning.scheduler.lambda_scheduler 1303.. py:module:: torch.ao.pruning.sparsifier.base_sparsifier 1304.. py:module:: torch.ao.pruning.sparsifier.nearly_diagonal_sparsifier 1305.. py:module:: torch.ao.pruning.sparsifier.utils 1306.. py:module:: torch.ao.pruning.sparsifier.weight_norm_sparsifier 1307.. py:module:: torch.ao.quantization.backend_config.backend_config 1308.. py:module:: torch.ao.quantization.backend_config.executorch 1309.. py:module:: torch.ao.quantization.backend_config.fbgemm 1310.. py:module:: torch.ao.quantization.backend_config.native 1311.. py:module:: torch.ao.quantization.backend_config.observation_type 1312.. py:module:: torch.ao.quantization.backend_config.onednn 1313.. py:module:: torch.ao.quantization.backend_config.qnnpack 1314.. py:module:: torch.ao.quantization.backend_config.tensorrt 1315.. py:module:: torch.ao.quantization.backend_config.utils 1316.. py:module:: torch.ao.quantization.backend_config.x86 1317.. py:module:: torch.ao.quantization.fake_quantize 1318.. py:module:: torch.ao.quantization.fuser_method_mappings 1319.. py:module:: torch.ao.quantization.fuse_modules 1320.. py:module:: torch.ao.quantization.fx.convert 1321.. py:module:: torch.ao.quantization.fx.custom_config 1322.. py:module:: torch.ao.quantization.fx.fuse 1323.. py:module:: torch.ao.quantization.fx.fuse_handler 1324.. py:module:: torch.ao.quantization.fx.graph_module 1325.. py:module:: torch.ao.quantization.fx.lower_to_fbgemm 1326.. py:module:: torch.ao.quantization.fx.lower_to_qnnpack 1327.. py:module:: torch.ao.quantization.fx.lstm_utils 1328.. py:module:: torch.ao.quantization.fx.match_utils 1329.. py:module:: torch.ao.quantization.fx.pattern_utils 1330.. py:module:: torch.ao.quantization.fx.prepare 1331.. py:module:: torch.ao.quantization.fx.qconfig_mapping_utils 1332.. py:module:: torch.ao.quantization.fx.quantize_handler 1333.. py:module:: torch.ao.quantization.fx.tracer 1334.. py:module:: torch.ao.quantization.fx.utils 1335.. py:module:: torch.ao.quantization.observer 1336.. py:module:: torch.ao.quantization.pt2e.duplicate_dq_pass 1337.. py:module:: torch.ao.quantization.pt2e.export_utils 1338.. py:module:: torch.ao.quantization.pt2e.graph_utils 1339.. py:module:: torch.ao.quantization.pt2e.port_metadata_pass 1340.. py:module:: torch.ao.quantization.pt2e.prepare 1341.. py:module:: torch.ao.quantization.pt2e.qat_utils 1342.. py:module:: torch.ao.quantization.pt2e.representation.rewrite 1343.. py:module:: torch.ao.quantization.pt2e.utils 1344.. py:module:: torch.ao.quantization.qconfig 1345.. py:module:: torch.ao.quantization.qconfig_mapping 1346.. py:module:: torch.ao.quantization.quant_type 1347.. py:module:: torch.ao.quantization.quantization_mappings 1348.. py:module:: torch.ao.quantization.quantize_fx 1349.. py:module:: torch.ao.quantization.quantize_jit 1350.. py:module:: torch.ao.quantization.quantize_pt2e 1351.. py:module:: torch.ao.quantization.quantizer.composable_quantizer 1352.. py:module:: torch.ao.quantization.quantizer.embedding_quantizer 1353.. py:module:: torch.ao.quantization.quantizer.quantizer 1354.. py:module:: torch.ao.quantization.quantizer.utils 1355.. py:module:: torch.ao.quantization.quantizer.x86_inductor_quantizer 1356.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer 1357.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer_utils 1358.. py:module:: torch.ao.quantization.stubs 1359.. py:module:: torch.ao.quantization.utils 1360.. py:module:: torch.nn.intrinsic.modules.fused 1361.. py:module:: torch.nn.intrinsic.qat.modules.conv_fused 1362.. py:module:: torch.nn.intrinsic.qat.modules.linear_fused 1363.. py:module:: torch.nn.intrinsic.qat.modules.linear_relu 1364.. py:module:: torch.nn.intrinsic.quantized.dynamic.modules.linear_relu 1365.. py:module:: torch.nn.intrinsic.quantized.modules.bn_relu 1366.. py:module:: torch.nn.intrinsic.quantized.modules.conv_relu 1367.. py:module:: torch.nn.intrinsic.quantized.modules.linear_relu 1368.. py:module:: torch.nn.qat.dynamic.modules.linear 1369.. py:module:: torch.nn.qat.modules.conv 1370.. py:module:: torch.nn.qat.modules.embedding_ops 1371.. py:module:: torch.nn.qat.modules.linear 1372.. py:module:: torch.nn.quantizable.modules.activation 1373.. py:module:: torch.nn.quantizable.modules.rnn 1374.. py:module:: torch.nn.quantized.dynamic.modules.conv 1375.. py:module:: torch.nn.quantized.dynamic.modules.linear 1376.. py:module:: torch.nn.quantized.dynamic.modules.rnn 1377.. py:module:: torch.nn.quantized.functional 1378.. py:module:: torch.nn.quantized.modules.activation 1379.. py:module:: torch.nn.quantized.modules.batchnorm 1380.. py:module:: torch.nn.quantized.modules.conv 1381.. py:module:: torch.nn.quantized.modules.dropout 1382.. py:module:: torch.nn.quantized.modules.embedding_ops 1383.. py:module:: torch.nn.quantized.modules.functional_modules 1384.. py:module:: torch.nn.quantized.modules.linear 1385.. py:module:: torch.nn.quantized.modules.normalization 1386.. py:module:: torch.nn.quantized.modules.rnn 1387.. py:module:: torch.nn.quantized.modules.utils 1388.. py:module:: torch.quantization.fake_quantize 1389.. py:module:: torch.quantization.fuse_modules 1390.. py:module:: torch.quantization.fuser_method_mappings 1391.. py:module:: torch.quantization.fx.convert 1392.. py:module:: torch.quantization.fx.fuse 1393.. py:module:: torch.quantization.fx.fusion_patterns 1394.. py:module:: torch.quantization.fx.graph_module 1395.. py:module:: torch.quantization.fx.match_utils 1396.. py:module:: torch.quantization.fx.pattern_utils 1397.. py:module:: torch.quantization.fx.prepare 1398.. py:module:: torch.quantization.fx.quantization_patterns 1399.. py:module:: torch.quantization.fx.quantization_types 1400.. py:module:: torch.quantization.fx.utils 1401.. py:module:: torch.quantization.observer 1402.. py:module:: torch.quantization.qconfig 1403.. py:module:: torch.quantization.quant_type 1404.. py:module:: torch.quantization.quantization_mappings 1405.. py:module:: torch.quantization.quantize 1406.. py:module:: torch.quantization.quantize_fx 1407.. py:module:: torch.quantization.quantize_jit 1408.. py:module:: torch.quantization.stubs 1409.. py:module:: torch.quantization.utils 1410