1# mypy: allow-untyped-defs 2import operator 3from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union 4 5import torch 6import torch.ao.nn.intrinsic as nni 7import torch.ao.nn.intrinsic.quantized as nniq 8import torch.ao.nn.intrinsic.quantized.dynamic as nniqd 9import torch.ao.nn.quantized as nnq 10import torch.ao.nn.quantized.dynamic as nnqd 11import torch.ao.nn.quantized.reference as nnqr 12import torch.nn as nn 13import torch.nn.functional as F 14from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule 15from torch.ao.quantization.qconfig import QConfigAny 16from torch.ao.quantization.quantization_mappings import get_quantized_operator 17from torch.ao.quantization.utils import _parent_name 18from torch.fx import GraphModule, map_arg, Node 19from torch.fx.graph import Graph 20 21from .utils import ( 22 collect_producer_nodes, 23 create_node_from_old_node_preserve_meta, 24 get_linear_prepack_op_for_dtype, 25 get_new_attr_name_with_prefix, 26 get_qconv_prepack_op, 27 graph_module_from_producer_nodes, 28) 29 30 31QOP_TO_ARG_NAMES_TO_SKIP = { 32 torch._ops.ops.quantized.hardswish: ["inplace"], 33 torch._ops.ops.quantized.elu: ["inplace"], 34 torch._ops.ops.quantized.dropout: ["inplace"], 35 torch._ops.ops.quantized.instance_norm: [ 36 "running_mean", 37 "running_var", 38 "use_input_stats", 39 "momentum", 40 ], 41} 42 43 44def _is_node_in_list(node, modules, func_list, method_list, module_type_list): 45 is_call_function = node.op == "call_function" and node.target in func_list 46 is_call_method = node.op == "call_method" and node.target in method_list 47 is_call_module = ( 48 node.op == "call_module" and type(modules[str(node.target)]) in module_type_list 49 ) 50 return is_call_function, is_call_method, is_call_module 51 52 53def is_fixed_qparams_node(node, modules): 54 func_list = [ 55 torch.nn.functional.hardsigmoid, 56 torch.nn.functional.sigmoid, 57 torch.sigmoid, 58 torch.tanh, 59 ] 60 method_list = [ 61 "hardsigmoid", 62 "hardsigmoid_", 63 "sigmoid", 64 "sigmoid_", 65 "tanh", 66 "tanh_", 67 ] 68 module_type_list = [ 69 torch.nn.Hardsigmoid, 70 torch.nn.Sigmoid, 71 torch.nn.Tanh, 72 torch.nn.Softmax, 73 ] 74 return _is_node_in_list(node, modules, func_list, method_list, module_type_list) 75 76 77def is_default_node(node, modules): 78 func_list = [ 79 torch.nn.functional.elu, 80 torch.nn.functional.hardswish, 81 torch.nn.functional.instance_norm, 82 torch.nn.functional.layer_norm, 83 torch.nn.functional.leaky_relu, 84 torch.nn.functional.dropout, 85 ] 86 method_list: List[Any] = [] 87 module_type_list = [ 88 nnqr.ConvTranspose1d, 89 nnqr.ConvTranspose2d, 90 nnqr.ConvTranspose3d, 91 torch.nn.ELU, 92 torch.nn.LeakyReLU, 93 torch.nn.Hardswish, 94 torch.nn.InstanceNorm1d, 95 torch.nn.InstanceNorm2d, 96 torch.nn.InstanceNorm3d, 97 torch.nn.LayerNorm, 98 torch.nn.Dropout, 99 torch.nn.PReLU, 100 torch.nn.BatchNorm2d, 101 torch.nn.BatchNorm3d, 102 torch.ao.nn.intrinsic.BNReLU2d, 103 torch.ao.nn.intrinsic.BNReLU3d, 104 ] 105 return _is_node_in_list(node, modules, func_list, method_list, module_type_list) 106 107 108def is_copy_node(node, modules): 109 func_list = [ 110 torch.adaptive_avg_pool1d, 111 torch.nn.functional.adaptive_avg_pool2d, 112 torch.nn.functional.adaptive_avg_pool3d, 113 torch.nn.functional.hardtanh, 114 torch.nn.functional.hardtanh_, 115 torch.nn.functional.interpolate, 116 torch.nn.functional.max_pool1d, 117 torch.nn.functional.max_pool2d, 118 torch.nn.functional.max_pool3d, 119 torch.nn.functional.relu, 120 torch.nn.functional.relu6, 121 torch.avg_pool1d, 122 torch._C._nn.avg_pool2d, 123 torch._C._nn.avg_pool3d, 124 torch.clamp, 125 torch.flatten, 126 torch.mean, 127 operator.floordiv, 128 # F.channel_shuffle and torch.channel_shuffle are essentially the same thing 129 # so we only need to put one of them here 130 torch.channel_shuffle, 131 ] 132 method_list = [ 133 "clamp", 134 "mean", 135 "relu", 136 "relu_", 137 ] 138 module_type_list = [ 139 torch.nn.AdaptiveAvgPool1d, 140 torch.nn.AdaptiveAvgPool2d, 141 torch.nn.AdaptiveAvgPool3d, 142 torch.nn.AvgPool1d, 143 torch.nn.AvgPool2d, 144 torch.nn.AvgPool3d, 145 torch.nn.Hardtanh, 146 torch.nn.MaxPool1d, 147 torch.nn.MaxPool2d, 148 torch.nn.MaxPool3d, 149 torch.nn.ReLU, 150 torch.nn.ReLU6, 151 torch.nn.ChannelShuffle, 152 ] 153 return _is_node_in_list(node, modules, func_list, method_list, module_type_list) 154 155 156def is_general_tensor_shape_node(node, modules): 157 func_list = [ 158 torch.narrow, 159 torch.transpose, 160 torch.repeat_interleave, 161 torch.squeeze, 162 torch.stack, 163 torch.unsqueeze, 164 torch.nn.functional.pixel_shuffle, 165 torch.nn.functional.pixel_unshuffle, 166 ] 167 method_list = [ 168 "contiguous", 169 "detach", 170 "detach_", 171 "permute", 172 "repeat", 173 "repeat_interleave", 174 "reshape", 175 "resize_", 176 "shape", 177 "size", 178 "squeeze", 179 "squeeze_", 180 "transpose", 181 "unsqueeze", 182 "unsqueeze_", 183 "view", 184 ] 185 module_type_list = [ 186 torch.nn.Identity, 187 torch.nn.PixelShuffle, 188 torch.nn.PixelUnshuffle, 189 ] 190 return _is_node_in_list(node, modules, func_list, method_list, module_type_list) 191 192 193def is_other_node(node, modules): 194 func_list = [ 195 torch.cat, 196 ] 197 method_list: List[Any] = [] 198 module_type_list: List[Any] = [] 199 return _is_node_in_list(node, modules, func_list, method_list, module_type_list) 200 201 202def is_special_pattern_node(node, modules): 203 res_function, res_method, res_module = False, False, False 204 for checker in [ 205 is_fixed_qparams_node, 206 is_default_node, 207 is_copy_node, 208 is_general_tensor_shape_node, 209 is_other_node, 210 ]: 211 is_call_function, is_call_method, is_call_module = checker(node, modules) 212 res_function = res_function or is_call_function 213 res_method = res_method or is_call_method 214 res_module = res_module or is_call_module 215 return res_function, res_method, res_module 216 217 218def is_dequantize_node(node): 219 return ( 220 isinstance(node, Node) 221 and node.op == "call_method" 222 and node.target == "dequantize" 223 ) 224 225 226def is_getattr_tensor_metadata_node(node): 227 return ( 228 node.op == "call_function" 229 and node.target == getattr 230 and node.args[1] in ["shape"] 231 ) 232 233 234def is_get_tensor_info_node(node): 235 return node.op == "call_method" and node.target in ["shape", "size"] 236 237 238def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigAny]): 239 """ 240 Return True if the op is configured with a None qconfig, False otherwise. 241 Note: maybe need to generalize this to also check for the dtype, and we 242 only lower when dtype matches, but right now fbgemm/qnnpack only support 243 a single dtype, so it is OK for now. 244 """ 245 return op.name in qconfig_map and qconfig_map[op.name] is None 246 247 248# Mapping from reference module class to the replacement static quantized module class for lowering 249STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = { 250 nnqr.Linear: nnq.Linear, 251 nnqr.Conv1d: nnq.Conv1d, 252 nnqr.Conv2d: nnq.Conv2d, 253 nnqr.Conv3d: nnq.Conv3d, 254} 255 256# Mapping from reference module class to the replacement dynamic quantized module class for lowering 257DYNAMIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = { 258 nnqr.Linear: nnqd.Linear, 259 nnqr.GRUCell: nnqd.GRUCell, 260 nnqr.LSTMCell: nnqd.LSTMCell, 261 nnqr.RNNCell: nnqd.RNNCell, 262 nnqr.LSTM: nnqd.LSTM, 263 nnqr.GRU: nnqd.GRU, 264} 265 266# Mapping from reference module class to the replacement weight only quantized module class for lowering 267# TODO: correct the namespace for these modules 268WEIGHT_ONLY_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = { 269 nnqr.Embedding: nnq.Embedding, 270 nnqr.EmbeddingBag: nnq.EmbeddingBag, 271} 272 273# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge 274# _lower_static_weighted_ref_module and special_pattern_replacement 275SPECIAL_PATTERN_LOWER_MODULE_MAP = { 276 nn.BatchNorm2d: nnq.BatchNorm2d, 277 nn.BatchNorm3d: nnq.BatchNorm3d, 278 nnqr.ConvTranspose1d: nnq.ConvTranspose1d, 279 nnqr.ConvTranspose2d: nnq.ConvTranspose2d, 280 nnqr.ConvTranspose3d: nnq.ConvTranspose3d, 281 nn.ELU: nnq.ELU, 282 nn.LeakyReLU: nnq.LeakyReLU, 283 nn.Hardswish: nnq.Hardswish, 284 nn.InstanceNorm1d: nnq.InstanceNorm1d, 285 nn.InstanceNorm2d: nnq.InstanceNorm2d, 286 nn.InstanceNorm3d: nnq.InstanceNorm3d, 287 nn.LayerNorm: nnq.LayerNorm, 288 nn.Dropout: nnq.Dropout, 289 nn.Softmax: nnq.Softmax, 290 nn.PReLU: nnq.PReLU, 291 nni.BNReLU2d: nniq.BNReLU2d, 292 nni.BNReLU3d: nniq.BNReLU3d, 293} 294 295# Mapping from fused module class to a 2-tuple of: 296# 1) The inner reference module class 297# 2) The replacement static quantized module class for lowering 298STATIC_LOWER_FUSED_MODULE_MAP: Dict[ 299 Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]] 300] = { 301 nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU), 302 # TODO: LinearLeakyReLU is registered as global but it is only fused and 303 # lowered when ondnn's backend config is used. Maybe need to separate 304 # registration and lowering functions for different backends in the future. 305 nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU), 306 nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh), 307 nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d), 308 nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d), 309 nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d), 310} 311 312# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP: 313# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs. 314# Mapping from fused module class to a 2-tuple of: 315# 1) The inner reference module class 316# 2) The replacement static quantized module class for lowering 317STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: Dict[ 318 Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]] 319] = { 320 nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d), 321 nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d), 322} 323 324# Mapping from fused module class to a 2-tuple of: 325# 1) The inner reference module class 326# 2) The replacement dynamic quantized module class for lowering 327DYNAMIC_LOWER_FUSED_MODULE_MAP: Dict[ 328 Type[nn.Module], Tuple[Type[nn.Module], Type[nn.Module]] 329] = { 330 nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU), 331} 332 333# Mapping from a functional to lower to a 2-tuple of 334# 1) The quantized version of the op 335# 2) The quantized version of the op fused with relu, if it exists, else None 336STATIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Optional[Callable]]] = { 337 F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu), 338 F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu), 339 F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu), 340 F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu), 341 F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None), 342 F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None), 343 F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None), 344} 345 346WEIGHT_PREPACK_OPS: Set[Callable] = { 347 torch._ops.ops.quantized.linear_prepack, 348 torch._ops.ops.quantized.linear_prepack_fp16, 349 torch._ops.ops.quantized.conv1d_prepack, 350 torch._ops.ops.quantized.conv2d_prepack, 351 torch._ops.ops.quantized.conv3d_prepack, 352 torch.ops.quantized.conv_transpose1d_prepack, 353 torch.ops.quantized.conv_transpose2d_prepack, 354 torch.ops.quantized.conv_transpose3d_prepack, 355} 356 357# Mapping from a functional to a dictionary, where the key is a 2-tuple of 358# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of 359# 1) The dynamically quantized version of the op 360# 2) The dynamically quantized version of the op fused with relu, if it exists, else None 361DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[ 362 Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]] 363] = { 364 F.linear: { 365 (torch.quint8, torch.qint8): ( 366 torch.ops.quantized.linear_dynamic, 367 torch.ops.quantized.linear_relu_dynamic, 368 ), 369 (torch.float16, torch.float16): ( 370 torch.ops.quantized.linear_dynamic_fp16, 371 torch.ops.quantized.linear_relu_dynamic_fp16, 372 ), 373 }, 374 # dynamic conv + relu is not available yet 375 F.conv1d: { 376 (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None), 377 }, 378 F.conv2d: { 379 (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None), 380 }, 381 F.conv3d: { 382 (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None), 383 }, 384} 385 386CONV_FUNCTIONAL_OPS: Set[Callable] = { 387 F.conv1d, 388 F.conv2d, 389 F.conv3d, 390} 391 392CONV_TRANSPOSE_FUNCTIONAL_OPS: Set[Callable] = { 393 F.conv_transpose1d, 394 F.conv_transpose2d, 395 F.conv_transpose3d, 396} 397 398# TODO: add tests for lowering these ops 399QBIN_OP_MAPPING: Dict[Union[Callable, str], Callable] = { 400 operator.add: torch.ops.quantized.add, 401 torch.add: torch.ops.quantized.add, 402 operator.mul: torch.ops.quantized.mul, 403 operator.matmul: torch.ops.quantized.matmul, 404 torch.mul: torch.ops.quantized.mul, 405 torch.matmul: torch.ops.quantized.matmul, 406} 407QBIN_RELU_OP_MAPPING: Dict[Union[Callable, str], Callable] = { 408 operator.add: torch.ops.quantized.add_relu, 409 torch.add: torch.ops.quantized.add_relu, 410 operator.mul: torch.ops.quantized.mul_relu, 411 torch.mul: torch.ops.quantized.mul_relu, 412} 413 414 415def _save_packed_weight(self, destination, prefix, keep_vars): 416 for attr_name in dir(self): 417 if "_packed_weight" in attr_name and isinstance( 418 getattr(self, attr_name), torch._C.ScriptObject 419 ): # type: ignore[attr-defined] 420 packed_weight = getattr(self, attr_name) 421 destination[prefix + attr_name] = packed_weight 422 423 424def _load_packed_weight( 425 self, 426 state_dict, 427 prefix, 428 local_metadata, 429 strict, 430 missing_keys, 431 unexpected_keys, 432 error_msgs, 433): 434 attrs_to_pop = [] 435 for attr_name in state_dict: 436 if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 437 setattr(self, attr_name, state_dict[attr_name]) 438 attrs_to_pop.append(attr_name) 439 440 # pop the packed param attributesn 441 for attr_name in attrs_to_pop: 442 state_dict.pop(attr_name) 443 444 445def fold_weight( 446 quantized_model: GraphModule, node_name_to_scope: Dict[str, Tuple[str, type]] 447) -> GraphModule: 448 """ 449 Trace back from the weight node util we hit getattr, reconstruct the 450 graph module with the traced nodes and run the graph module to pack the 451 weight. then replace the original chain of ops with the packed weight. 452 """ 453 packed_weights = {} 454 # map from folded node name to the prepacked weight name 455 folded_nodes = {} 456 # get packed weights 457 for node in quantized_model.graph.nodes: 458 if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS: 459 nodes_to_fold = collect_producer_nodes(node) 460 if nodes_to_fold is not None: 461 for node_to_fold in nodes_to_fold: 462 folded_nodes[node_to_fold.name] = node 463 464 prepacking_module = graph_module_from_producer_nodes( 465 quantized_model, nodes_to_fold 466 ) 467 packed_weight = prepacking_module() 468 packed_weights[node.name] = packed_weight 469 470 # remove folded nodes and replace the prepacking node with getattr 471 folded_graph = Graph() 472 env: Dict[Any, Any] = {} 473 474 def load_arg(a): 475 return map_arg(a, lambda node: env[node.name]) 476 477 for node in quantized_model.graph.nodes: 478 prepack_node = folded_nodes.get(node.name, None) 479 if prepack_node is node: 480 packed_weight = packed_weights[node.name] 481 # add a prepacked attribute to root 482 op_node = next(iter(prepack_node.users)) 483 module_path, _ = node_name_to_scope[op_node.name] 484 get_new_packed_weight_name = get_new_attr_name_with_prefix( 485 module_path + "_packed_weight_" 486 ) 487 packed_weight_name = get_new_packed_weight_name(quantized_model) 488 setattr(quantized_model, packed_weight_name, packed_weight) 489 # replace prepack node with a getattr node 490 env[node.name] = folded_graph.create_node( 491 "get_attr", packed_weight_name, (), {} 492 ) 493 elif prepack_node is not None: 494 # remove the foled node 495 continue 496 else: 497 # copy other nodes 498 env[node.name] = folded_graph.node_copy(node, load_arg) 499 500 quantized_model = GraphModule(quantized_model, folded_graph) 501 quantized_model._register_state_dict_hook(_save_packed_weight) 502 quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) 503 return quantized_model 504 505 506def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]: 507 """ 508 Return the `torch.nn.Module` that corresponds to the specified node's target. 509 If no such node exists, return None. 510 """ 511 if node.op == "call_module" and str(node.target) in modules: 512 return modules[str(node.target)] 513 else: 514 return None 515 516 517def _match_static_pattern( 518 node: Node, 519 modules: Dict[str, nn.Module], 520 qconfig_map: Dict[str, QConfigAny], 521 matching_modules_or_ops: List[Callable], 522 dequantize_node_arg_indices: List[int], 523) -> Union[Tuple[Node, Node, Node], Tuple[None, None, None]]: 524 """ 525 Match the pattern (dequantize - ref node - quantize) against the node provided. 526 527 If there is a match, return a 3-tuple of: 528 1) q_node: the quantize node, 529 2) relu_node: a relu node wrapping the ref_node, and 530 3) ref_node: a reference module or functional node to replace with its quantized counterpart 531 Otherwise, if there is no match, return a 3-tuple of (None, None, None). 532 533 Parameters: 534 node: The `torch.fx.Node` to match against. 535 modules: A mapping from node names to modules in the model graph, used for module lookup. 536 qconfig_map: A mapping from node names to the qconfigs associated with the nodes. 537 If the corresponding qconfig for the reference node is None, then return no match. 538 matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. 539 If the reference node is not in this list, then return no match. 540 dequantize_node_arg_indices: A list of indices in the reference node args where dequantize 541 nodes may be present. An empty list means skipping the check for dequantize nodes. 542 """ 543 SKIP_LOWERING_VALUE = (None, None, None) 544 545 # Match quantize node 546 if node.op != "call_function" or node.target != torch.quantize_per_tensor: 547 return SKIP_LOWERING_VALUE 548 q_node = node 549 ref_node = q_node.args[0] 550 assert isinstance(ref_node, Node) 551 552 # Handle cases where the node is wrapped in a ReLU 553 if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( 554 ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU 555 ): 556 relu_node = ref_node 557 ref_node = relu_node.args[0] 558 assert isinstance(ref_node, Node) 559 else: 560 relu_node = None 561 if should_skip_lowering(ref_node, qconfig_map): 562 return SKIP_LOWERING_VALUE 563 564 # Match reference module or functional 565 if isinstance(matching_modules_or_ops[0], type) and issubclass( 566 matching_modules_or_ops[0], nn.Module 567 ): 568 expected_op = "call_module" 569 match_key = type(_get_module(ref_node, modules)) 570 else: 571 expected_op = "call_function" 572 match_key = ref_node.target # type: ignore[assignment] 573 if ref_node.op != expected_op or match_key not in matching_modules_or_ops: 574 return SKIP_LOWERING_VALUE 575 576 # Match dequantize node(s). Both of the following conditions must pass: 577 # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node 578 # (2) There must be at least one dequantize node 579 matched_dequantize = False 580 for i in dequantize_node_arg_indices: 581 assert i < len( 582 ref_node.args 583 ), f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" 584 arg = ref_node.args[i] 585 if is_dequantize_node(arg): 586 matched_dequantize = True 587 elif isinstance(arg, Node): 588 return SKIP_LOWERING_VALUE 589 if not matched_dequantize: 590 return SKIP_LOWERING_VALUE 591 592 return (q_node, relu_node, ref_node) # type: ignore[return-value] 593 594 595def _match_static_pattern_with_two_inputs( 596 node: Node, 597 modules: Dict[str, nn.Module], 598 qconfig_map: Dict[str, QConfigAny], 599 matching_modules_or_ops: List[Callable], 600) -> Union[Tuple[Node, Node], Tuple[None, None]]: 601 """ 602 (dequantize \ 603 Match the pattern (dequantize - ref node - quantize) against the node provided. 604 605 If there is a match, return a 2-tuple of: 606 1) q_node: the quantize node, 607 2) ref_node: a reference module or functional node to replace with its quantized counterpart 608 Otherwise, if there is no match, return a 2-tuple of (None, None). 609 610 Parameters: 611 node: The `torch.fx.Node` to match against. 612 modules: A mapping from node names to modules in the model graph, used for module lookup. 613 qconfig_map: A mapping from node names to the qconfigs associated with the nodes. 614 If the corresponding qconfig for the reference node is None, then return no match. 615 matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. 616 If the reference node is not in this list, then return no match. 617 """ 618 SKIP_LOWERING_VALUE = (None, None) 619 620 # Match quantize node 621 if node.op != "call_function" or node.target != torch.quantize_per_tensor: 622 return SKIP_LOWERING_VALUE 623 q_node = node 624 ref_node = q_node.args[0] 625 assert isinstance(ref_node, Node) 626 627 if should_skip_lowering(ref_node, qconfig_map): 628 return SKIP_LOWERING_VALUE 629 630 # Match reference module or functional 631 if isinstance(matching_modules_or_ops[0], type) and issubclass( 632 matching_modules_or_ops[0], nn.Module 633 ): 634 expected_op = "call_module" 635 match_key = type(_get_module(ref_node, modules)) 636 else: 637 # This pass only support op of "call_module" 638 return SKIP_LOWERING_VALUE 639 640 if ref_node.op != expected_op or match_key not in matching_modules_or_ops: 641 return SKIP_LOWERING_VALUE 642 643 # Check ref_node has 2 input nodes, both are dq node. 644 if len(ref_node.args) != 2: 645 return SKIP_LOWERING_VALUE 646 for i in range(len(ref_node.args)): 647 arg = ref_node.args[i] 648 if not is_dequantize_node(arg): 649 return SKIP_LOWERING_VALUE 650 651 return (q_node, ref_node) 652 653 654def _lower_static_weighted_ref_module( 655 model: GraphModule, qconfig_map: Dict[str, QConfigAny] 656): 657 """ 658 Traverse the graph and find dequantize - ref module - quantize patterns 659 and replace them with the quantized version of the ref module. 660 """ 661 modules = dict(model.named_modules(remove_duplicate=False)) 662 nodes = list(model.graph.nodes) 663 for n in model.graph.nodes: 664 # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) 665 matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list( 666 STATIC_LOWER_FUSED_MODULE_MAP.keys() 667 ) 668 (q_node, relu_node, ref_node) = _match_static_pattern( 669 n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0] # type: ignore[arg-type] 670 ) 671 if q_node is None: 672 continue 673 assert ref_node is not None 674 (_, scale_node, zero_point_node, _) = q_node.args 675 ref_module = _get_module(ref_node, modules) 676 ref_class = type(ref_module) 677 assert isinstance(scale_node, Node) 678 assert isinstance(zero_point_node, Node) 679 assert issubclass(ref_class, nn.Module) 680 681 # Step 1: Change this pattern to use the corresponding quantized module 682 # For fused modules, we also check whether the inner module is a reference module 683 # If so, we replace the entire fused module with the corresponding quantized module 684 if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: 685 inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] 686 if type(ref_module[0]) != inner_ref_class: # type: ignore[index] 687 continue 688 else: 689 q_class = STATIC_LOWER_MODULE_MAP[ref_class] 690 output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] 691 output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] 692 q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) 693 # replace reference module with quantized module 694 parent_name, module_name = _parent_name(ref_node.target) 695 setattr(modules[parent_name], module_name, q_module) 696 697 # Step 2: Reroute around dq_node, and remove q_node and its args 698 assert len(ref_node.args) == 1 699 dq_node = ref_node.args[0] 700 assert isinstance(dq_node, Node) 701 ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] 702 q_node.replace_all_uses_with(ref_node) 703 model.graph.erase_node(q_node) 704 model.graph.erase_node(scale_node) 705 model.graph.erase_node(zero_point_node) 706 707 708def _lower_static_weighted_ref_module_with_two_inputs( 709 model: GraphModule, qconfig_map: Dict[str, QConfigAny] 710): 711 """ 712 Traverse the graph and find patterns 713 dequantize dequantize 714 \\ // 715 ref module 716 \\ 717 quantize 718 and replace them with the quantized version of the ref module. 719 """ 720 modules = dict(model.named_modules(remove_duplicate=False)) 721 nodes = list(model.graph.nodes) 722 for n in model.graph.nodes: 723 # (dequantize \ 724 # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) 725 matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) 726 (q_node, ref_node) = _match_static_pattern_with_two_inputs( 727 n, modules, qconfig_map, matching_modules # type: ignore[arg-type] 728 ) 729 if q_node is None: 730 continue 731 assert ref_node is not None 732 (_, scale_node, zero_point_node, _) = q_node.args 733 ref_module = _get_module(ref_node, modules) 734 ref_class = type(ref_module) 735 assert isinstance(scale_node, Node) 736 assert isinstance(zero_point_node, Node) 737 assert issubclass(ref_class, nn.Module) 738 739 # Step 1: Change this pattern to use the corresponding quantized module 740 # For fused modules, we also check whether the inner module is a reference module 741 # If so, we replace the entire fused module with the corresponding quantized module 742 if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: 743 inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ 744 ref_class 745 ] 746 if type(ref_module[0]) != inner_ref_class: # type: ignore[index] 747 continue 748 else: 749 continue 750 output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] 751 output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] 752 q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) 753 # replace reference module with quantized module 754 parent_name, module_name = _parent_name(ref_node.target) 755 setattr(modules[parent_name], module_name, q_module) 756 757 # Step 2: Reroute around dq_node, and remove q_node and its args 758 assert len(ref_node.args) == 2 759 for arg in ref_node.args: 760 if not is_dequantize_node(arg): 761 continue 762 dq_node = arg 763 assert isinstance(dq_node, Node) 764 ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] 765 766 q_node.replace_all_uses_with(ref_node) 767 model.graph.erase_node(q_node) 768 model.graph.erase_node(scale_node) 769 model.graph.erase_node(zero_point_node) 770 771 772def _lower_dynamic_weighted_ref_module(model: GraphModule): 773 """ 774 Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns 775 and replace them with the dynamically quantized version of the ref module. 776 """ 777 named_modules = dict(model.named_modules(remove_duplicate=False)) 778 for n in model.graph.nodes: 779 if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( 780 DYNAMIC_LOWER_MODULE_MAP.keys() 781 ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())): 782 continue 783 ref_node = n 784 dq_node = ref_node.args[0] 785 if dq_node.op != "call_method" or dq_node.target != "dequantize": 786 continue 787 788 input_dynamic_q_node = dq_node.args[0] 789 790 if ( 791 input_dynamic_q_node.op != "call_function" 792 or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic 793 ): 794 continue 795 796 activation_dtype = input_dynamic_q_node.args[1] 797 is_fp16 = activation_dtype == torch.float16 798 is_int8 = activation_dtype in [torch.quint8, torch.qint8] 799 if not is_int8 and not is_fp16: 800 continue 801 802 ref_module = named_modules[str(ref_node.target)] 803 ref_class = type(ref_module) 804 if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: 805 inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] 806 if type(ref_module[0]) != inner_ref_class: 807 continue 808 else: 809 q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] 810 # TODO: maybe define a WeightedDynamicallyQuantizedModule 811 q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined] 812 813 # replace reference module with dynamically quantized module 814 parent_name, module_name = _parent_name(ref_node.target) 815 setattr(named_modules[parent_name], module_name, q_module) 816 ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0]) 817 818 819def _lower_weight_only_weighted_ref_module(model: GraphModule): 820 """ 821 Traverse the graph and find ref_module patterns 822 and replace them with the weight only quantized version of the ref module. 823 """ 824 named_modules = dict(model.named_modules(remove_duplicate=False)) 825 for n in model.graph.nodes: 826 if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( 827 WEIGHT_ONLY_LOWER_MODULE_MAP.keys() 828 ): 829 continue 830 ref_node = n 831 ref_module = named_modules[str(ref_node.target)] 832 ref_class = type(ref_module) 833 q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class) 834 # TODO: WeightedQuantizedModule is currently assuming static quant apis 835 # with output_scale, output_zero_point in from_reference, we may want to 836 # relax that, or rename this 837 # TODO: maybe define a WeightedWeightOnlyQuantizedModule 838 q_module = q_class.from_reference(ref_module) # type: ignore[union-attr] 839 840 # replace reference module with dynamically quantized module 841 parent_name, module_name = _parent_name(ref_node.target) 842 setattr(named_modules[parent_name], module_name, q_module) 843 844 845def _lower_static_weighted_ref_functional( 846 model: GraphModule, qconfig_map: Dict[str, QConfigAny] 847): 848 """ 849 Traverse the graph and replace functional reference patterns with their quantized versions. 850 """ 851 modules = dict(model.named_modules(remove_duplicate=False)) 852 nodes = list(model.graph.nodes) 853 for n in model.graph.nodes: 854 # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize) 855 matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys()) 856 (q_node, relu_node, func_node) = _match_static_pattern( 857 n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1] 858 ) 859 if q_node is None: 860 continue 861 assert func_node is not None 862 (_, output_scale_node, output_zp_node, _) = q_node.args 863 (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args 864 assert isinstance(output_zp_node, Node) 865 assert isinstance(input_dq_node, Node) 866 assert isinstance(weight_dq_node, Node) 867 quantized_weight = weight_dq_node.args[0] 868 assert isinstance(quantized_weight, Node) 869 if quantized_weight.op != "call_function" or quantized_weight.target not in ( 870 torch.quantize_per_tensor, 871 torch.quantize_per_channel, 872 ): 873 continue 874 875 # Step 1: Replace quantized weights with packed weights, which will be folded later 876 # Use the right prepack op and prepare the corresponding args 877 # Linear prepack args: (quantized weights[, bias]) 878 # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) 879 prepack_args = [quantized_weight] + remaining_func_args 880 if func_node.target == F.linear: 881 weight_dtype = quantized_weight.args[-1] 882 prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) 883 elif func_node.target in CONV_FUNCTIONAL_OPS: 884 prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] 885 # For conv1d, the stride, padding, and dilation args may be ints, 886 # in which case we need to convert them to tuples 887 if func_node.target == F.conv1d: 888 for i in [2, 3, 4]: 889 if len(prepack_args) > i and isinstance(prepack_args[i], int): 890 prepack_args[i] = (prepack_args[i],) 891 elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS: 892 prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] 893 # For conv_transpose1d, the stride, padding, and dilation args may be ints, 894 # in which case we need to convert them to tuples 895 if func_node.target == F.conv_transpose1d: 896 # Note prepack_args[5] is groups. 897 for i in [2, 3, 4, 6]: 898 if len(prepack_args) > i and isinstance(prepack_args[i], int): 899 prepack_args[i] = (prepack_args[i],) 900 # swap dilation and groups 901 # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups} 902 # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation} 903 if len(prepack_args) > 6: 904 prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5] 905 else: 906 raise ValueError(f"Lowering is not supported for op '{func_node.target}'") 907 with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type] 908 # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack) 909 # They are not needed for compute op (i.e., quantized::linear) 910 kwargs = func_node.kwargs 911 # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias 912 if func_node.target == F.linear and "bias" in kwargs: 913 kwargs = kwargs.copy() 914 kwargs["B"] = kwargs["bias"] 915 del kwargs["bias"] 916 packed_weight = model.graph.create_node( 917 "call_function", prepack_op, tuple(prepack_args), kwargs 918 ) 919 920 # Step 2: Replace reference pattern with the corresponding quantized op 921 (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target] # type: ignore[index] 922 # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases 923 if q_relu_func is not None: 924 func_node.target = q_relu_func if relu_node is not None else q_func 925 else: 926 func_node.target = q_func 927 func_node.args = ( 928 input_dq_node.args[0], 929 packed_weight, 930 output_scale_node, 931 output_zp_node, 932 ) 933 # kwargs for func_node has been moved to kwargs for prepack op 934 func_node.kwargs = {} 935 q_node.replace_all_uses_with(func_node) 936 # Move func_node after output_zp_node in the graph 937 output_zp_node.append(func_node) 938 939 # Clean up: Remove quantize node, and the relu node if it exists 940 model.graph.erase_node(q_node) 941 if relu_node is not None and q_relu_func is not None: 942 model.graph.erase_node(relu_node) 943 944 945def _lower_dynamic_weighted_ref_functional( 946 model: GraphModule, qconfig_map: Dict[str, QConfigAny] 947): 948 """ 949 Traverse the graph and replace functional reference patterns with their dynamically 950 quantized versions. 951 Examples: 952 quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic 953 to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16 954 """ 955 modules = dict(model.named_modules(remove_duplicate=False)) 956 nodes = list(model.graph.nodes) 957 # we want to search in reserved order so that we can match the larger patterns first 958 # e.g. we want to match linear - relu before linear. 959 for n in reversed(model.graph.nodes): 960 # Step 0: Find nodes that match this pattern 961 # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op) 962 # We search for the pattern backwards, starting with the quantize node 963 # Quantize node args: (func, scale, zp, dtype) 964 func_node = n 965 # Handle cases where the functional op is wrapped in a ReLU 966 if ( 967 func_node.op == "call_function" 968 and func_node.target == F.relu 969 or func_node.op == "call_module" 970 and type(modules[str(func_node.target)]) == torch.nn.ReLU 971 ): 972 relu_node = func_node 973 func_node = relu_node.args[0] 974 else: 975 relu_node = None 976 if should_skip_lowering(func_node, qconfig_map): 977 continue 978 # Linear args: (dequantized inputs, dequantized weights[, bias]) 979 # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups]) 980 if ( 981 func_node.op != "call_function" 982 or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP 983 ): 984 continue 985 (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args 986 if ( 987 input_dq_node.op != "call_method" 988 or input_dq_node.target != "dequantize" 989 or weight_dq_node.op != "call_method" 990 or weight_dq_node.target != "dequantize" 991 ): 992 continue 993 994 input_dynamic_q_node = input_dq_node.args[0] 995 996 if ( 997 input_dynamic_q_node.op != "call_function" 998 or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic 999 ): 1000 continue 1001 1002 reduce_range_node = None 1003 (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args 1004 is_fp16 = activation_dtype == torch.float16 1005 is_int8 = activation_dtype in [torch.quint8, torch.qint8] 1006 if not is_int8 and not is_fp16: 1007 continue 1008 1009 quantized_weight = weight_dq_node.args[0] 1010 weight_dtype = quantized_weight.args[-1] 1011 1012 # Step 1: Try to select reference pattern with the corresponding quantized op 1013 dynamic_quant_dtype_key = (activation_dtype, weight_dtype) 1014 if ( 1015 dynamic_quant_dtype_key 1016 not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target] 1017 ): 1018 print( 1019 f"Didn't find dtype combination {dynamic_quant_dtype_key} during " 1020 f"dynamic quantized op lowering for {func_node.target}" 1021 ) 1022 continue 1023 (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][ 1024 dynamic_quant_dtype_key 1025 ] 1026 1027 if q_func is None or q_relu_func is None: 1028 print( 1029 "Didn't find corresponding quantized function or quantized relu function " 1030 f"for {func_node.target}, {dynamic_quant_dtype_key}" 1031 ) 1032 continue 1033 1034 # Step 2: Replace quantized weights with packed weights, which will be folded later 1035 # Use the right prepack op and prepare the corresponding args 1036 # Linear prepack args: (quantized weights[, bias]) 1037 # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) 1038 prepack_args = [quantized_weight] + remaining_func_args 1039 prepack_kwargs = {} 1040 if func_node.target == F.linear: 1041 prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) 1042 kwargs = func_node.kwargs.copy() 1043 if "bias" in kwargs: 1044 prepack_kwargs["B"] = kwargs["bias"] 1045 del kwargs["bias"] 1046 func_node.kwargs = kwargs 1047 elif func_node.target in CONV_FUNCTIONAL_OPS: 1048 prepack_op = get_qconv_prepack_op(func_node.target) 1049 # For conv1d, the stride, padding, and dilation args may be ints, 1050 # in which case we need to convert them to tuples 1051 if func_node.target == F.conv1d: 1052 for i in [2, 3, 4]: 1053 if len(prepack_args) > i and isinstance(prepack_args[i], int): 1054 prepack_args[i] = (prepack_args[i],) 1055 else: 1056 raise ValueError(f"Lowering is not supported for op '{func_node.target}'") 1057 with model.graph.inserting_before(func_node): 1058 packed_weight = model.graph.create_node( 1059 "call_function", prepack_op, tuple(prepack_args), prepack_kwargs 1060 ) 1061 1062 # Step 3: Replace reference pattern with the corresponding quantized op 1063 func_node.target = q_relu_func if relu_node is not None else q_func 1064 if is_int8: 1065 func_node.args = (pattern_input, packed_weight, reduce_range_node) 1066 else: 1067 func_node.args = (pattern_input, packed_weight) 1068 1069 if relu_node is not None: 1070 relu_node.replace_all_uses_with(func_node) 1071 1072 # Step 4: Remove the relu node if it exists 1073 if relu_node is not None: 1074 model.graph.erase_node(relu_node) 1075 1076 1077def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfigAny]): 1078 binary_ops_to_lower: List[Callable] = [ 1079 operator.add, 1080 torch.add, 1081 operator.mul, 1082 torch.mul, 1083 torch.matmul, 1084 ] 1085 modules = dict(model.named_modules(remove_duplicate=False)) 1086 for n in model.graph.nodes: 1087 # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) 1088 (q_node, relu_node, bop_node) = _match_static_pattern( 1089 n, 1090 modules, 1091 qconfig_map, 1092 binary_ops_to_lower, 1093 dequantize_node_arg_indices=[0, 1], 1094 ) 1095 if q_node is None: 1096 continue 1097 assert bop_node is not None 1098 (_, scale_node, zero_point_node, _) = q_node.args 1099 1100 # Step 1: Remove dequant nodes 1101 num_dq_nodes = 0 1102 for arg in bop_node.args: 1103 if not is_dequantize_node(arg): 1104 continue 1105 dq_node = arg 1106 assert isinstance(dq_node, Node) 1107 dn_input = dq_node.args[0] 1108 bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] 1109 num_dq_nodes += 1 1110 assert num_dq_nodes > 0 1111 1112 # Step 2: Swap binary op to quantized binary op 1113 assert bop_node.target in QBIN_OP_MAPPING 1114 binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING 1115 qbin_op = binop_to_qbinop[bop_node.target] 1116 # prepare the args for quantized binary op 1117 # (x, y) 1118 qop_node_args = list(bop_node.args) 1119 # (x, y, scale, zero_point) 1120 # add scale and zero_point arguments for Tensor - Tensor operation 1121 if num_dq_nodes == 2: 1122 qop_node_args.extend([scale_node, zero_point_node]) 1123 # insert a call to quantized binary op and remove the original binary op 1124 with model.graph.inserting_after(q_node): 1125 qop_node = create_node_from_old_node_preserve_meta( 1126 model.graph, 1127 ("call_function", qbin_op, tuple(qop_node_args), {}), 1128 bop_node, 1129 ) 1130 q_node.replace_all_uses_with(qop_node) 1131 1132 # Step 3: Remove quantize node, binary op node, and relu node if any 1133 model.graph.erase_node(q_node) 1134 if relu_node is not None: 1135 model.graph.erase_node(relu_node) 1136 model.graph.erase_node(bop_node) 1137 1138 1139def special_pattern_replacement(model: GraphModule): 1140 modules = dict(model.named_modules(remove_duplicate=False)) 1141 for n in model.graph.nodes: 1142 q_node = n 1143 is_quantize = q_node.target == torch.quantize_per_tensor 1144 is_to_fp16 = ( 1145 q_node.op == "call_method" 1146 and q_node.target == "to" 1147 and len(q_node.args) == 2 1148 and q_node.args[1] == torch.float16 1149 ) 1150 if not (is_quantize or is_to_fp16): 1151 continue 1152 ref_node = q_node.args[0] 1153 # get output scale/zero_point/dtype from the quantize node 1154 # ref_node, scale_node, zero_point_node, dtype = q_node.args 1155 # TODO: add safety checks that users for the ref_node and dq_node needs to be one 1156 is_call_function, is_call_method, is_call_module = is_fixed_qparams_node( 1157 ref_node, modules 1158 ) 1159 if is_to_fp16 and (is_call_function or is_call_method or is_call_module): 1160 # TODO: add a warning or error out here? (bc-breaking if error out) 1161 # warnings.warn( 1162 # "Only reference patterns are currently supported for {dtype} dtype with {op} op" 1163 # "".format(dtype=dtypes, op=ref_node)) 1164 continue 1165 1166 is_call_function, is_call_method, is_call_module = is_default_node( 1167 ref_node, modules 1168 ) 1169 if is_to_fp16 and (is_call_function or is_call_method or is_call_module): 1170 # TODO: add a warning or error out here? (bc-breaking if error out) 1171 continue 1172 1173 # This check includes all supported ops 1174 is_call_function, is_call_method, is_call_module = is_special_pattern_node( 1175 ref_node, modules 1176 ) 1177 if not (is_call_module or is_call_function or is_call_method): 1178 continue 1179 assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 1180 dq_node_or_nodes = ( 1181 ref_node.args[0] 1182 if len(ref_node.args) > 0 1183 else next(iter(ref_node.kwargs.values())) 1184 ) 1185 assert isinstance(dq_node_or_nodes, (Node, tuple, list)) 1186 is_dequantize = False 1187 if isinstance(dq_node_or_nodes, Node): 1188 is_dequantize = ( 1189 dq_node_or_nodes.op == "call_method" 1190 and dq_node_or_nodes.target == "dequantize" 1191 ) 1192 elif isinstance(dq_node_or_nodes, (tuple, list)): 1193 is_dequantize = all( 1194 x.op == "call_method" and x.target == "dequantize" 1195 for x in dq_node_or_nodes 1196 ) 1197 1198 if not is_dequantize: 1199 continue 1200 1201 # TODO: enable we have patterns that needs to swap the modules 1202 if is_call_module: 1203 ref_module = modules[ref_node.target] 1204 if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize: 1205 qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module)) 1206 scale_node = q_node.args[1] 1207 zero_point_node = q_node.args[2] 1208 output_scale = getattr(model, scale_node.target) 1209 output_zero_point = getattr(model, zero_point_node.target) 1210 1211 qmodule = qmodule_cls.from_reference( # type:ignore[union-attr] 1212 ref_module, output_scale, output_zero_point 1213 ) 1214 # replace reference module with quantized module 1215 parent_name, module_name = _parent_name(ref_node.target) 1216 setattr(modules[parent_name], module_name, qmodule) 1217 1218 # reroute around dq node: 1219 dq_nodes: List[Node] = [] 1220 if isinstance(dq_node_or_nodes, Node): 1221 dq_nodes = [dq_node_or_nodes] 1222 elif isinstance(dq_node_or_nodes, (tuple, list)): 1223 dq_nodes = list(dq_node_or_nodes) 1224 1225 for dq_node in dq_nodes: 1226 dn_input = dq_node.args[0] 1227 ref_node.replace_input_with(dq_node, dn_input) 1228 1229 # store q node args 1230 qnode_qparams = list(q_node.args)[1:] 1231 # replace uses of q node with input and remove q node 1232 q_node_input = q_node.args[0] 1233 q_node.replace_all_uses_with(q_node_input) 1234 model.graph.erase_node(q_node) 1235 1236 is_call_function, is_call_method, is_call_module = is_default_node( 1237 ref_node, modules 1238 ) 1239 if is_call_function: 1240 # pass scale/zer_point arguments from quantize_per_tensor to the default node operator 1241 # insert an op after the zero_point node so that the scale/zero_point 1242 # nodes are is available 1243 qop = get_quantized_operator(ref_node.target) 1244 args = list(ref_node.args) 1245 kwargs = dict(ref_node.kwargs) 1246 if qop in QOP_TO_ARG_NAMES_TO_SKIP: 1247 args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop] 1248 for arg in args_to_skip: 1249 if arg in kwargs: 1250 kwargs.pop(arg) 1251 kwargs["output_scale"] = qnode_qparams[0] 1252 kwargs["output_zero_point"] = qnode_qparams[1] 1253 with model.graph.inserting_after(qnode_qparams[1]): 1254 qop_node = create_node_from_old_node_preserve_meta( 1255 model.graph, ("call_function", qop, tuple(args), kwargs), ref_node 1256 ) 1257 ref_node.replace_all_uses_with(qop_node) 1258 model.graph.erase_node(ref_node) 1259 else: 1260 # remove scale/zero_point node for quantize node 1261 for n in qnode_qparams: 1262 if isinstance(n, Node): 1263 model.graph.erase_node(n) 1264 1265 return model 1266 1267 1268def _lower_getattr_tensor_metadta_op(model: GraphModule): 1269 """Modified the graph of the model inplace, to skip extra dequantize op before 1270 the general tensor shape ops when possible 1271 """ 1272 for n in model.graph.nodes: 1273 if is_getattr_tensor_metadata_node(n): 1274 maybe_dq = n.args[0] 1275 if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": 1276 continue 1277 # skip the dequantize node 1278 args = list(n.args) 1279 args[0] = n.args[0].args[0] 1280 n.args = tuple(args) 1281 1282 1283def _lower_get_tensor_info_op(model: GraphModule): 1284 """Modified the graph of the model inplace, to skip extra dequantize op before 1285 the general tensor shape ops when possible 1286 """ 1287 for n in model.graph.nodes: 1288 if not is_get_tensor_info_node(n): 1289 continue 1290 maybe_dq = n.args[0] 1291 if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": 1292 continue 1293 # skip the dequantize node 1294 args = list(n.args) 1295 args[0] = n.args[0].args[0] 1296 n.args = tuple(args) 1297 1298 1299def _lower_to_native_backend( 1300 model: GraphModule, 1301 qconfig_map: Dict[str, QConfigAny], 1302 node_name_to_scope: Dict[str, Tuple[str, type]], 1303) -> GraphModule: 1304 """Lower a quantized reference model (with reference quantized operator patterns) 1305 to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same 1306 operator signature so they can be lowered with the same function 1307 """ 1308 _lower_static_weighted_ref_module(model, qconfig_map) 1309 _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map) 1310 _lower_dynamic_weighted_ref_module(model) 1311 _lower_weight_only_weighted_ref_module(model) 1312 _lower_static_weighted_ref_functional(model, qconfig_map) 1313 _lower_dynamic_weighted_ref_functional(model, qconfig_map) 1314 _lower_quantized_binary_op(model, qconfig_map) 1315 _lower_getattr_tensor_metadta_op(model) 1316 _lower_get_tensor_info_op(model) 1317 special_pattern_replacement(model) 1318 model.graph.eliminate_dead_code() 1319 model = fold_weight(model, node_name_to_scope) 1320 model.graph.eliminate_dead_code() 1321 model.recompile() 1322 model.graph.lint() 1323 return model 1324