1## BackendConfig Overview 2 3BackendConfig allows PyTorch quantization to work with different backend or kernel libraries. These backends may have different sets of supported quantized operator patterns, and the same operator patterns may require different handling across different backends. To make quantization work with different backends and allow maximum flexibility, we strived to make all the parts of the quantization flow configurable with BackendConfig. Currently, it is only used by FX graph mode quantization. For more details on how it integrates with the FX graph mode quantization flow, refer to this [README](/torch/ao/quantization/fx/README.md). 4 5BackendConfig configures quantization behavior in terms of operator patterns. For each operator pattern, we need to specify what the supported data types are for the input and output activations, weights, and biases, and also specify the QAT modules, the reference quantized modules etc., which will be used in module swapping during the quantization passes. 6 7Quantized backends can have different support in terms of the following aspects: 8* Quantization scheme (symmetric vs asymmetric, per-channel vs per-tensor) 9* Data type (float32, float16, int8, uint8, bfloat16, etc.) for input/output/weight/bias 10* Quantized (and fused) mapping: Some quantized operators may have different numerics compared to a naive (dequant - float_op - quant) reference implementation. For weighted operators, such as conv and linear, we need to be able to specify custom reference modules and a mapping from the float modules 11* QAT mapping: For weighted operators, we need to swap them with the Quantization Aware Training (QAT) versions that add fake quantization to the weights 12 13As an example, here is what fbgemm looks like: 14| | fbgemm | 15|-------------------------------------------|-----------------------------------------------------------------------| 16| Quantization Scheme | activation: per tensor, weight: per tensor or per channel | 17| Data Type | activation: quint8 (with qmin/qmax range restrictions), weight: qint8 | 18| Quantized and Fused Operators and Mapping | e.g. torch.nn.Conv2d -> torch.ao.nn.quantized.reference.Conv2d | 19| QAT Module Mapping | e.g. torch.nn.Conv2d -> torch.ao.nn.qat.Conv2d | 20 21Instead of hardcoding the fusion mappings, float to reference quantized module mappings, fusion patterns etc., we will derive everything from the BackendConfig throughout the code base. This allows PyTorch Quantization to work with all first-party (fbgemm and qnnpack) and third-party backends (TensorRT, executorch etc.) that may differ from native backends in different aspects. With the recent addition of xnnpack, integrated as part of the qnnpack backend in PyTorch, the BackendConfig is needed to define the new constraints required for xnnpack quantized operators. 22 23## Pattern Specification 24 25The operator patterns used in BackendConfig are float modules, functional operators, pytorch operators, or a tuple combination of the above. For example: 26* torch.nn.Linear 27* torch.nn.functional.linear 28* torch.add 29* operator.add 30* (torch.nn.functional.linear, torch.nn.functional.relu) 31* (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU) 32 33Tuple patterns are treated as sequential patterns, and currently only tuples of 2 or 3 elements are supported. 34 35### Advanced Pattern Specification 36 37The above format should satisfy the vast majority of use cases. However, it does not handle more complex scenarios such as graph patterns. For these use cases, the BackendConfig API offers an alternative "reverse nested tuple" pattern format, enabled through `BackendPatternConfig()._set_pattern_complex_format(...)`. Note that this format is deprecated and will be replaced in a future version of PyTorch. 38``` 39operator = module_type | functional | torch op | native op | MatchAllNode 40Pattern = (operator, Pattern, Pattern, ...) | operator 41``` 42where the first item for each Pattern is the operator, and the rest are the patterns for the arguments of the operator. 43For example, the pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph: 44``` 45tensor_1 tensor_2 46 | | 47 *(MatchAllNode) nn.Conv2d 48 | | 49 | nn.BatchNorm2d 50 \ / 51 -- operator.add -- 52 | 53 nn.ReLU 54``` 55 56During prepare and convert, we’ll match the last node, which will be the anchor point of the match, and we can retrieve the whole graph by tracing back from the node. E.g. in the example above, we matched the `nn.ReLU` node, and `node.args[0]` is the `operator.add` node. 57 58## BackendConfig Implementation 59 60The BackendConfig is comprised of a list of BackendPatternConfigs, each of which define the specifications and the requirements for an operator pattern. Here is an example usage: 61 62``` 63import torch 64from torch.ao.quantization.backend_config import ( 65 BackendConfig, 66 BackendPatternConfig, 67 DTypeConfig, 68 ObservationType, 69) 70 71weighted_int8_dtype_config = DTypeConfig( 72 input_dtype=torch.quint8, 73 output_dtype=torch.quint8, 74 weight_dtype=torch.qint8, 75 bias_dtype=torch.float) 76 77def fuse_conv2d_relu(is_qat, conv, relu): 78 """Return a fused ConvReLU2d from individual conv and relu modules.""" 79 return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) 80 81# For quantizing Linear 82linear_config = BackendPatternConfig(torch.nn.Linear) \ 83 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 84 .add_dtype_config(weighted_int8_dtype_config) \ 85 .set_root_module(torch.nn.Linear) \ 86 .set_qat_module(torch.ao.nn.qat.Linear) \ 87 .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) 88 89# For fusing Conv2d + ReLU into ConvReLU2d 90conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ 91 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 92 .add_dtype_config(weighted_int8_dtype_config) \ 93 .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ 94 .set_fuser_method(fuse_conv2d_relu) 95 96# For quantizing ConvReLU2d 97fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ 98 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 99 .add_dtype_config(weighted_int8_dtype_config) \ 100 .set_root_module(torch.nn.Conv2d) \ 101 .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ 102 .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) 103 104backend_config = BackendConfig("my_backend") \ 105 .set_backend_pattern_config(linear_config) \ 106 .set_backend_pattern_config(conv_relu_config) \ 107 .set_backend_pattern_config(fused_conv_relu_config) 108``` 109 110### Observer Insertion 111 112Relevant APIs: 113* `set_observation_type` 114 115During the prepare phase, we insert observers (or QuantDeQuantStubs in the future) into the graph for this operator pattern based on the observation type, which specifies whether to use different observers for the inputs and the outputs of the pattern. For more detail, see `torch.ao.quantization.backend_config.ObservationType`. 116 117### Reference Quantized Patterns 118 119Relevant APIs: 120* `set_root_module` 121* `set_reference_quantized_module` 122 123During the convert phase, when we construct the reference quantized model, the root modules (e.g. `torch.nn.Linear` for `nni.LinearReLU` or `nniqat.LinearReLU`) will be swapped to the corresponding reference quantized modules (e.g. `torch.ao.nn.reference.Linear`). This allows custom backends to specify custom reference quantized module implementations to match the numerics of their lowered operators. Since this is a one-to-one mapping, both the root module and the reference quantized module must be specified in the same BackendPatternConfig in order for the conversion to take place. 124 125### Fusion 126 127Relevant APIs: 128* `set_fuser_method` 129* `set_fused_module` 130* `_set_root_node_getter` 131* `_set_extra_inputs_getter` 132 133As an optimization, operator patterns such as (`torch.nn.Linear`, `torch.nn.ReLU`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules. 134 135In FX graph mode quantization, we replace the corresponding nodes in the graph using two helper functions set by the user: `root_node_getter`, which returns the root node (typically the weighted module in the pattern like `torch.nn.Linear`) to replace the matched pattern in the graph, and `extra_inputs_getter`, which returns a list of extra input arguments that will be appended to the existing arguments of the fused module (copied over from the root node). See [this snippet](https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6) for an example usage. 136 137### Data Type Restrictions 138 139Relevant APIs: 140* `add_dtype_config` 141* `set_dtype_configs` 142 143DTypeConfig specifies a set of supported data types for input/output/weight/bias along with the associated constraints, if any. There are two ways of specifying `input_dtype`, `output_dtype`, and `weight_dtype`, as simple `torch.dtype`s or as `DTypeWithConstraints`, e.g.: 144 145``` 146import torch 147from torch.ao.quantization.backend import DTypeConfig, DTypeWithConstraints 148 149dtype_config = DTypeConfig( 150 input_dtype=torch.quint8, 151 output_dtype=torch.quint8, 152 weight_dtype=torch.qint8, 153 bias_dtype=torch.float) 154 155dtype_config_with_constraints = DTypeConfig( 156 input_dtype=DTypeWithConstraints( 157 dtype=torch.quint8, 158 quant_min_lower_bound=0, 159 quant_max_upper_bound=255, 160 scale_min_lower_bound=2 ** -12, 161 ), 162 output_dtype=DTypeWithConstraints( 163 dtype=torch.quint8, 164 quant_min_lower_bound=0, 165 quant_max_upper_bound=255, 166 scale_min_lower_bound=2 ** -12, 167 ), 168 weight_dtype=DTypeWithConstraints( 169 dtype=torch.qint8, 170 quant_min_lower_bound=-128, 171 quant_max_upper_bound=127, 172 scale_min_lower_bound=2 ** -12, 173 ), 174 bias_dtype=torch.float) 175``` 176 177During the prepare phase of quantization, we will compare the data types specified in these DTypeConfigs to the ones specified in the matching QConfig for a given operator pattern. If the data types do not match (or the constraints are not satisfied) for all the DTypeConfigs specified for the operator pattern, then we will simply ignore the QConfig and skip quantizing this pattern. 178 179#### Quantization range 180 181The user's QConfig may specify `quant_min` and `quant_max`, which are min and max restrictions on the quantization values. Here we set the lower bound for the `quant_min` and then upper bound for the `quant_max` to represent the limits of the backend. If a QConfig exceeds these limits in either direction, it will be treated as violating this constraint. 182 183#### Scale range 184 185Similarly, the user's QConfig may specify a minimum value for the quantization scale (currently exposed as `eps` but will change in the future to better reflect the semantics). Here we set the lower bound for the `scale_min` to represent the limits of the backend. If a QConfig's min scale value falls below this limit, the QConfig will be treated as violating this constraint. Note that `scale_max_upper_bound` is currently not used, because there is no corresponding mechanism to enforce this on the observer yet. 186 187#### Fixed quantization parameters 188 189For ops with fixed quantization parameters such as `torch.nn.Sigmoid` or `torch.nn.Tanh`, the BackendConfig can specify the specific scale and zero point values as constraints on the input and output activations. The user's QConfigs for these ops must use `FixedQParamsObserver` or `FixedQParamsFakeQuantize` for their activations with matching scale and zero point values, otherwise these QConfigs will be ignored. 190