1# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime 2# not a specific backend 3 4import operator 5from typing import List 6 7import torch 8import torch.ao.nn.qat as nnqat 9import torch.ao.nn.quantized.reference as nnqr 10import torch.nn as nn 11import torch.nn.functional as F 12from torch.ao.quantization.fuser_method_mappings import ( 13 _sequential_wrapper2, 14 fuse_conv_bn, 15 fuse_conv_bn_relu, 16) 17 18from ._common_operator_config_utils import _Conv2dMetadata 19from .backend_config import ( 20 BackendConfig, 21 BackendPatternConfig, 22 DTypeConfig, 23 DTypeWithConstraints, 24 ObservationType, 25) 26from .qnnpack import ( 27 qnnpack_default_op_qint8_symmetric_dtype_config, 28 qnnpack_weighted_op_qint8_symmetric_dtype_config, 29) 30 31 32__all__ = [ 33 "get_executorch_backend_config", 34] 35 36 37# =================== 38# | DTYPE CONFIGS | 39# =================== 40 41executorch_weighted_op_int8_dtype_config = DTypeConfig( 42 input_dtype=torch.quint8, 43 output_dtype=torch.quint8, 44 weight_dtype=torch.qint8, 45 bias_dtype=torch.float, 46) 47 48executorch_default_op_quint8_dtype_config = DTypeConfig( 49 input_dtype=torch.quint8, 50 output_dtype=torch.quint8, 51) 52 53executorch_default_dynamic_quint8_dtype_config = DTypeConfig( 54 input_dtype=torch.quint8, 55 output_dtype=torch.float, 56 weight_dtype=torch.qint8, 57 bias_dtype=torch.float, 58 is_dynamic=True, 59) 60 61executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( 62 dtype=torch.qint8, 63 scale_min_lower_bound=2**-12, 64) 65 66executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( 67 dtype=torch.qint8, 68 quant_min_lower_bound=-127, 69 quant_max_upper_bound=127, 70 scale_min_lower_bound=2**-12, 71) 72 73executorch_default_dynamic_qint8_dtype_config = DTypeConfig( 74 input_dtype=executorch_act_qint8_scale_min_2_neg_12, 75 output_dtype=torch.float, 76 weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12, 77 bias_dtype=torch.float, 78 is_dynamic=True, 79) 80 81executorch_default_dynamic_float16_dtype_config = DTypeConfig( 82 input_dtype=torch.float16, 83 output_dtype=torch.float, 84 weight_dtype=torch.float16, 85 bias_dtype=torch.float, 86 is_dynamic=True, 87) 88 89executorch_weight_only_quint8_dtype_config = DTypeConfig( 90 input_dtype=torch.float, 91 output_dtype=torch.float, 92 weight_dtype=torch.quint8, 93) 94 95 96# ============================= 97# | BACKEND PATTERN CONFIGS | 98# ============================= 99 100 101def _get_linear_configs() -> List[BackendPatternConfig]: 102 """ 103 Return all configs related to linear modules and ops. 104 """ 105 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 106 dtype_configs = [ 107 qnnpack_weighted_op_qint8_symmetric_dtype_config, 108 executorch_weighted_op_int8_dtype_config, 109 executorch_default_dynamic_quint8_dtype_config, 110 executorch_default_dynamic_qint8_dtype_config, 111 executorch_default_dynamic_float16_dtype_config, 112 ] 113 linear_configs: List[BackendPatternConfig] = [] 114 # linear module 115 linear_configs.append( 116 BackendPatternConfig(torch.nn.Linear) 117 .set_observation_type(observation_type) # noqa: E131 118 .set_dtype_configs(dtype_configs) 119 .set_root_module(torch.nn.Linear) 120 .set_reference_quantized_module(nnqr.Linear) 121 .set_qat_module(nnqat.Linear) 122 ) 123 # linear qat module 124 linear_configs.append( 125 BackendPatternConfig(nnqat.Linear) 126 .set_observation_type(observation_type) # noqa: E131 127 .set_dtype_configs(dtype_configs) 128 .set_root_module(torch.nn.Linear) 129 .set_reference_quantized_module(nnqr.Linear) 130 ) 131 # functional linear 132 linear_configs.append( 133 BackendPatternConfig(torch.nn.functional.linear) 134 .set_observation_type(observation_type) # noqa: E131 135 .set_dtype_configs(dtype_configs) 136 ._set_input_type_to_index({"weight": 1, "bias": 2}) 137 ) 138 return linear_configs 139 140 141def _get_conv_configs() -> List[BackendPatternConfig]: 142 """ 143 Return all configs related to conv modules and ops. 144 """ 145 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 146 dtype_configs = [ 147 qnnpack_weighted_op_qint8_symmetric_dtype_config, 148 executorch_weighted_op_int8_dtype_config, 149 ] 150 conv_configs = [] 151 for convs in [_Conv2dMetadata]: 152 # (1) Single conv modules/functions 153 # ----------------------------------- 154 # conv module 155 conv_configs.append( 156 BackendPatternConfig(convs.root) 157 .set_observation_type(observation_type) # noqa: E131 158 .set_dtype_configs(dtype_configs) 159 .set_root_module(convs.root) 160 .set_reference_quantized_module(convs.reference) 161 .set_qat_module(convs.qat) 162 ) 163 # conv qat module 164 conv_configs.append( 165 BackendPatternConfig(convs.qat) 166 .set_observation_type(observation_type) # noqa: E131 167 .set_dtype_configs(dtype_configs) 168 .set_root_module(convs.root) 169 .set_reference_quantized_module(convs.reference) 170 ) 171 # functional conv 172 conv_configs.append( 173 BackendPatternConfig(convs.func) 174 .set_observation_type(observation_type) # noqa: E131 175 .set_dtype_configs(dtype_configs) 176 ._set_input_type_to_index({"weight": 1, "bias": 2}) 177 ) 178 179 # (2) Conv + relu 180 # ----------------------------------- 181 # conv module + relu module 182 conv_configs.append( 183 BackendPatternConfig((convs.root, nn.ReLU)) 184 .set_dtype_configs(dtype_configs) # noqa: E131 185 .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) 186 .set_fused_module(convs.fused_conv_relu) 187 ) 188 # conv module + functional relu 189 conv_configs.append( 190 BackendPatternConfig((convs.root, F.relu)) 191 .set_dtype_configs(dtype_configs) # noqa: E131 192 .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) 193 .set_fused_module(convs.fused_conv_relu) 194 ) 195 # fused conv relu module 196 conv_configs.append( 197 BackendPatternConfig(convs.fused_conv_relu) 198 .set_observation_type(observation_type) # noqa: E131 199 .set_dtype_configs(dtype_configs) 200 .set_root_module(convs.root) 201 .set_reference_quantized_module(convs.reference) 202 .set_qat_module(convs.relu_qat) 203 ) 204 # conv relu, qat fused module 205 conv_configs.append( 206 BackendPatternConfig(convs.relu_qat) 207 .set_observation_type(observation_type) # noqa: E131 208 .set_dtype_configs(dtype_configs) 209 .set_root_module(convs.root) 210 .set_reference_quantized_module(convs.reference) 211 ) 212 # functional conv + relu module 213 conv_configs.append( 214 BackendPatternConfig((convs.func, nn.ReLU)) 215 .set_observation_type(observation_type) # noqa: E131 216 .set_dtype_configs(dtype_configs) 217 ) 218 # functional conv + functional relu 219 conv_configs.append( 220 BackendPatternConfig((convs.func, F.relu)) 221 .set_observation_type(observation_type) # noqa: E131 222 .set_dtype_configs(dtype_configs) 223 ) 224 # fused conv relu 225 conv_configs.append( 226 BackendPatternConfig(convs.fused_conv_relu) 227 .set_dtype_configs(dtype_configs) # noqa: E131 228 .set_qat_module(convs.relu_qat) 229 ) 230 231 conv_configs.append( 232 BackendPatternConfig(convs.relu_qat) 233 .set_dtype_configs(dtype_configs) # noqa: E131 234 .set_root_module(convs.root) 235 .set_reference_quantized_module(convs.reference) 236 ) 237 238 # (3) Conv + batchnorm (+ relu) 239 # ------------------------------- 240 # conv + batchnorm (+ relu) 241 conv_configs.append( 242 BackendPatternConfig((convs.root, convs.bn)) 243 .set_dtype_configs(dtype_configs) # noqa: E131 244 .set_fuser_method(fuse_conv_bn) 245 .set_fused_module(convs.fused_conv_bn) 246 ) 247 # conv + bn + relu module fusion 248 conv_configs.append( 249 BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) 250 .set_dtype_configs(dtype_configs) # noqa: E131 251 .set_fuser_method(fuse_conv_bn_relu) 252 .set_fused_module(convs.fused_conv_bn_relu) 253 ) 254 # conv + bn + relu functional fusion 255 conv_configs.append( 256 BackendPatternConfig((convs.root, convs.bn, F.relu)) 257 .set_dtype_configs(dtype_configs) # noqa: E131 258 .set_root_module(convs.root) 259 .set_fuser_method(fuse_conv_bn_relu) 260 .set_fused_module(convs.fused_conv_bn_relu) 261 ) 262 # TODO: we can add fusion for torch.relu as well 263 # 3.2 conv + bn (+ relu) fused module configs 264 # fused conv bn 265 conv_configs.append( 266 BackendPatternConfig(convs.fused_conv_bn) 267 .set_dtype_configs(dtype_configs) # noqa: E131 268 .set_qat_module(convs.bn_qat) 269 ) 270 271 # fused conv bn relu 272 conv_configs.append( 273 BackendPatternConfig(convs.fused_conv_bn_relu) 274 .set_dtype_configs(dtype_configs) # noqa: E131 275 .set_qat_module(convs.bn_relu_qat) 276 ) 277 278 # conv bn, qat fused module 279 conv_configs.append( 280 BackendPatternConfig(convs.bn_qat) 281 .set_observation_type(observation_type) # noqa: E131 282 .set_dtype_configs(dtype_configs) 283 .set_root_module(convs.root) 284 .set_reference_quantized_module(convs.reference) 285 ) 286 # conv bn relu, qat fused module 287 conv_configs.append( 288 BackendPatternConfig(convs.bn_relu_qat) 289 .set_observation_type(observation_type) # noqa: E131 290 .set_dtype_configs(dtype_configs) 291 .set_root_module(convs.root) 292 .set_reference_quantized_module(convs.reference) 293 ) 294 return conv_configs 295 296 297def _get_binary_ops_configs() -> List[BackendPatternConfig]: 298 """ 299 Return all configs related to binary ops. 300 """ 301 dtype_configs = [ 302 qnnpack_default_op_qint8_symmetric_dtype_config, 303 executorch_weighted_op_int8_dtype_config, 304 ] 305 num_tensor_args_to_observation_type_mapping = { 306 # TODO: this is not used right now since we have extra check in prepare 307 # will need to change this to NO_OBSERVER later after we implemented 308 # Tensor dtype inference properly 309 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 310 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, 311 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 312 } 313 binary_op_configs: List[BackendPatternConfig] = [] 314 for op in [ 315 operator.add, 316 torch.add, 317 operator.sub, 318 torch.sub, 319 operator.mul, 320 torch.mul, 321 ]: 322 bop_patterns = [ 323 (op, torch.nn.ReLU), 324 (op, torch.nn.functional.relu), 325 (op, torch.relu), 326 op, 327 ] 328 for bop_pattern in bop_patterns: 329 binary_op_configs.append( 330 BackendPatternConfig(bop_pattern) 331 .set_dtype_configs(dtype_configs) # noqa: E131 332 ._set_num_tensor_args_to_observation_type( 333 num_tensor_args_to_observation_type_mapping 334 ) 335 ) 336 return binary_op_configs 337 338 339def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]: 340 """ 341 Return the operator configs for the operators that works for both float and quantized 342 input if input is quantized, the output Tensor shares the same quantization parameter 343 with input. 344 345 Example operator: avgpool2d, reshape, transpose, maxpool2d 346 Example observed operator: 347 observer_0 - avgpool2d - observer_0 (same observer instance as input) 348 """ 349 observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT 350 dtype_configs = [ 351 qnnpack_default_op_qint8_symmetric_dtype_config, 352 executorch_default_op_quint8_dtype_config, 353 ] 354 share_qparams_ops = [ 355 torch.nn.Flatten, 356 F.adaptive_avg_pool2d, 357 F.elu, 358 F.hardtanh, 359 F.max_pool2d, 360 F.pad, 361 F.relu, 362 F.relu6, 363 F.leaky_relu, 364 F.leaky_relu_, 365 torch.nn.AdaptiveAvgPool2d, 366 torch.nn.ConstantPad2d, 367 torch.nn.ELU, 368 torch.nn.MaxPool2d, 369 torch.nn.ReLU6, 370 torch.nn.Hardtanh, 371 torch.nn.LeakyReLU, 372 torch.clamp, 373 torch.flatten, 374 torch.mean, 375 torch.permute, 376 torch.permute_copy, 377 torch.squeeze, 378 "clamp", 379 "mean", 380 "permute", 381 "reshape", 382 "relu", 383 "relu_", 384 "squeeze", 385 "squeeze_", 386 "leaky_relu", 387 ] 388 share_qparams_op_configs: List[BackendPatternConfig] = [] 389 for op in share_qparams_ops: 390 share_qparams_op_configs.append( 391 BackendPatternConfig(op) 392 .set_observation_type(observation_type) # noqa: E131 393 .set_dtype_configs(dtype_configs) 394 ) 395 return share_qparams_op_configs 396 397 398def _get_bn_configs() -> List[BackendPatternConfig]: 399 """ 400 Return all configs related to batchnorm. 401 """ 402 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 403 dtype_configs = [ 404 qnnpack_default_op_qint8_symmetric_dtype_config, 405 executorch_default_op_quint8_dtype_config, 406 ] 407 bn_configs = [] 408 bn_configs.append( 409 BackendPatternConfig(nn.BatchNorm2d) 410 .set_observation_type(observation_type) # noqa: E131 411 .set_dtype_configs(dtype_configs) 412 ) 413 return bn_configs 414 415 416def _get_cat_configs() -> List[BackendPatternConfig]: 417 dtype_configs = [ 418 qnnpack_default_op_qint8_symmetric_dtype_config, 419 executorch_default_op_quint8_dtype_config, 420 ] 421 cat_configs = [] 422 cat_configs.append( 423 BackendPatternConfig(torch.cat) 424 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 425 .set_dtype_configs(dtype_configs) 426 ) 427 cat_configs.append( 428 BackendPatternConfig(torch.concat) 429 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 430 .set_dtype_configs(dtype_configs) 431 ) 432 cat_configs.append( 433 BackendPatternConfig(torch.concatenate) 434 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 435 .set_dtype_configs(dtype_configs) 436 ) 437 return cat_configs 438 439 440def _get_embedding_op_configs() -> List[BackendPatternConfig]: 441 dtype_configs = [ 442 executorch_weight_only_quint8_dtype_config, 443 ] 444 embedding_op_configs = [] 445 for embedding_op, qat_embedding_op, ref_embedding_op in [ 446 (nn.Embedding, nnqat.Embedding, nnqr.Embedding), 447 (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), 448 ]: 449 embedding_op_configs.append( 450 BackendPatternConfig(embedding_op) 451 .set_observation_type( 452 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 453 ) # noqa: E131 454 .set_dtype_configs(dtype_configs) 455 .set_qat_module(qat_embedding_op) 456 .set_root_module(embedding_op) 457 .set_reference_quantized_module(ref_embedding_op) 458 ) 459 # config for qat op 460 embedding_op_configs.append( 461 BackendPatternConfig(qat_embedding_op) 462 .set_observation_type( 463 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 464 ) # noqa: E131 465 .set_dtype_configs(dtype_configs) 466 .set_root_module(embedding_op) 467 .set_reference_quantized_module(ref_embedding_op) 468 ) 469 470 # config for functional embedding 471 embedding_op_configs.append( 472 BackendPatternConfig(torch.nn.functional.embedding) 473 .set_observation_type( 474 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 475 ) # noqa: E131 476 .set_dtype_configs(dtype_configs) 477 ._set_input_type_to_index({"weight": 1}) 478 ) 479 return embedding_op_configs 480 481 482# ===================== 483# | BACKEND CONFIGS | 484# ===================== 485 486 487def get_executorch_backend_config() -> BackendConfig: 488 """ 489 Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack. 490 """ 491 return ( 492 BackendConfig("executorch") 493 .set_backend_pattern_configs(_get_linear_configs()) 494 .set_backend_pattern_configs(_get_conv_configs()) 495 .set_backend_pattern_configs(_get_binary_ops_configs()) 496 .set_backend_pattern_configs(_get_share_qparams_ops_configs()) 497 .set_backend_pattern_configs(_get_bn_configs()) 498 .set_backend_pattern_configs(_get_cat_configs()) 499 .set_backend_pattern_configs(_get_embedding_op_configs()) 500 ) 501