1import torch 2 3from ._common_operator_config_utils import ( 4 _get_binary_op_configs, 5 _get_bn_configs, 6 _get_cat_config, 7 _get_conv_configs, 8 _get_default_op_configs, 9 _get_embedding_op_configs, 10 _get_fixed_qparams_op_configs, 11 _get_linear_configs, 12 _get_rnn_op_configs, 13 _get_share_qparams_op_configs, 14) 15from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints 16 17 18__all__ = [ 19 "get_qnnpack_backend_config", 20] 21 22# =================== 23# | DTYPE CONFIGS | 24# =================== 25 26qnnpack_weighted_op_quint8_dtype_config = DTypeConfig( 27 input_dtype=torch.quint8, 28 output_dtype=torch.quint8, 29 weight_dtype=torch.qint8, 30 bias_dtype=torch.float, 31) 32 33qnnpack_default_op_quint8_dtype_config = DTypeConfig( 34 input_dtype=torch.quint8, 35 output_dtype=torch.quint8, 36) 37 38qnnpack_default_op_fp16_dtype_config = DTypeConfig( 39 input_dtype=torch.float16, 40 output_dtype=torch.float16, 41 weight_dtype=torch.float16, 42 bias_dtype=torch.float16, 43) 44 45qnnpack_default_dynamic_int8_dtype_config = DTypeConfig( 46 input_dtype=torch.quint8, 47 output_dtype=torch.float, 48 weight_dtype=torch.qint8, 49 bias_dtype=torch.float, 50 is_dynamic=True, 51) 52 53qnnpack_default_dynamic_float16_dtype_config = DTypeConfig( 54 input_dtype=torch.float16, 55 output_dtype=torch.float, 56 weight_dtype=torch.float16, 57 bias_dtype=torch.float, 58 is_dynamic=True, 59) 60 61qnnpack_weight_only_quint8_dtype_config = DTypeConfig( 62 input_dtype=torch.float, 63 output_dtype=torch.float, 64 weight_dtype=torch.quint8, 65) 66 67qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig( 68 input_dtype=torch.float, 69 output_dtype=torch.float, 70 weight_dtype=torch.quint4x2, 71) 72 73# xnnpack compatible dtype configs 74 75# We restrict scale values to be 2 ** -12 to ensure the 76# requantization scale never falls below the xnnpack lower 77# threshold. Additionally, for qint8 weight, we restrict 78# the quantization values to [-127, +127], excluding -128. 79# For more detail, refer to the description of 80# `default_symmetric_qnnpack_qconfig`. 81 82# TODO: add additional restriction on qscheme to ensure it 83# is either per_tensor_symmetric or per_channel_symmetric 84 85qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( 86 dtype=torch.qint8, 87 scale_min_lower_bound=2**-12, 88) 89 90qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( 91 dtype=torch.qint8, 92 quant_min_lower_bound=-127, 93 quant_max_upper_bound=127, 94 scale_min_lower_bound=2**-12, 95) 96 97qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig( 98 input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, 99 output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, 100 weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12, 101 bias_dtype=torch.float, 102) 103 104qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig( 105 input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, 106 output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, 107) 108 109 110# ===================== 111# | BACKEND CONFIGS | 112# ===================== 113 114 115def get_qnnpack_backend_config() -> BackendConfig: 116 """ 117 Return the `BackendConfig` for PyTorch's native QNNPACK backend. 118 """ 119 conv_dtype_configs = [ 120 qnnpack_weighted_op_qint8_symmetric_dtype_config, 121 qnnpack_weighted_op_quint8_dtype_config, 122 ] 123 linear_dtype_configs = [ 124 qnnpack_weighted_op_qint8_symmetric_dtype_config, 125 qnnpack_weighted_op_quint8_dtype_config, 126 qnnpack_default_dynamic_int8_dtype_config, 127 qnnpack_default_dynamic_float16_dtype_config, 128 ] 129 binary_op_dtype_configs = [ 130 qnnpack_default_op_qint8_symmetric_dtype_config, 131 qnnpack_default_op_quint8_dtype_config, 132 ] 133 default_op_dtype_configs = [ 134 qnnpack_default_op_qint8_symmetric_dtype_config, 135 qnnpack_default_op_quint8_dtype_config, 136 ] 137 fixed_qparams_op_dtype_configs = [ 138 qnnpack_default_op_qint8_symmetric_dtype_config, 139 qnnpack_default_op_quint8_dtype_config, 140 ] 141 share_qparams_op_dtype_configs = [ 142 qnnpack_default_op_qint8_symmetric_dtype_config, 143 qnnpack_default_op_quint8_dtype_config, 144 ] 145 rnn_op_dtype_configs = [ 146 qnnpack_default_dynamic_int8_dtype_config, 147 qnnpack_default_dynamic_float16_dtype_config, 148 ] 149 embedding_op_dtype_configs = [ 150 qnnpack_weight_only_quint8_dtype_config, 151 qnnpack_weight_only_quint4x2_dtype_config, 152 ] 153 return ( 154 BackendConfig("qnnpack") 155 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 156 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 157 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 158 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 159 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 160 .set_backend_pattern_configs( 161 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 162 ) 163 .set_backend_pattern_configs( 164 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 165 ) 166 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 167 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 168 .set_backend_pattern_configs( 169 _get_embedding_op_configs(embedding_op_dtype_configs) 170 ) 171 ) 172