1# mypy: allow-untyped-defs 2import copy 3import operator 4import warnings 5from collections import namedtuple 6from dataclasses import dataclass 7from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union 8 9import torch 10import torch.nn as nn 11from torch.ao.quantization import QConfigAny, QuantType 12from torch.ao.quantization.backend_config import DTypeWithConstraints 13from torch.ao.quantization.fake_quantize import ( 14 FakeQuantizeBase, 15 FixedQParamsFakeQuantize, 16) 17from torch.ao.quantization.observer import ( 18 _is_activation_post_process, 19 FixedQParamsObserver, 20 ObserverBase, 21) 22from torch.ao.quantization.qconfig import ( 23 float16_dynamic_qconfig, 24 float16_static_qconfig, 25 qconfig_equals, 26) 27from torch.ao.quantization.qconfig_mapping import QConfigMapping 28from torch.ao.quantization.stubs import DeQuantStub 29from torch.ao.quantization.utils import ( 30 _assert_and_get_unique_device, 31 activation_is_statically_quantized, 32) 33from torch.fx import GraphModule, map_arg 34from torch.fx.graph import Graph, Node 35 36# importing the lib so that the quantized_decomposed ops are registered 37from ._decomposed import quantized_decomposed_lib # noqa: F401 38from .custom_config import PrepareCustomConfig 39 40 41# TODO: revisit this list. Many helper methods shouldn't be public 42__all__ = [ 43 "all_node_args_except_first", 44 "all_node_args_have_no_tensors", 45 "assert_and_get_unique_device", 46 "collect_producer_nodes", 47 "create_getattr_from_value", 48 "create_node_from_old_node_preserve_meta", 49 "EMPTY_ARG_DICT", 50 "get_custom_module_class_keys", 51 "get_linear_prepack_op_for_dtype", 52 "get_new_attr_name_with_prefix", 53 "get_non_observable_arg_indexes_and_types", 54 "get_qconv_prepack_op", 55 "get_skipped_module_name_and_classes", 56 "graph_module_from_producer_nodes", 57 "maybe_get_next_module", 58 "NodeInfo", 59 "node_arg_is_bias", 60 "node_arg_is_weight", 61 "NON_OBSERVABLE_ARG_DICT", 62 "NON_QUANTIZABLE_WEIGHT_OPS", 63 "return_arg_list", 64 "ObservedGraphModuleAttrs", 65] 66 67NON_QUANTIZABLE_WEIGHT_OPS = { 68 torch.nn.functional.layer_norm, 69 torch.nn.functional.group_norm, 70 torch.nn.functional.instance_norm, 71} 72 73 74@dataclass 75class ObservedGraphModuleAttrs: 76 node_name_to_qconfig: Dict[str, QConfigAny] 77 node_name_to_scope: Dict[str, Tuple[str, type]] 78 prepare_custom_config: PrepareCustomConfig 79 equalization_node_name_to_qconfig: Dict[str, Any] 80 qconfig_mapping: QConfigMapping 81 is_qat: bool 82 observed_node_names: Set[str] 83 is_observed_standalone_module: bool = False 84 standalone_module_input_quantized_idxs: Optional[List[int]] = None 85 standalone_module_output_quantized_idxs: Optional[List[int]] = None 86 87 88def node_arg_is_weight(node: Node, arg: Any) -> bool: 89 """Returns if node arg is weight""" 90 weight_index = None 91 if "target_dtype_info" in node.meta: 92 weight_index = node.meta["target_dtype_info"].get("weight_index", None) 93 if ( 94 weight_index is not None 95 and weight_index < len(node.args) 96 and node.args[weight_index] is arg 97 ): 98 return True 99 return node.kwargs.get("weight") is arg 100 101 102def node_arg_is_bias(node: Node, arg: Any) -> bool: 103 """Returns if node arg is bias""" 104 bias_index = None 105 if "target_dtype_info" in node.meta: 106 bias_index = node.meta["target_dtype_info"].get("bias_index", None) 107 if ( 108 bias_index is not None 109 and bias_index < len(node.args) 110 and node.args[bias_index] is arg 111 ): 112 return True 113 return node.kwargs.get("bias") is arg 114 115 116def get_custom_module_class_keys( 117 custom_module_mapping: Dict[QuantType, Dict[Type, Type]] 118) -> List[Any]: 119 r"""Get all the unique custom module keys in the custom config dict 120 e.g. 121 Input: 122 { 123 QuantType.STATIC: { 124 CustomModule1: ObservedCustomModule 125 }, 126 QuantType.DYNAMIC: { 127 CustomModule2: DynamicObservedCustomModule 128 }, 129 QuantType.WEIGHT_ONLY: { 130 CustomModule3: WeightOnlyObservedCustomModule 131 }, 132 } 133 134 Output: 135 # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts 136 [CustomModule1, CustomModule2, CustomModule3] 137 """ 138 # using set to dedup 139 float_custom_module_classes: Set[Any] = set() 140 for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: 141 quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) 142 quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) 143 float_custom_module_classes |= quant_mode_custom_module_classes 144 return list(float_custom_module_classes) 145 146 147def get_linear_prepack_op_for_dtype(dtype): 148 if dtype == torch.float16: 149 return torch.ops.quantized.linear_prepack_fp16 150 elif dtype == torch.qint8: 151 return torch.ops.quantized.linear_prepack 152 else: 153 raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002 154 155 156def get_qconv_prepack_op(conv_op: Callable) -> Callable: 157 prepack_ops = { 158 torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, 159 torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, 160 torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, 161 torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, 162 torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, 163 torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, 164 } 165 prepack_op = prepack_ops.get(conv_op, None) 166 assert prepack_op, f"Didn't find prepack op for {conv_op}" 167 return prepack_op 168 169 170# Returns a function that can get a new attribute name for module with given 171# prefix, for example, 172# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') 173# >> new_name = get_new_observer_name(module) 174# new_name will be an unused attribute name on module, e.g. `_observer_1` 175def get_new_attr_name_with_prefix(prefix: str) -> Callable: 176 prefix = prefix.replace(".", "_") 177 178 def get_new_attr_name(module: torch.nn.Module): 179 def get_attr_name(i: int): 180 return prefix + str(i) 181 182 i = 0 183 attr_name = get_attr_name(i) 184 while hasattr(module, attr_name): 185 i += 1 186 attr_name = get_attr_name(i) 187 return attr_name 188 189 return get_new_attr_name 190 191 192def collect_producer_nodes(node: Node) -> Optional[List[Node]]: 193 r"""Starting from a target node, trace back until we hit inpu or 194 getattr node. This is used to extract the chain of operators 195 starting from getattr to the target node, for example 196 def forward(self, x): 197 observed = self.observer(self.weight) 198 return F.linear(x, observed) 199 collect_producer_nodes(observed) will either return a list of nodes that 200 produces the observed node or None if we can't extract a self contained 201 graph without free variables(inputs of the forward function). 202 """ 203 nodes = [node] 204 frontier = [node] 205 while frontier: 206 node = frontier.pop() 207 all_args = list(node.args) + list(node.kwargs.values()) 208 for arg in all_args: 209 if not isinstance(arg, Node): 210 continue 211 if arg.op == "placeholder": 212 # hit input, can't fold in this case 213 return None 214 nodes.append(arg) 215 if not (arg.op == "call_function" and arg.target == getattr): 216 frontier.append(arg) 217 return nodes 218 219 220def graph_module_from_producer_nodes( 221 root: GraphModule, producer_nodes: List[Node] 222) -> GraphModule: 223 r"""Construct a graph module from extracted producer nodes 224 from `collect_producer_nodes` function 225 Args: 226 root: the root module for the original graph 227 producer_nodes: a list of nodes we use to construct the graph 228 Return: 229 A graph module constructed from the producer nodes 230 """ 231 assert len(producer_nodes) > 0, "list of producer nodes can not be empty" 232 # since we traced back from node to getattr 233 producer_nodes.reverse() 234 graph = Graph() 235 env: Dict[Any, Any] = {} 236 237 def load_arg(a): 238 return map_arg(a, lambda node: env[node]) 239 240 for producer_node in producer_nodes: 241 env[producer_node] = graph.node_copy(producer_node, load_arg) 242 graph.output(load_arg(producer_nodes[-1])) 243 graph_module = GraphModule(root, graph) 244 return graph_module 245 246 247# TODO: delete 248def assert_and_get_unique_device(module: torch.nn.Module) -> Any: 249 """ 250 Returns the unique device for a module, or None if no device is found. 251 Throws an error if multiple devices are detected. 252 """ 253 return _assert_and_get_unique_device(module) 254 255 256def create_getattr_from_value( 257 module: torch.nn.Module, graph: Graph, prefix: str, value: Any 258) -> Node: 259 """ 260 Given a value of any type, creates a getattr node corresponding to the value and 261 registers the value as a buffer to the module. 262 """ 263 get_new_attr_name = get_new_attr_name_with_prefix(prefix) 264 attr_name = get_new_attr_name(module) 265 device = assert_and_get_unique_device(module) 266 new_value = ( 267 value.clone().detach() 268 if isinstance(value, torch.Tensor) 269 else torch.tensor(value, device=device) 270 ) 271 module.register_buffer(attr_name, new_value) 272 # Create get_attr with value 273 attr_node = graph.create_node("get_attr", attr_name) 274 return attr_node 275 276 277def all_node_args_have_no_tensors( 278 node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool] 279) -> bool: 280 """ 281 If we know for sure that all of this node's args have no 282 tensors (are primitives), return True. If we either 283 find a tensor or are not sure, return False. Note: this 284 function is not exact. 285 """ 286 if cache and node in cache: 287 return cache[node] 288 289 result = False # will be overwritten 290 if not isinstance(node, Node): 291 result = True 292 elif node.op == "placeholder": 293 result = False 294 elif node.op == "call_module": 295 assert isinstance(node.target, str) 296 if _is_activation_post_process(modules[node.target]): 297 result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] 298 elif node.op == "call_module": 299 result = False 300 elif node.op == "call_function" and node.target is operator.getitem: 301 result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] 302 elif node.op == "get_attr": 303 result = False 304 elif node.target is getattr and node.args[1] in ["ndim", "shape"]: 305 # x1 = x0.ndim 306 result = True 307 elif node.op == "call_method" and node.target == "size": 308 # x1 = x0.size(0) 309 result = True 310 else: 311 found_one_tensor = False 312 for arg in node.args: 313 if isinstance(arg, list): 314 for list_el in arg: 315 if isinstance(list_el, Node): 316 this_list_el_args_have_no_tensors = ( 317 all_node_args_have_no_tensors(list_el, modules, cache) 318 ) 319 found_one_tensor = found_one_tensor or ( 320 not this_list_el_args_have_no_tensors 321 ) 322 # If found_one_tensor is True, there is no point in 323 # recursing further as the end result will always 324 # be True. 325 # TODO(future PR): remove this entire function and 326 # change to dtype inference without recursion. 327 if found_one_tensor: 328 result = not found_one_tensor 329 if cache: 330 cache[node] = result 331 return result 332 elif isinstance(arg, int): 333 pass 334 else: 335 if isinstance(arg, Node): 336 this_arg_args_have_no_tensors = all_node_args_have_no_tensors( 337 arg, modules, cache 338 ) 339 found_one_tensor = found_one_tensor or ( 340 not this_arg_args_have_no_tensors 341 ) 342 # If found_one_tensor is True, there is no point in 343 # recursing further as the end result will always 344 # be True. 345 # TODO(future PR): remove this entire function and 346 # change to dtype inference without recursion. 347 if found_one_tensor: 348 result = not found_one_tensor 349 if cache: 350 cache[node] = result 351 return result 352 else: 353 found_one_tensor = True 354 result = not found_one_tensor 355 if cache: 356 cache[node] = result 357 return result 358 359 360def all_node_args_except_first(node: Node) -> List[int]: 361 """ 362 Returns all node arg indices after first 363 """ 364 return list(range(1, len(node.args))) 365 366 367def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]: 368 """ 369 Constructs a function that takes a node as arg and returns the arg_indices 370 that are valid for node.args 371 """ 372 373 def arg_indices_func(node: Node) -> List[int]: 374 return [i for i in arg_indices if i < len(node.args)] 375 376 return arg_indices_func 377 378 379NodeInfo = namedtuple("NodeInfo", "op target") 380 381# this dict identifies which indices of a node are non tensors 382# so that they can be propagated correctly since inserting observers 383# for them would cause errors 384 385NON_OBSERVABLE_ARG_DICT: Dict[ 386 NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] 387] = { 388 NodeInfo("call_method", "masked_fill"): { 389 torch.bool: return_arg_list([1]), 390 float: return_arg_list([2]), 391 }, 392 NodeInfo("call_method", "permute"): {int: all_node_args_except_first}, 393 NodeInfo("call_method", "repeat"): {int: all_node_args_except_first}, 394 NodeInfo("call_method", "reshape"): {int: all_node_args_except_first}, 395 NodeInfo("call_method", "size"): {int: return_arg_list([1])}, 396 NodeInfo("call_method", "transpose"): {int: all_node_args_except_first}, 397 NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first}, 398 NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])}, 399 NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])}, 400 NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])}, 401 NodeInfo("call_method", "view"): {int: all_node_args_except_first}, 402} 403 404EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {} 405 406 407def get_non_observable_arg_indexes_and_types( 408 node: Node, 409) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]: 410 """ 411 Returns a dict with of non float tensor types as keys and values which correspond to a 412 function to retrieve the list (which takes the node as an argument) 413 """ 414 info = NodeInfo(node.op, node.target) 415 416 return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) 417 418 419def maybe_get_next_module( 420 node: Node, 421 modules: Dict[str, nn.Module], 422 target_module_type: Optional[Type[nn.Module]] = None, 423 target_functional_type: Any = None, 424) -> Optional[Node]: 425 """Gets the next module that matches what is needed in 426 is_target_module_type if it exists 427 428 Args: 429 node: The node whose users we want to look at 430 target_module_type: Module type that we want to check 431 target_functional_type: Functional type that we want to check 432 """ 433 434 for user in node.users.keys(): 435 if ( 436 user.op == "call_module" 437 and target_module_type is not None 438 and isinstance(modules[str(user.target)], target_module_type) 439 ): 440 return user 441 elif ( 442 user.op == "call_function" 443 and target_functional_type is not None 444 and user.target == target_functional_type 445 ): 446 return user 447 448 return None 449 450 451def create_node_from_old_node_preserve_meta( 452 quantized_graph: Graph, 453 create_node_args: Tuple[Any, ...], 454 old_node: Node, 455) -> Node: 456 """ 457 Creates `new_node` and copies the necessary metadata to it from `old_node`. 458 """ 459 new_node = quantized_graph.create_node(*create_node_args) 460 new_node.stack_trace = old_node.stack_trace 461 return new_node 462 463 464def get_skipped_module_name_and_classes( 465 prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool 466) -> Tuple[List[str], List[Type[Any]]]: 467 skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) 468 skipped_module_classes = copy.copy( 469 prepare_custom_config.non_traceable_module_classes 470 ) 471 if not is_standalone_module: 472 # standalone module and custom module config are applied in top level module 473 skipped_module_names += list( 474 prepare_custom_config.standalone_module_names.keys() 475 ) 476 skipped_module_classes += list( 477 prepare_custom_config.standalone_module_classes.keys() 478 ) 479 skipped_module_classes += get_custom_module_class_keys( 480 prepare_custom_config.float_to_observed_mapping 481 ) 482 483 return skipped_module_names, skipped_module_classes 484 485 486def _is_custom_module_lstm( 487 node: Node, 488 named_modules: Dict[str, torch.nn.Module], 489 qconfig: QConfigAny = None, 490 # QuantizeHandler, but we cannot include the type here due to circular imports 491 qhandler: Optional[Any] = None, 492) -> bool: 493 """ 494 Return whether this refers to the custom module LSTM flow. 495 """ 496 mod = _get_module(node, named_modules) 497 if qconfig is not None and qhandler is not None: 498 assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] 499 return ( 500 isinstance(mod, torch.nn.LSTM) 501 and activation_is_statically_quantized(qconfig) 502 and qhandler.is_custom_module() 503 ) 504 else: 505 return isinstance(mod, torch.ao.nn.quantizable.LSTM) 506 507 508def _is_custom_module_mha( 509 node: Node, 510 named_modules: Dict[str, torch.nn.Module], 511 qconfig: QConfigAny = None, 512 # QuantizeHandler, but we cannot include the type here due to circular imports 513 qhandler: Optional[Any] = None, 514) -> bool: 515 """ 516 Return whether this refers to the custom module MultiheadAttention flow. 517 """ 518 mod = _get_module(node, named_modules) 519 if qconfig is not None and qhandler is not None: 520 assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] 521 return ( 522 isinstance(mod, torch.nn.MultiheadAttention) 523 and activation_is_statically_quantized(qconfig) 524 and qhandler.is_custom_module() 525 ) 526 else: 527 return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) 528 529 530def _get_module( 531 node: Node, named_modules: Dict[str, torch.nn.Module] 532) -> Optional[torch.nn.Module]: 533 """ 534 If `node` refers to a call_module node, return the module, else None. 535 """ 536 if node.op == "call_module" and str(node.target) in named_modules: 537 return named_modules[str(node.target)] 538 else: 539 return None 540 541 542def _insert_dequant_stub( 543 node: Node, 544 model: torch.nn.Module, 545 named_modules: Dict[str, torch.nn.Module], 546 graph: Graph, 547) -> Node: 548 """ 549 Attach a `DeQuantStub` to the model and create a node that calls this 550 `DeQuantStub` on the output of `node`, similar to how observers are inserted. 551 """ 552 prefix = "dequant_stub_" 553 get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) 554 dequant_stub_name = get_new_dequant_stub_name(model) 555 dequant_stub = DeQuantStub() 556 setattr(model, dequant_stub_name, dequant_stub) 557 named_modules[dequant_stub_name] = dequant_stub 558 with graph.inserting_after(node): 559 return graph.call_module(dequant_stub_name, (node,)) 560 561 562def _insert_dequant_stubs_for_custom_module_lstm_output( 563 node: Node, 564 model: torch.nn.Module, 565 named_modules: Dict[str, torch.nn.Module], 566 graph: Graph, 567) -> Node: 568 """ 569 Insert DeQuantStubs after each internal output node of custom module LSTM. 570 571 Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), 572 Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its 573 components through `getitem`. This function transforms the graph as follows: 574 575 (1) Split the LSTM node into (output, (hidden0, hidden1)) 576 (2) Insert a DeQuantStub after each internal node 577 (3) Recombine the DeQuantStubs into the same structure as before 578 (4) Reroute all consumers of the original LSTM node and its sub-nodes 579 (e.g. lstm[0]) 580 581 Before: 582 lstm_output 583 | 584 v 585 original_user(s) 586 After: 587 lstm_output 588 / \\ 589 / (getitem) \\ 590 / \\ 591 v v 592 output hidden 593 | / \\ 594 (DeQuantStub) (getitem) 595 | / \\ 596 v v v 597 output_dq hidden0 hidden1 598 | | | 599 | (DeQuantStub) (DeQuantStub) 600 | | | 601 | v v 602 | hidden0_dq hidden1_dq 603 | \\ / 604 | (tuple) 605 | \\ / 606 | v v 607 | hidden_dq 608 \\ / 609 \\ (tuple) / 610 v v 611 lstm_output_dq 612 | 613 v 614 original_user(s) 615 616 For step (4), reroute all users of the original LSTM node(s) as follows: 617 lstm_output -> lstm_output_dq 618 lstm_output[0] -> output_dq 619 lstm_output[1] -> hidden_dq 620 lstm_output[1][0] -> hidden0_dq 621 lstm_output[1][1] -> hidden1_dq 622 623 Return the node `lstm_output_dq`. 624 """ 625 # (1) Split the LSTM node into (output, (hidden0, hidden1)) 626 # (2) Insert a DeQuantStub after each internal node 627 with graph.inserting_after(node): 628 output = graph.call_function(operator.getitem, (node, 0)) 629 output_dq = _insert_dequant_stub(output, model, named_modules, graph) 630 with graph.inserting_after(output_dq): 631 hidden = graph.call_function(operator.getitem, (node, 1)) 632 with graph.inserting_after(hidden): 633 hidden0 = graph.call_function(operator.getitem, (hidden, 0)) 634 hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) 635 with graph.inserting_after(hidden0_dq): 636 hidden1 = graph.call_function(operator.getitem, (hidden, 1)) 637 hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) 638 639 # (3) Recombine the DeQuantStubs into the same structure as before 640 with graph.inserting_after(hidden1_dq): 641 hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) 642 with graph.inserting_after(hidden_dq): 643 lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) 644 645 # (4) Reroute all consumers of the original LSTM node and its sub-nodes 646 for user in list(node.users.keys()): 647 if user != output and user != hidden: 648 user.replace_input_with(node, lstm_output_dq) 649 # The getitem and tuple nodes we added here may interfere with reference quantized 650 # pattern matching, so we need to redirect the consumers of internal nodes to the 651 # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached, 652 # in order to preserve reference patterns like "dequantize - consumer - quantize". 653 _reroute_tuple_getitem_pattern(graph) 654 return lstm_output_dq 655 656 657def _maybe_get_custom_module_lstm_from_node_arg( 658 arg: Node, 659 named_modules: Dict[str, torch.nn.Module], 660) -> Optional[Node]: 661 """ 662 Given an argument of a node, if the argument refers to the path through which the node 663 is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. 664 665 This is used to determine whether a node is a consumer of custom module LSTM, and, if so, 666 skip inserting input observers for this node. This is because custom module LSTM produces 667 quantized outputs, so inserting an input observer for the consumer of custom module LSTM 668 would unnecessarily quantize the outputs again. 669 670 lstm -> consumer 671 672 In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with 673 DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). 674 This tuple can be consumed in one of four ways: 675 676 lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] 677 lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] 678 lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] 679 lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm 680 681 Thus, we must match against the above patterns instead of simply checking the parent node 682 to determine whether this node is a consumer of a custom module LSTM. 683 """ 684 685 def match_dq(a): 686 return isinstance(_get_module(a, named_modules), DeQuantStub) 687 688 def match_lstm(a): 689 return _is_custom_module_lstm(a, named_modules) 690 691 def match_getitem(a): 692 return a.op == "call_function" and a.target == operator.getitem 693 694 def match_tuple(a): 695 return a.op == "call_function" and a.target == tuple 696 697 def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]: 698 """ 699 Traverse up the graph and match the args one by one. 700 If there is a match, return the last matched node, or None otherwise. 701 """ 702 a = arg 703 for i, match in enumerate(match_pattern): 704 if not match(a): 705 return None 706 # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) 707 if i < len(match_pattern) - 1: 708 if match == match_tuple: 709 a = a.args[0][0] # type: ignore[assignment,index] 710 else: 711 a = a.args[0] # type: ignore[assignment] 712 return a 713 714 all_match_patterns = [ 715 [match_dq, match_getitem, match_lstm], 716 [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], 717 [match_dq, match_getitem, match_getitem, match_lstm], 718 [match_tuple, match_dq, match_getitem, match_lstm], 719 ] 720 721 for p in all_match_patterns: 722 matched_node = _match_pattern(p) 723 if matched_node is not None: 724 return matched_node 725 return None 726 727 728def _reroute_tuple_getitem_pattern(graph: Graph): 729 """ 730 Search for patterns where N consecutive `tuple` call_function nodes are followed by 731 N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. 732 If we find this pattern, reroute the consumers of the last `getitem` to skip these 733 N `tuple` and `getitem` nodes. 734 735 Before: 736 737 a b c 738 | \\ / 739 \\ tuple 740 \\ / 741 tuple 742 | 743 getitem(1) 744 | 745 getitem(0) 746 | 747 d 748 749 After: 750 751 b 752 | 753 d 754 """ 755 756 def find_patterns( 757 node: Node, 758 index_stack: List[int], 759 current_pattern: List[Node], 760 matched_patterns: List[List[Node]], 761 seen: Set[Tuple[Node, Tuple[int, ...]]], 762 ): 763 """ 764 Traverse the graph recursively to match for the N-tuple - N-getitem patterns, 765 starting at the given node. 766 767 We use a stack to keep track of the expected `getitem` indices, since these are 768 reversed from the `tuple` indices. In the above example, the stack after 769 (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first 770 and then by getitem(0). 771 772 TODO: traverse upwards from the output and handle the case when tuple is not a 773 separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) 774 """ 775 if len(index_stack) == 0 and len(current_pattern) > 0: 776 matched_patterns.append(copy.copy(current_pattern)) 777 current_pattern.clear() 778 779 # Avoid duplicating work 780 state = (node, tuple(index_stack)) 781 if state in seen: 782 return 783 seen.add(state) 784 785 # Iterate through users of this node to find tuple/getitem nodes to match 786 for user in node.users: 787 if user.op == "call_function" and user.target == tuple: 788 for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] 789 if user_arg == node: 790 index_stack.append(i) 791 current_pattern.append(user) 792 find_patterns( 793 user, index_stack, current_pattern, matched_patterns, seen 794 ) 795 elif user.op == "call_function" and user.target == operator.getitem: 796 if len(index_stack) > 0: 797 if user.args[1] == index_stack[-1]: 798 index_stack.pop() 799 current_pattern.append(user) 800 find_patterns( 801 user, index_stack, current_pattern, matched_patterns, seen 802 ) 803 return matched_patterns 804 805 # Collect all matched patterns 806 matched_patterns: List[List[Node]] = [] 807 seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack) 808 for node in graph.nodes: 809 find_patterns(node, [], [], matched_patterns, seen) 810 811 # For each pattern, redirect all consumers of the last getitem node to the correct input 812 # of the first tuple node 813 for pattern in matched_patterns: 814 first_tuple = pattern[0] 815 last_getitem = pattern[-1] 816 assert first_tuple.op == "call_function" and first_tuple.target == tuple 817 assert ( 818 last_getitem.op == "call_function" 819 and last_getitem.target == operator.getitem 820 ) 821 last_getitem_index = last_getitem.args[1] 822 new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] 823 for user in list(last_getitem.users.keys()): 824 user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type] 825 826 827def _get_observer_from_activation_post_process( 828 activation_post_process: Union[ObserverBase, FakeQuantizeBase], 829) -> ObserverBase: 830 """ 831 If `activation_post_process` is an observer, return the observer. 832 If `activation_post_process` is a fake quantize, return the internal observer. 833 """ 834 if isinstance(activation_post_process, ObserverBase): 835 return activation_post_process 836 else: 837 assert isinstance(activation_post_process, FakeQuantizeBase) 838 return activation_post_process.activation_post_process # type: ignore[return-value] 839 840 841def _qconfig_satisfies_dtype_config_constraints( 842 qconfig: QConfigAny, 843 dtype_with_constraints: DTypeWithConstraints, 844 is_activation: bool = True, 845) -> bool: 846 """ 847 Return whether `qconfig` satisfies the following constraints from the backend, 848 specified through the activation and weight DTypeWithConstraints. 849 850 1. QConfig specified a quantization range that falls within the backend's, if any 851 2. QConfig specified a min scale value that is >= the backend's, if any 852 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has 853 scale and zero point that match the backend's, if any 854 855 If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. 856 If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. 857 """ 858 859 # TODO: log warnings only when the user enabled a debug flag 860 def _activation_post_process_satisfies_dtype_config_constraints( 861 activation_post_process: Union[ObserverBase, FakeQuantizeBase], 862 dtype_with_constraints: DTypeWithConstraints, 863 debug_string: str, 864 ) -> bool: 865 observer = _get_observer_from_activation_post_process(activation_post_process) 866 app_quant_min = getattr(observer, "quant_min", None) 867 app_quant_max = getattr(observer, "quant_max", None) 868 # TODO: for now, just use the existing eps value as scale_min. In the future, we should 869 # resolve the differences between the two, either by renaming eps or some other way 870 app_scale_min = getattr(observer, "eps", None) 871 backend_quant_min = dtype_with_constraints.quant_min_lower_bound 872 backend_quant_max = dtype_with_constraints.quant_max_upper_bound 873 backend_scale_min = dtype_with_constraints.scale_min_lower_bound 874 backend_scale_exact_match = dtype_with_constraints.scale_exact_match 875 backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match 876 # check quantization ranges 877 if backend_quant_min is not None and backend_quant_max is not None: 878 if app_quant_min is None or app_quant_max is None: 879 warnings.warn( 880 f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}" 881 ) 882 return False 883 elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: 884 warnings.warn( 885 f"QConfig {debug_string} quantization range must fall within the backend's:\n" 886 f"QConfig range = ({app_quant_min}, {app_quant_max}), " 887 f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " 888 f"ignoring {qconfig}" 889 ) 890 return False 891 # check scale min 892 if backend_scale_min is not None: 893 if app_scale_min is None: 894 warnings.warn( 895 f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}" 896 ) 897 return False 898 if app_scale_min < backend_scale_min: 899 warnings.warn( 900 f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " 901 f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}" 902 ) 903 return False 904 # check fixed scale and zero point 905 if ( 906 backend_scale_exact_match is not None 907 and backend_zero_point_exact_match is not None 908 ): 909 # For tests only, accept the following qconfigs for now 910 # TODO: handle fp16 qconfigs properly 911 for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: 912 if qconfig_equals(qconfig, accepted_qconfig): 913 return True 914 suggestion_str = ( 915 "Please use torch.ao.quantization.get_default_qconfig_mapping or " 916 "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" 917 ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' 918 " model = prepare_fx(model, qconfig_mapping, example_inputs)" 919 ) 920 if not isinstance( 921 activation_post_process, FixedQParamsObserver 922 ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): 923 warnings.warn( 924 f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " 925 f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}" 926 ) 927 return False 928 if ( 929 observer.scale != backend_scale_exact_match 930 or observer.zero_point != backend_zero_point_exact_match 931 ): 932 warnings.warn( 933 f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " 934 f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " 935 f"ignoring {qconfig}.\n{suggestion_str}" 936 ) 937 return False 938 return True 939 940 if qconfig is None or dtype_with_constraints.dtype is None: 941 return True 942 943 activation_post_process_ctr = ( 944 qconfig.activation if is_activation else qconfig.weight 945 ) 946 debug_string = "activation" if is_activation else "weight" 947 satisfies_constraints = True 948 if activation_post_process_ctr is not None: 949 activation_post_process = activation_post_process_ctr() 950 assert _is_activation_post_process(activation_post_process) 951 # If dtypes don't match, don't check the activation_post_process and return True early 952 if activation_post_process.dtype != dtype_with_constraints.dtype: 953 return True 954 satisfies_constraints = ( 955 _activation_post_process_satisfies_dtype_config_constraints( 956 activation_post_process, dtype_with_constraints, debug_string 957 ) 958 ) 959 return satisfies_constraints 960