README.md
1# FX Graph Mode Quantization Design Doc
2<!---
3```
4float_model QConfigMapping BackendConfig
5 \ | /
6 \ | /
7 \ | /
8(prepare_fx/prepare_qat_fx) /
9—-------------------------------------------------------
10| Fuse |
11| QAT Module Swap |
12| Insert Observers |
13—-------------------------------------------------------
14 |
15 Calibrate/Train
16 |
17(convert_fx) |
18—--------------------------------------------------------
19| Convert |
20| Lowering |
21—--------------------------------------------------------
22 |
23 Quantized Model
24```
25-->
26
27```mermaid
28---
29title: High Level FX Graph Mode Quantization Flow
30---
31flowchart TD
32 classDef nofs fill:none,stroke:none
33 classDef sub fill:#D6EAF8,stroke:none
34 float_model:::nofs --> prepare_fx:::sub
35 QConfigMapping:::nofs --> prepare_fx
36 BackendConfig:::nofs --> prepare_fx
37 subgraph prepare_fx["`_(prepare_fx/prepare_qat_fx)_`"]
38 Fuse:::nofs --> swap[QAT Module Swap]:::nofs --> obs[Insert Observers]:::nofs
39 end
40 prepare_fx --> Calibrate/Train:::nofs --> convert_fx:::sub
41 subgraph convert_fx["`_(convert_fx)_`"]
42 Convert:::nofs --> Lowering:::nofs
43 end
44 convert_fx --> qm[Quantized Model]:::nofs
45```
46
47Please refer to [TODO: link] for definitions of terminologies.
48
49## Overview
50The FX graph representation is pretty close to python/eager mode, it preserves many python/eager mode constructs like modules, functionals, torch ops, so overall the implementation reuses some of building blocks and utilities from eager mode quantization, this includes the QConfig, QConfig propagation (might be removed), fused modules, QAT module, quantized modules, QAT module swapping utility. Also the overall flow exactly matches eager mode quantization, the only difference is that the transformations like fusion, inserting stubs are fully automated and controlled by QConfigMapping and BackendConfig.
51
52## High Level Flow with Simple Example
53
54`prepare_fx`:
55```
56Floating Point Model --> (1.1 `_fuse_fx`) --> Fused Model
57 --> (1.2 QAT Module Swap) --> Model with QAT modules
58 --> (1.3 Insert Observers) --> Prepared Model
59```
60
61`convert_fx`:
62```
63Prepared Model --> (2.1 `convert_to_reference`) --> Reference Quantized Model
64 --> (2.2 Lower to Native Backend) --> Quantized Model
65```
66
67In the following, I’ll first have a detailed description for each step, and then talk about the corresponding settings in BackendConfig. We’ll follow the terminologies defined in (draft) README.md of quantization syntax transforms in this doc.
68
69### 0. Original Model
70
71```
72class LinearReLUModule(torch.nn.Module):
73 def __init__(self) -> None:
74 super().__init__()
75 self.linear = torch.nn.Linear(5, 10).float()
76 self.relu = torch.nn.ReLU()
77
78 def forward(self, x):
79 return self.relu(self.linear(x))
80```
81
82### 1.1 Fusion
83```
84fused: GraphModule(
85 (linear): LinearReLU(
86 (0): Linear(in_features=5, out_features=10, bias=True)
87 (1): ReLU()
88 )
89)
90
91def forward(self, x):
92 linear = self.linear(x); x = None
93 return linear
94```
95
96What we did in this example are:
97
98* Identify (Linear - ReLU) subgraph by searching through the model graph
99* For each of the identified subgraph, we replace the `root_node` (typically the weighted module in the pattern, like Linear), with a fused module by calling the fuser_method for this pattern, a fused module is a sequential of a few modules, e.g. nni.LinearReLU is a sequential of linear and relu module
100
101`backend_config` configurations relevant to this step are:
102
103```
104def fuse_linear_relu(is_qat, linear, relu):
105 return nni.LinearReLU(linear, relu)
106
107BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
108 .set_fuser_method(fuse_linear_relu)
109 ._set_root_node_getter(my_root_node_getter)
110 ._set_extra_inputs_getter(my_extra_inputs_getter)
111```
112
113
114`BackendPatternConfig` takes in a pattern that specifies the fusion pattern that we want to search for, pattern format can be found in https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
115
116`set_dtype_configs`: dtype_configs are used to check against the qconfig for the pattern, to see if the qconfig is supported in the target backend or not. Currently it’s not used in fusion, but we can add this check in the future, or remove this and always fuse these patterns.
117`set_fuser_method`: specifies the fuser method to use for the pattern, a fuser method will take the matched object and fuse them into a fused module.
118`_set_root_node_getter`: sets a function that takes a node pattern and returns the root node in the pattern.
119`_set_extra_inputs_getter`: all input args of root node will be copied over to fused module, if there are extra inputs, this function will return a list of extra inputs given the pattern.
120
121Example usage of `root_node_getter` and `extra_input_getter`: https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6
122
123### 1.2 QAT Module Swap
124```
125GraphModule(
126 (linear): LinearReLU(
127 in_features=5, out_features=10, bias=True
128 (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
129 )
130)
131
132def forward(self, x):
133 linear = self.linear(x); x = None
134 return linear
135```
136
137In this step we swap the fused module to qat module, for example, swap nn.intrinsic.LinearReLU instances to nn.intrinsic.qat.LinearReLU module where we fake quantize the weight of linear.
138For modules that has corresponding QAT modules we’ll call eager mode `convert` function with a mapping from float module to QAT module which will swap all float module (and fused module) with QAT module, this step is exactly the same as eager mode quantization, just called inside the `prepare_fx/prepare_qat_fx` function.
139
140`backend_config` configurations relevant in this step are:
141```
142BackendPatternConfig(nni.LinearReLU)
143 .set_qat_module(nniqat.LinearReLU)
144```
145
146The pattern used to initialize BackendPatternConfig is the class type for original or fused floating point module class.
147`set_qat_module` sets the qat module class corresponding to the module class specified in the pattern.
148
149### 1.3 QuantDeQuantStub and Observer/FakeQuantize Insertion
150```
151GraphModule(
152 (activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
153 (linear): LinearReLU(
154 (0): Linear(in_features=5, out_features=10, bias=True)
155 (1): ReLU()
156 )
157 (activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
158)
159
160def forward(self, x):
161 activation_post_process_0 = self.activation_post_process_0(x); x = None
162 linear = self.linear(activation_post_process_0); activation_post_process_0 = None
163 activation_post_process_1 = self.activation_post_process_1(linear); linear = None
164 return activation_post_process_1
165```
166
167Note: activation_post_process_0 and activation_post_process_1 will be updated with QuantDeQuantStub
168
169QuantDeQuantStubs are inserted based on the `qconfig_mapping` provided by users. Also we have a backend_config that specifies the configs that are supported by the backend. In this step, we will
170* Check if `qconfig_mapping` is compatible with `backend_config` or not, if user requested a qconfig that is not compatible with `backend_config`, we’ll not insert observers for the operator, the config would just be ignored.
171* Insert observer for the input and output of the subgraph, based on the `qconfig_mapping` (what user requested) and the `backend_config` (how the operator should be observed in a backend).
172
173Detailed walkthrough for this step in `prepare_qat_fx` (inserting QDQStub and FakeQuantize modules):
174Note: We could also insert QStub and DQStub in this step when users request to change the interface dtype for the model, standalone module or custom modules.
175```
176# fused and qat swapped model
177# graph 1:
178input - qat_linear_relu - output
179 |
180 FakeQuantize
181(need to be updated with QDQStub + FakeQuantize)
182 |
183 weight
184
185# qconfig_mapping (simplified, shown as dict)
186{'qat_linear_relu': QConfig(
187 weight=MinMaxObserver.with_args(dtype=torch.qint8),
188 activation=HistogramObserver.with_args(dtype=torch.quint8),
189)}
190
191# backend_config (simplified)
192{
193 'pattern': nnqat.LinearReLU,
194 'dtype_configs': [{input: torch.quint8, output: torch.quint8, weight: torch.qint8}],
195}
196```
197
198step 1: assign qconfig to each op (please see [TODO: link] for details)
199
200step 2: determine which qconfigs are valid according to the backend configuration (please see [TODO: link] for details)
201(we should add a warning here)
202
203step 3: for subgraphs with validated qconfigs, insert qstub/dqstub/qdqstub needed
204
205To talk about what happens in this step, let’s first define some terms. Let’s view the computation graph we showed above as a Graph consists of nodes and edges, each node here will be an FX Node that represents some computation, for example linear, and each edge will be a connection between two nodes, and each edge can both be viewed as the output of the previous Node or the input of the next Node.
206
207The end goal for this step is to insert QDQStubs at edges so that we produce a graph of quantized reference model when each QDQStub represents a quantize operator followed by a dequantize operator.
208
209```
210# graph 2:
211input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - output
212 |
213 FakeQuantize
214 (need to be updated with QDQStub + FakeQuantize)
215 |
216 weight
217```
218Note: weight + FakeQuantize is a part of qat_linear_relu
219
220The overall logic to insert QDQStub1 and QDQStub2 inplace is the following:
2210. For each node in the original graph, we compute the target_dtype for input and output for it based on qconfig, for graph1, configured with qconfig_mapping, we have:
222```
223# node_name_to_target_dtype_info =
224# {
225# # this is placeholder node in FX Graph
226# "input" : {"input_activation": torch.float32, "output_activation": torch.float32},
227# "qat_linear_relu": {"input_activation": torch.quint8, "output_activation": torch.quint8, "weight": ...}
228# # this is the return node in FX Graph
229# "output": {"input_activation": torch.float32, "output_activation": torch.float32}
230# }
231```
232Note: this map is generated before we insert qdqstub to graph1, and will not change in the process.
233
2341. Inserting QDQStub1 (for input of qat_linear_relu)
235 We need to look at the edge between `input` Node and `qat_linear_relu` Node here, we need to decide if we need to insert a
236 QDQStub at this edge, which could serve as an input argument for `qat_linear_relu` Node (and also output for `input` Node)
237 The way we decide if we want to insert QDQStub here is to figure out
238
239 (1). The target dtype for output of `input` Node, which is torch.float32
240
241 (2). The target dtype for input of `qat_linear_relu` Node, which is torch.quint8
242 There is a mismatch here and (2) is a quantized dtype, so we need to insert QDQStub at the edge.
243
244 We also need to attach observer/fakequant module to the QDQStub we inserted here.
2452. Insert QDQStub2 (for output of qat_linear_relu)
246 The logic for inserting QDQStub for output is much easier, since we assume all modules/functions in the graph produce fp32 output
247 by default (we can have additional checks and extend this to work for other dtypes after we have type inference ready),
248 we just need to look at the target output dtype for qat_linear_relu Node, and if it is a quantized dtype (quint8, qint8, float16),
249 we would insert a QDQStub here.
250
251Questions: How to avoid inserting duplicate QDQStubs?
252e.g. when we have a single input being used by multiple ops:
253```
254input — linear1 —-
255 \--- linear2 —
256```
257how do we make sure we only insert one QDQStub for input of both linear1 and linear2?
258```
259input - QDQStub — linear1 -
260 \ —- linear2 -
261```
262
263The way we do it right now is before we insert QDQStub, we look at all users of `input` Node here and make sure there is no QDQStubs
264with the same target_dtype, that is, if we already inserted a QDQStub with dtype quint8 for linear1, and linear2 is also connected to it, if we request another QDQStub with dtype quint8 when processing linear2 Node, we’ll detect that the desired QDQStub already exists and do nothing
265
266Question: What is the logic for keeping output to be float32?
267Let’s say the output of `qat_linear_relu` Node is configured as float32, both in qconfig_mapping and backend_config:
268```
269# qconfig_mapping (simplified, shown as dict)
270{'qat_linear_relu': QConfig(
271 weight=MinMaxObserver.with_args(dtype=torch.qint8),
272 input_activation=HistogramObserver.with_args(dtype=torch.quint8),
273 output_activation=PlaceholderObserver.with_args(dtype=torch.float32),
274)}
275
276# backend_config (simplified)
277{
278 'pattern': nnqat.LinearReLU,
279 'dtype_configs': [{input: torch.quint8, output: torch.float32, weight: torch.qint8}],
280}
281```
282
283What we’ll do here is when we are trying to insert output QDQStub for `qat_linear_relu`, we look at the target output dtype for this node (node_name_to_target_dtype_info["qat_linear_relu"]["output_activation"], and find that it is float, which is not a quantized dtype, so
284will do nothing here.
285Note that this does not prevent other operators following `qat_linear_relu` to insert a QDQStub at the output of `qat_linear_relu`, since we are dealing with an `edge` of the graph here, and an `edge` is connected to two nodes, which means
286the output of `qat_linear_relu` will also be the input of a node following `qat_linear_relu`.
287
288`backend_config` configurations used in this step:
289```
290BackendConfig(nniqat.LinearReLU)
291 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
292 .set_dtype_configs([
293 DTypeConfig(input_dtype=torch.quint8, output_dtype = torch.quint8, weight_dtype = torch.qint8, bias_dtype = torch.float32)]
294 )
295```
296
297Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with
298
299`set_observation_type`: sets the observation type for the patter, currently only two types:
300
301`OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` means the output observer instance will be different from the input, which is the most common type of observer placement.
302
303`OUTPUT_SHARE_OBSERVER_WITH_INPUT` means the output observer is shared with input, they will be the same instance. This is useful for operators like cat.
304
305`set_dtype_configs`: sets a list of supported (activation, weight, bias, etc.) dtype combinations for qconfigs for the pattern. Note that we represent different modes of quantization (static/dynamic/`weight_only`) purely through this combination, for example, fbgemm static quantization can be represented as:
306```
307{
308 "input_activation": torch.quint8,
309 "weight": torch.qint8,
310 "output_activation": torch.quint8
311}
312```
313
314Note: the dtype config will be used to configure the support for dynamic quantization as well
315
316Note: we may extend this to support more fine grained configurations of args, kwargs, attributes and outputs in the future
317
318Note: we are referring to observer here, which is an implementation detail, we can change this to talk about quantization parameters instead, e.g. `QParamsType.OUTPUT_USE_DIFFERENT_QPARAMS_AS_INPUT` and `QParamsType.OUTPUT_USE_SAME_QPARAMS_AS_INPUT`
319
320### 2. Calibration/Training
321After we insert observers, we run the model to calibrate observers or to fine tune. This step is identical to eager mode quantization. After that the observer/fakequantize modules contain sufficient information to determine quantization parameters according to the observed data.
322
323### 3.1 Conversion to Reference Quantized Model
324```
325quantized: GraphModule(
326 (linear): LinearReLU(
327 (0): QuantizedLinear(Reference)(in_features=5, out_features=10, bias=True)
328 (1): ReLU()
329 )
330)
331
332def forward(self, x):
333 linear_input_scale_0 = self.linear_input_scale_0
334 linear_input_zero_point_0 = self.linear_input_zero_point_0
335 quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None
336 dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None
337 linear = self.linear(dequantize); dequantize = None
338 linear_scale_0 = self.linear_scale_0
339 linear_zero_point_0 = self.linear_zero_point_0
340 quantize_per_tensor_1 = torch.quantize_per_tensor(linear, linear_scale_0, linear_zero_point_0, torch.quint8); linear = linear_scale_0 = linear_zero_point_0 = None
341 dequantize_1 = quantize_per_tensor_1.dequantize(); quantize_per_tensor_1 = None
342 return dequantize_1
343```
344
345After we insert observers, we’ll need to convert the model to a reference quantized model. Reference quantized model is a model that uses reference patterns to represent quantized operators, this serves as the standard interface for quantized operators between PyTorch quantization and backend lowering passes. For more details, please take a look at this [RFC](https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md). This pass is pretty straightforward, what we do is:
346
347(1). for each QDQStub (attached with Observer for FakeQuantize modules) in the graph, we'll convert it to calls to quantize and dequantize functions based on the attributes of attached Observer and FakeQuantize modules (e.g. qscheme, dtype etc.)
348
349(2). for weighted modules like linear/conv, we convert them to corresponding reference quantized module.
350
351Example:
352```
353# graph 1
354input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - output
355 |
356 FakeQuantize
357 (need to be updated with QDQStub + FakeQuantize)
358 |
359 Weight
360
361Note: weight + FakeQuantize is a part of qat_linear_relu module
362
363# graph 2
364input - quantize - dequantize - reference_linear_relu - quantize - dequantize - output
365 |
366 dequantize
367 |
368 quantize
369 |
370 weight
371```
372Note: weight + quantize + dequantize is a part of reference_linear_relu module
373
374To decide which quantize node we want to use, we’ll look at:
375
376(1). dtype of attached Observer/FakeQuantize module
377
378(2). qscheme of attached Observer/FakeQuantize module
379
380(3). (optionally) other attributes of attached Observer/FakeQuantize module
381
382The quantize operator we can choose from right now are: (quantize_per_tensor, quantize_per_channel, to, quantize_per_tensor_dynamic)
383
384```
385backend_config configurations used in this step:
386BackendConfig(nniqat.LinearReLU)
387 .set_root_module(nn.Linear)
388 .set_reference_quantized_module_for_root(nnqr.Linear)
389 .set_fused_module(nni.LinearReLU)
390```
391
392Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with
393
394`set_root_module`: Sets a module class for the root of the pattern, e.g. nn.Linear for a nni.LinearReLU/nniqat.LinearReLU, used to identify the modules that needs to be swapped to reference quantized module
395
396`set_reference_quantized_module_for_root`: Sets the corresponding reference quantized module class for root module class, e.g. when root_module is nn.Linear, this will be nn.quantized.reference.Linear, used to swap the root module to be a reference quantized module.
397
398Note: we are only swapping `root_module` here, for example, in the current example, the original module is `nniqat.LinearReLU`, when we are converting weight modules(step (2)), we first convert `nniqat.LinearReLU` to a float module, in this case, the fused LinearReLU module: `nni.LinearReLU`, and then swap the root_module (`nn.Linear`) with reference quantized module (`nnqr.Linear`), so we end up with a `nni.LinearReLU` module, which is a sequential module of a `nnqr.Linear` and `nn.ReLU`.
399
400Basically, the corresponding reference quantized module for both `nniqat.LinearReLU` and `nni.LinearReLU` would be a `nni.LinearReLU` Sequential module (originally `nn.Linear` + `nn.ReLU`) with `nn.Linear` being replaced by `nnqr.Linear`: `nni.LinearReLU(nnqr.Linear, nn.ReLU)`.
401
402`set_fused_module`: This is the corresponding fused module class for the pattern, used to identify fused modules that needs to be converted to reference quantized module
403
404### 3.2 Lower to PyTorch Native Backend
405```
406GraphModule(
407 (linear): QuantizedLinearReLU(in_features=5, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
408)
409
410def forward(self, x):
411 linear_input_scale_0 = self.linear_input_scale_0
412 linear_input_zero_point_0 = self.linear_input_zero_point_0
413 quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None
414 linear = self.linear(quantize_per_tensor); quantize_per_tensor = None
415 dequantize_1 = linear.dequantize(); linear = None
416 return dequantize_1
417```
418
419Currently, PyTorch has native quantized backends: fbgemm and qnnpack, so we need a lowering pass to lower the reference quantized model to a model that is using native quantized operators in PyTorch. What this pass did is
420
4211. Recognize the reference patterns like: "dequantize - `float_op` - quantize" in the graph and replace them with the quantized modules (under torch.nn.quantized namespace) or operators (under torch.ops.quantized namespace, or torch namespace)
422In general there are three types of patterns:
423
424* Static quantization:
425```
426dequantize -> float_op -> quantize_per_tensor
427```
428
429* Dynamic quantization:
430```
431quantize_per_tensor_dynamic -> dequantize -> float_op
432```
433
434* Weight only quantization:
435```
436 input - float_op - output
437 weight - quantize_per_tensor - dequantize /
438```
439
4402. Prepack and fold the weights for quantized linear and quantized conv operator
4413. The lowering pass is also going to keep some patterns for quantized operators unfused, since user may explicitly request some operators to stay in float by configuring the qconfig to be None
442
443There are no configurations related to lowering in `backend_config` since it is backend developer’s responsibility to implement lowering pass and each of the backend developers may have their own configurations. So from end to end, `backend_config` and together with qconfig_mapping controls what Reference Quantized Model is produced by FX Graph Mode Quantization, not lowered model.
444
445However, for some operator based backends, like the current pytorch native backends including fbgemm and qnnpack. We could interpret `backend_config` in terms of configurations for operators as well. e.g. configuring `input_dtype=quint8`, `weight_dtype=qint8`, `output_dtype=torch.quint8` for nn.Linear is saying that the quantized linear will take a `quint8` activation and `qint8` weight as input and outputs a `quint8` activation. But there is no guarantee that this interpretation will always work in the future, especially when we add new flavors of quantized operators.
446
447## Extensibility
448
449FX graph mode quantization can be extended to work with different backends, which may have different sets of supported quantized operator patterns and different requirements for each pattern. For more detail, please refer to the [BackendConfig README](/torch/ao/quantization/backend_config/README.md).
450