1# mypy: allow-untyped-defs 2import torch 3 4from ._common_operator_config_utils import ( 5 _get_binary_op_configs, 6 _get_bn_configs, 7 _get_cat_config, 8 _get_conv_configs, 9 _get_default_op_configs, 10 _get_embedding_op_configs, 11 _get_fixed_qparams_op_configs, 12 _get_linear_configs, 13 _get_ln_configs, 14 _get_rnn_op_configs, 15 _get_share_qparams_op_configs, 16 _get_tensor_info_op_configs, 17) 18from .backend_config import BackendConfig, DTypeConfig 19 20 21__all__ = [ 22 "get_test_only_legacy_native_backend_config", 23 "default_op_quint8_dtype_config", 24 "default_op_fp16_dtype_config", 25 "default_dynamic_int8_dtype_config", 26 "default_dynamic_float16_dtype_config", 27 "input_output_only_quint8_dtype_config", 28 "weight_only_quint8_dtype_config", 29 "weight_only_quint4x2_dtype_config", 30 "get_native_backend_config", 31 "get_native_backend_config_dict", 32 "get_test_only_legacy_native_backend_config_dict", 33] 34 35# =================== 36# | DTYPE CONFIGS | 37# =================== 38 39# weighted op int8 dtype config 40# this is config for ops that has quantized weights, like linear, conv 41weighted_op_quint8_dtype_config = DTypeConfig( 42 input_dtype=torch.quint8, 43 output_dtype=torch.quint8, 44 weight_dtype=torch.qint8, 45 bias_dtype=torch.float, 46) 47 48default_op_quint8_dtype_config = DTypeConfig( 49 input_dtype=torch.quint8, 50 output_dtype=torch.quint8, 51) 52 53default_op_fp16_dtype_config = DTypeConfig( 54 input_dtype=torch.float16, 55 output_dtype=torch.float16, 56 weight_dtype=torch.float16, 57 bias_dtype=torch.float16, 58) 59 60default_dynamic_int8_dtype_config = DTypeConfig( 61 input_dtype=torch.quint8, 62 output_dtype=torch.float, 63 weight_dtype=torch.qint8, 64 bias_dtype=torch.float, 65 # currently the dtype check is not yet enabled, so we provided the dtype_configs but 66 # it is not really used yet, 67 # we will enable it a bit later after we moved everything to backend_config_dict 68 is_dynamic=True, 69) 70 71default_dynamic_float16_dtype_config = DTypeConfig( 72 input_dtype=torch.float16, 73 output_dtype=torch.float, 74 weight_dtype=torch.float16, 75 bias_dtype=torch.float, 76 # currently the dtype check is not yet enabled, so we provided the dtype_configs but 77 # it is not really used yet, 78 # we will enable it a bit later after we moved everything to backend_config_dict 79 is_dynamic=True, 80) 81 82# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights 83input_output_only_quint8_dtype_config = DTypeConfig( 84 input_dtype=torch.quint8, 85 output_dtype=torch.quint8, 86 weight_dtype=torch.float, 87 bias_dtype=torch.float, 88) 89 90weight_only_quint8_dtype_config = DTypeConfig( 91 input_dtype=torch.float, 92 output_dtype=torch.float, 93 weight_dtype=torch.quint8, 94) 95 96weight_only_quint4x2_dtype_config = DTypeConfig( 97 input_dtype=torch.float, 98 output_dtype=torch.float, 99 weight_dtype=torch.quint4x2, 100) 101 102 103# ===================== 104# | BACKEND CONFIGS | 105# ===================== 106 107 108def get_test_only_legacy_native_backend_config() -> BackendConfig: 109 """ 110 Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops. 111 """ 112 conv_dtype_configs = [weighted_op_quint8_dtype_config] 113 linear_dtype_configs = [ 114 weighted_op_quint8_dtype_config, 115 default_dynamic_int8_dtype_config, 116 default_dynamic_float16_dtype_config, 117 default_op_fp16_dtype_config, 118 ] 119 binary_op_dtype_configs = [ 120 default_op_quint8_dtype_config, 121 default_op_fp16_dtype_config, 122 ] 123 default_op_dtype_configs = [default_op_quint8_dtype_config] 124 fixed_qparams_op_dtype_configs = [ 125 default_op_quint8_dtype_config, 126 default_op_fp16_dtype_config, 127 ] 128 share_qparams_op_dtype_configs = [ 129 default_op_quint8_dtype_config, 130 default_op_fp16_dtype_config, 131 ] 132 tensor_info_op_dtype_configs = [ 133 default_op_quint8_dtype_config, 134 ] 135 rnn_op_dtype_configs = [ 136 default_dynamic_int8_dtype_config, 137 default_dynamic_float16_dtype_config, 138 ] 139 embedding_op_dtype_configs = [ 140 weight_only_quint8_dtype_config, 141 weight_only_quint4x2_dtype_config, 142 ] 143 layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] 144 return ( 145 BackendConfig("_native_and_fp16") 146 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 147 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 148 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 149 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 150 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 151 .set_backend_pattern_configs( 152 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 153 ) 154 .set_backend_pattern_configs( 155 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 156 ) 157 .set_backend_pattern_configs( 158 _get_tensor_info_op_configs(tensor_info_op_dtype_configs) 159 ) 160 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 161 .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) 162 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 163 .set_backend_pattern_configs( 164 _get_embedding_op_configs(embedding_op_dtype_configs) 165 ) 166 ) 167 168 169def get_native_backend_config() -> BackendConfig: 170 """ 171 Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). 172 """ 173 # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs 174 conv_dtype_configs = [weighted_op_quint8_dtype_config] 175 linear_dtype_configs = [ 176 weighted_op_quint8_dtype_config, 177 default_dynamic_int8_dtype_config, 178 default_dynamic_float16_dtype_config, 179 ] 180 binary_op_dtype_configs = [default_op_quint8_dtype_config] 181 default_op_dtype_configs = [default_op_quint8_dtype_config] 182 fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config] 183 share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] 184 tensor_info_op_dtype_configs = [default_op_quint8_dtype_config] 185 rnn_op_dtype_configs = [ 186 default_dynamic_int8_dtype_config, 187 default_dynamic_float16_dtype_config, 188 ] 189 embedding_op_dtype_configs = [ 190 weight_only_quint8_dtype_config, 191 weight_only_quint4x2_dtype_config, 192 ] 193 layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] 194 return ( 195 BackendConfig("native") 196 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 197 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 198 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 199 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 200 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 201 .set_backend_pattern_configs( 202 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 203 ) 204 .set_backend_pattern_configs( 205 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 206 ) 207 .set_backend_pattern_configs( 208 _get_tensor_info_op_configs(tensor_info_op_dtype_configs) 209 ) 210 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 211 .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) 212 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 213 .set_backend_pattern_configs( 214 _get_embedding_op_configs(embedding_op_dtype_configs) 215 ) 216 ) 217 218 219def get_native_backend_config_dict(): 220 """ 221 Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form. 222 """ 223 return get_native_backend_config().to_dict() 224 225 226def get_test_only_legacy_native_backend_config_dict(): 227 """ 228 Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional 229 fp16 ops in dictionary form. 230 """ 231 return get_test_only_legacy_native_backend_config().to_dict() 232