xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/qnnpack.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3from ._common_operator_config_utils import (
4    _get_binary_op_configs,
5    _get_bn_configs,
6    _get_cat_config,
7    _get_conv_configs,
8    _get_default_op_configs,
9    _get_embedding_op_configs,
10    _get_fixed_qparams_op_configs,
11    _get_linear_configs,
12    _get_rnn_op_configs,
13    _get_share_qparams_op_configs,
14)
15from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
16
17
18__all__ = [
19    "get_qnnpack_backend_config",
20]
21
22# ===================
23# |  DTYPE CONFIGS  |
24# ===================
25
26qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
27    input_dtype=torch.quint8,
28    output_dtype=torch.quint8,
29    weight_dtype=torch.qint8,
30    bias_dtype=torch.float,
31)
32
33qnnpack_default_op_quint8_dtype_config = DTypeConfig(
34    input_dtype=torch.quint8,
35    output_dtype=torch.quint8,
36)
37
38qnnpack_default_op_fp16_dtype_config = DTypeConfig(
39    input_dtype=torch.float16,
40    output_dtype=torch.float16,
41    weight_dtype=torch.float16,
42    bias_dtype=torch.float16,
43)
44
45qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
46    input_dtype=torch.quint8,
47    output_dtype=torch.float,
48    weight_dtype=torch.qint8,
49    bias_dtype=torch.float,
50    is_dynamic=True,
51)
52
53qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
54    input_dtype=torch.float16,
55    output_dtype=torch.float,
56    weight_dtype=torch.float16,
57    bias_dtype=torch.float,
58    is_dynamic=True,
59)
60
61qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
62    input_dtype=torch.float,
63    output_dtype=torch.float,
64    weight_dtype=torch.quint8,
65)
66
67qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
68    input_dtype=torch.float,
69    output_dtype=torch.float,
70    weight_dtype=torch.quint4x2,
71)
72
73# xnnpack compatible dtype configs
74
75# We restrict scale values to be 2 ** -12 to ensure the
76# requantization scale never falls below the xnnpack lower
77# threshold. Additionally, for qint8 weight, we restrict
78# the quantization values to [-127, +127], excluding -128.
79# For more detail, refer to the description of
80# `default_symmetric_qnnpack_qconfig`.
81
82# TODO: add additional restriction on qscheme to ensure it
83# is either per_tensor_symmetric or per_channel_symmetric
84
85qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
86    dtype=torch.qint8,
87    scale_min_lower_bound=2**-12,
88)
89
90qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
91    dtype=torch.qint8,
92    quant_min_lower_bound=-127,
93    quant_max_upper_bound=127,
94    scale_min_lower_bound=2**-12,
95)
96
97qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
98    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
99    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
100    weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
101    bias_dtype=torch.float,
102)
103
104qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
105    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
106    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
107)
108
109
110# =====================
111# |  BACKEND CONFIGS  |
112# =====================
113
114
115def get_qnnpack_backend_config() -> BackendConfig:
116    """
117    Return the `BackendConfig` for PyTorch's native QNNPACK backend.
118    """
119    conv_dtype_configs = [
120        qnnpack_weighted_op_qint8_symmetric_dtype_config,
121        qnnpack_weighted_op_quint8_dtype_config,
122    ]
123    linear_dtype_configs = [
124        qnnpack_weighted_op_qint8_symmetric_dtype_config,
125        qnnpack_weighted_op_quint8_dtype_config,
126        qnnpack_default_dynamic_int8_dtype_config,
127        qnnpack_default_dynamic_float16_dtype_config,
128    ]
129    binary_op_dtype_configs = [
130        qnnpack_default_op_qint8_symmetric_dtype_config,
131        qnnpack_default_op_quint8_dtype_config,
132    ]
133    default_op_dtype_configs = [
134        qnnpack_default_op_qint8_symmetric_dtype_config,
135        qnnpack_default_op_quint8_dtype_config,
136    ]
137    fixed_qparams_op_dtype_configs = [
138        qnnpack_default_op_qint8_symmetric_dtype_config,
139        qnnpack_default_op_quint8_dtype_config,
140    ]
141    share_qparams_op_dtype_configs = [
142        qnnpack_default_op_qint8_symmetric_dtype_config,
143        qnnpack_default_op_quint8_dtype_config,
144    ]
145    rnn_op_dtype_configs = [
146        qnnpack_default_dynamic_int8_dtype_config,
147        qnnpack_default_dynamic_float16_dtype_config,
148    ]
149    embedding_op_dtype_configs = [
150        qnnpack_weight_only_quint8_dtype_config,
151        qnnpack_weight_only_quint4x2_dtype_config,
152    ]
153    return (
154        BackendConfig("qnnpack")
155        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
156        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
157        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
158        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
159        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
160        .set_backend_pattern_configs(
161            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
162        )
163        .set_backend_pattern_configs(
164            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
165        )
166        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
167        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
168        .set_backend_pattern_configs(
169            _get_embedding_op_configs(embedding_op_dtype_configs)
170        )
171    )
172