xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/native.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from ._common_operator_config_utils import (
5    _get_binary_op_configs,
6    _get_bn_configs,
7    _get_cat_config,
8    _get_conv_configs,
9    _get_default_op_configs,
10    _get_embedding_op_configs,
11    _get_fixed_qparams_op_configs,
12    _get_linear_configs,
13    _get_ln_configs,
14    _get_rnn_op_configs,
15    _get_share_qparams_op_configs,
16    _get_tensor_info_op_configs,
17)
18from .backend_config import BackendConfig, DTypeConfig
19
20
21__all__ = [
22    "get_test_only_legacy_native_backend_config",
23    "default_op_quint8_dtype_config",
24    "default_op_fp16_dtype_config",
25    "default_dynamic_int8_dtype_config",
26    "default_dynamic_float16_dtype_config",
27    "input_output_only_quint8_dtype_config",
28    "weight_only_quint8_dtype_config",
29    "weight_only_quint4x2_dtype_config",
30    "get_native_backend_config",
31    "get_native_backend_config_dict",
32    "get_test_only_legacy_native_backend_config_dict",
33]
34
35# ===================
36# |  DTYPE CONFIGS  |
37# ===================
38
39# weighted op int8 dtype config
40# this is config for ops that has quantized weights, like linear, conv
41weighted_op_quint8_dtype_config = DTypeConfig(
42    input_dtype=torch.quint8,
43    output_dtype=torch.quint8,
44    weight_dtype=torch.qint8,
45    bias_dtype=torch.float,
46)
47
48default_op_quint8_dtype_config = DTypeConfig(
49    input_dtype=torch.quint8,
50    output_dtype=torch.quint8,
51)
52
53default_op_fp16_dtype_config = DTypeConfig(
54    input_dtype=torch.float16,
55    output_dtype=torch.float16,
56    weight_dtype=torch.float16,
57    bias_dtype=torch.float16,
58)
59
60default_dynamic_int8_dtype_config = DTypeConfig(
61    input_dtype=torch.quint8,
62    output_dtype=torch.float,
63    weight_dtype=torch.qint8,
64    bias_dtype=torch.float,
65    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
66    # it is not really used yet,
67    # we will enable it a bit later after we moved everything to backend_config_dict
68    is_dynamic=True,
69)
70
71default_dynamic_float16_dtype_config = DTypeConfig(
72    input_dtype=torch.float16,
73    output_dtype=torch.float,
74    weight_dtype=torch.float16,
75    bias_dtype=torch.float,
76    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
77    # it is not really used yet,
78    # we will enable it a bit later after we moved everything to backend_config_dict
79    is_dynamic=True,
80)
81
82# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
83input_output_only_quint8_dtype_config = DTypeConfig(
84    input_dtype=torch.quint8,
85    output_dtype=torch.quint8,
86    weight_dtype=torch.float,
87    bias_dtype=torch.float,
88)
89
90weight_only_quint8_dtype_config = DTypeConfig(
91    input_dtype=torch.float,
92    output_dtype=torch.float,
93    weight_dtype=torch.quint8,
94)
95
96weight_only_quint4x2_dtype_config = DTypeConfig(
97    input_dtype=torch.float,
98    output_dtype=torch.float,
99    weight_dtype=torch.quint4x2,
100)
101
102
103# =====================
104# |  BACKEND CONFIGS  |
105# =====================
106
107
108def get_test_only_legacy_native_backend_config() -> BackendConfig:
109    """
110    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
111    """
112    conv_dtype_configs = [weighted_op_quint8_dtype_config]
113    linear_dtype_configs = [
114        weighted_op_quint8_dtype_config,
115        default_dynamic_int8_dtype_config,
116        default_dynamic_float16_dtype_config,
117        default_op_fp16_dtype_config,
118    ]
119    binary_op_dtype_configs = [
120        default_op_quint8_dtype_config,
121        default_op_fp16_dtype_config,
122    ]
123    default_op_dtype_configs = [default_op_quint8_dtype_config]
124    fixed_qparams_op_dtype_configs = [
125        default_op_quint8_dtype_config,
126        default_op_fp16_dtype_config,
127    ]
128    share_qparams_op_dtype_configs = [
129        default_op_quint8_dtype_config,
130        default_op_fp16_dtype_config,
131    ]
132    tensor_info_op_dtype_configs = [
133        default_op_quint8_dtype_config,
134    ]
135    rnn_op_dtype_configs = [
136        default_dynamic_int8_dtype_config,
137        default_dynamic_float16_dtype_config,
138    ]
139    embedding_op_dtype_configs = [
140        weight_only_quint8_dtype_config,
141        weight_only_quint4x2_dtype_config,
142    ]
143    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
144    return (
145        BackendConfig("_native_and_fp16")
146        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
147        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
148        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
149        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
150        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
151        .set_backend_pattern_configs(
152            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
153        )
154        .set_backend_pattern_configs(
155            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
156        )
157        .set_backend_pattern_configs(
158            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
159        )
160        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
161        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
162        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
163        .set_backend_pattern_configs(
164            _get_embedding_op_configs(embedding_op_dtype_configs)
165        )
166    )
167
168
169def get_native_backend_config() -> BackendConfig:
170    """
171    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
172    """
173    # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
174    conv_dtype_configs = [weighted_op_quint8_dtype_config]
175    linear_dtype_configs = [
176        weighted_op_quint8_dtype_config,
177        default_dynamic_int8_dtype_config,
178        default_dynamic_float16_dtype_config,
179    ]
180    binary_op_dtype_configs = [default_op_quint8_dtype_config]
181    default_op_dtype_configs = [default_op_quint8_dtype_config]
182    fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
183    share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
184    tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
185    rnn_op_dtype_configs = [
186        default_dynamic_int8_dtype_config,
187        default_dynamic_float16_dtype_config,
188    ]
189    embedding_op_dtype_configs = [
190        weight_only_quint8_dtype_config,
191        weight_only_quint4x2_dtype_config,
192    ]
193    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
194    return (
195        BackendConfig("native")
196        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
197        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
198        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
199        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
200        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
201        .set_backend_pattern_configs(
202            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
203        )
204        .set_backend_pattern_configs(
205            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
206        )
207        .set_backend_pattern_configs(
208            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
209        )
210        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
211        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
212        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
213        .set_backend_pattern_configs(
214            _get_embedding_op_configs(embedding_op_dtype_configs)
215        )
216    )
217
218
219def get_native_backend_config_dict():
220    """
221    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
222    """
223    return get_native_backend_config().to_dict()
224
225
226def get_test_only_legacy_native_backend_config_dict():
227    """
228    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
229    fp16 ops in dictionary form.
230    """
231    return get_test_only_legacy_native_backend_config().to_dict()
232