1# mypy: allow-untyped-defs 2from typing import Any, Dict, Optional, Tuple, Union 3 4import torch 5from torch._subclasses import FakeTensor 6from torch.ao.quantization import ( 7 CUSTOM_KEY, 8 NUMERIC_DEBUG_HANDLE_KEY, 9 ObserverOrFakeQuantize, 10 QConfigMapping, 11) 12from torch.ao.quantization.fx.custom_config import PrepareCustomConfig 13from torch.ao.quantization.fx.prepare import ( 14 _create_obs_or_fq_from_qspec, 15 _insert_obs_or_fq, 16 _is_activation_post_process_node, 17 _save_state, 18) 19from torch.ao.quantization.qconfig import QConfigAny 20from torch.ao.quantization.quantizer import ( 21 EdgeOrNode, 22 QuantizationSpecBase, 23 SharedQuantizationSpec, 24) 25from torch.fx import Graph, GraphModule, Node 26from torch.fx.node import Argument 27 28 29# TODO: make pt2e folder private? 30__all__ = [ 31 "prepare", 32] 33 34 35def _find_root_edge_or_node( 36 edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode] 37) -> EdgeOrNode: 38 """Find the root node for the sharing tree 39 Args: 40 edge_or_node: edge/node that we want to find the root 41 shared_with_map: each edge/node points to the parent, the root node will points to itself 42 43 Returns: 44 root edge/node 45 """ 46 parent = shared_with_map[edge_or_node] 47 if parent == edge_or_node: 48 return edge_or_node 49 root = _find_root_edge_or_node(parent, shared_with_map) 50 # path compression 51 shared_with_map[edge_or_node] = root 52 return root 53 54 55def _union( 56 parent: EdgeOrNode, 57 child: EdgeOrNode, 58 shared_with_map: Dict[EdgeOrNode, EdgeOrNode], 59) -> None: 60 """Merge the subtree for `child` with `parent`, the order is important here""" 61 root_parent = _find_root_edge_or_node(parent, shared_with_map) 62 root_child = _find_root_edge_or_node(child, shared_with_map) 63 # union the two trees by pointing the root of child to root of parent 64 shared_with_map[root_child] = root_parent 65 66 67def _update_shared_with( 68 child: EdgeOrNode, 69 qspec: QuantizationSpecBase, 70 shared_with_map: Dict[EdgeOrNode, EdgeOrNode], 71): 72 """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` 73 configuration and established the relationship between `edge_or_node` with the edge/node that it 74 is pointing to, we'll use this information in the end to get the group id 75 """ 76 if isinstance(qspec, SharedQuantizationSpec): 77 parent = qspec.edge_or_node 78 # we point from edge_or_node to the node that it is sharing_with, e.g. 79 # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` 80 _union(parent, child, shared_with_map) 81 82 83def _unwrap_shared_qspec( 84 qspec: QuantizationSpecBase, 85 edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], 86 shared_with_map: Dict[EdgeOrNode, EdgeOrNode], 87) -> QuantizationSpecBase: 88 """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) 89 if qspec is SharedQuantizationSpec 90 (1). tries to find the root edge or node for the node that the qspec points to 91 (2). recursively find the root qspec based on the qspec for the root node 92 """ 93 if isinstance(qspec, SharedQuantizationSpec): 94 sharing_with = qspec.edge_or_node 95 root = _find_root_edge_or_node(sharing_with, shared_with_map) 96 qspec = edge_or_node_to_qspec[root] 97 return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) 98 return qspec 99 100 101def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): 102 return ( 103 hasattr(qspec_a, "dtype") 104 and hasattr(qspec_b, "dtype") 105 and qspec_a.dtype == qspec_b.dtype 106 ) 107 108 109def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): 110 return ( 111 hasattr(qspec_a, "is_dynamic") 112 and hasattr(qspec_b, "is_dynamic") 113 and qspec_a.is_dynamic == qspec_b.is_dynamic 114 ) 115 116 117def _get_edge_or_node_to_qspec( 118 model: torch.fx.GraphModule, 119) -> Dict[EdgeOrNode, QuantizationSpecBase]: 120 """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" 121 edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {} 122 for n in model.graph.nodes: 123 if hasattr(n, "meta") and "quantization_annotation" in n.meta: 124 qa = n.meta["quantization_annotation"] 125 for input_to_n, qspec in qa.input_qspec_map.items(): 126 input_edge = (input_to_n, n) 127 edge_or_node_to_qspec[input_edge] = qspec 128 if qa.output_qspec is not None: 129 output_node = n 130 qspec = qa.output_qspec 131 edge_or_node_to_qspec[output_node] = qspec 132 return edge_or_node_to_qspec 133 134 135def _union_input_edge_with( 136 input_edge, 137 input_edge_root_qspec, 138 edge_or_node, 139 edge_or_node_to_qspec, 140 shared_with_map, 141): 142 """Union input edge with another edge or node, used in implicit sharing to point the current input 143 edge to other user edges of the producer node, or the output of producer node since these are 144 referring to the same Tensor 145 """ 146 root_qspec = None 147 if edge_or_node in edge_or_node_to_qspec: 148 qspec = edge_or_node_to_qspec[edge_or_node] 149 root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) 150 # TODO: add assertions for types of root qspecs 151 if ( 152 root_qspec is not None 153 and _has_same_dtype(root_qspec, input_edge_root_qspec) 154 and _has_same_is_dynamic(root_qspec, input_edge_root_qspec) 155 ): 156 # the input arg to the node should reuse the existing output observer for arg 157 # since dtype is the same (we may want to extend this to be a more strict check 158 # in the future) 159 # so we point from `input_edge` to `arg` (output of the argument) 160 _union(edge_or_node, input_edge, shared_with_map) 161 162 163def _get_edge_or_node_to_group_id( 164 edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] 165) -> Dict[EdgeOrNode, int]: 166 """Map from edge/node to the group ID, generated from quantization annotations, 167 edge/node with the same group ID should use the same observer/fake_quant instance 168 169 This is applying SharedQuantizationSpec configuration and map each edge/node to a group 170 There is another implicit sharing that's built in the quantization, when we have the following: 171 * op1 -> op2 172 * output of op1: int8_qspec 173 * (op1 -> op2) input edge: int8_qspec 174 we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. 175 176 Figuring out the correct group ID for all edge/node is a standard union find problem: 177 https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ 178 179 Args: 180 edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations 181 Returns: 182 edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that 183 belongs to the same group should have the same id 184 185 Example: 186 op2 -> cat1 -> cat2 187 op1 / / 188 op3 189 edge_or_node_to_qspec: { 190 op1: int8_qspec, 191 op2: int8_qspec, 192 (op1, cat1): int8_qspc, 193 (op2, cat1): SharedQuantizationSpec((op1, cat1)), 194 cat1: SharedQuantizationSpec((op1, cat1)), 195 (op3, cat2): int8_qspec, 196 (cat1, cat2): SharedQuantizationSpec((op3, cat2)), 197 cat2: SharedQuantizationSpec((op3, cat2)), 198 } 199 200 edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) 201 edge_or_node_to_group_id: { 202 op1: 1, 203 op2: 1, 204 (op1, cat1): 1, 205 (op2, cat1): 1, 206 cat1: 1, 207 (op3, cat2): 1, 208 (cat1, cat2): 1, 209 cat2: 1, 210 } 211 # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which 212 # connects the two sharing group around cat1 and cat2 op due to transitive sharing 213 """ 214 # means the observer of key should be shared with observer with value, by default it will 215 # be shared with itself 216 shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = { 217 k: k for k in edge_or_node_to_qspec.keys() 218 } 219 for edge_or_node, qspec in edge_or_node_to_qspec.items(): 220 if isinstance(edge_or_node, torch.fx.Node): 221 output_node = edge_or_node 222 _update_shared_with(output_node, qspec, shared_with_map) 223 else: 224 input_edge = edge_or_node 225 input_edge_root_qspec = _unwrap_shared_qspec( 226 qspec, edge_or_node_to_qspec, shared_with_map 227 ) 228 229 assert isinstance(input_edge, tuple) 230 arg, n = input_edge 231 if n.meta["quantization_annotation"].allow_implicit_sharing: 232 # NOTE: the order is important here, we first share with other users and then share with previous 233 # output because the reverse order could cause circular dependency 234 # e.g node1 -> node2 235 # \ -> node3 236 # when processing (node1, node2), if we first point (node1, node2) to node1 237 # Step 1. shared_map = {(node1, node2): node1} 238 # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , 239 # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} 240 # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) 241 # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll 242 # have a circular dependency 243 # the following order works around this issue, but this does not allow arbitrary configuration 244 # of sharing so it might break in a different case in the future, when it breaks 245 # quantizer writer can check the notes here to debug the issue 246 247 # sharing with other users of the producer node 248 # (arg, user) 249 if not isinstance(arg, Node) or not isinstance(n, Node): 250 raise Exception( # noqa: TRY002 251 f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}" 252 ) 253 for user in arg.users: 254 if user is n: 255 continue 256 arg_to_user_edge = (arg, user) 257 _union_input_edge_with( 258 input_edge, 259 input_edge_root_qspec, 260 arg_to_user_edge, 261 edge_or_node_to_qspec, 262 shared_with_map, 263 ) 264 265 # sharing with output of producer node 266 _union_input_edge_with( 267 input_edge, 268 input_edge_root_qspec, 269 arg, 270 edge_or_node_to_qspec, 271 shared_with_map, 272 ) 273 274 _update_shared_with(input_edge, qspec, shared_with_map) 275 276 # now that we get the sharing relations between all edges and nodes, we can assingn group ids 277 cur_group_id = 0 278 edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {} 279 for edge_or_node in shared_with_map.keys(): 280 root = _find_root_edge_or_node(edge_or_node, shared_with_map) 281 if root not in edge_or_node_to_group_id: 282 edge_or_node_to_group_id[root] = cur_group_id 283 cur_group_id += 1 284 edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] 285 286 return edge_or_node_to_group_id 287 288 289def _get_obs_or_fq_map( 290 edge_or_node_to_group_id: Dict[EdgeOrNode, int], 291 edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], 292 is_qat: bool, 293) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]: 294 """Generates the EdgeOrNode to observer/fake_quant instances 295 Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant 296 instances 297 """ 298 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} 299 group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {} 300 for edge_or_node, qspec in edge_or_node_to_qspec.items(): 301 group_id = edge_or_node_to_group_id[edge_or_node] 302 if group_id not in group_id_to_obs_or_fq: 303 # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify 304 # the implementation for _create_obs_or_fq_from_qspec 305 group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec( 306 qspec, obs_or_fq_map, is_qat 307 ) 308 obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] 309 return obs_or_fq_map 310 311 312def _maybe_insert_input_observer_for_arg_or_kwarg( 313 node: Union[Node, Any], 314 arg: Argument, 315 qconfig: QConfigAny, 316 model: torch.nn.Module, 317 named_modules: Dict[str, torch.nn.Module], 318 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 319 is_qat: bool, 320) -> Argument: 321 """ 322 Given a `node` and an `arg`, inserts an input observer between 323 `node` and `arg` if necessary. 324 """ 325 # for ops such as torch.cat([x0, x1]), 326 # traverse through the list 327 if isinstance(arg, (list, tuple)): 328 new_arg_to_return = [] 329 for inner_arg in arg: 330 new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( 331 node, 332 inner_arg, 333 qconfig, 334 model, 335 named_modules, 336 obs_or_fq_map, 337 is_qat, 338 ) 339 new_arg_to_return.append(new_inner_arg) 340 return type(arg)(new_arg_to_return) 341 342 if not isinstance(arg, Node): 343 return arg 344 assert isinstance(arg, Node) 345 # default (no observer) 346 new_arg = arg 347 348 # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes 349 original_arg = arg 350 while _is_activation_post_process_node(original_arg, named_modules): 351 original_arg = original_arg.args[0] # type: ignore[assignment] 352 assert isinstance( 353 original_arg, Node 354 ), f"expect original argument to be a Node, but got: {type(original_arg)}" 355 356 input_edge = (original_arg, node) 357 if input_edge not in obs_or_fq_map: 358 return new_arg 359 # input_edge needs to be observed 360 input_edge_obs_or_fq = obs_or_fq_map[input_edge] 361 if input_edge_obs_or_fq is None: 362 return new_arg 363 364 arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) 365 # the arg is observed as the output and is using the same instance as the input_edge 366 # we'll reuse the inserted observer/fake_quant 367 if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id( 368 input_edge_obs_or_fq 369 ): 370 return new_arg 371 372 # otherwise, we'll insert a new observer/fake_quant node 373 374 existing_obs_node = None 375 # skip inserting new observers if the same observer instance is inserted before for another user 376 # Example: 377 # conv1 -> obs1 -> existing_obs -> conv2 378 # \ -> conv3 379 # 380 # instead of inserting new observers we will have: 381 # conv1 -> obs1 -> existing_obs -> conv2 382 # \ -> conv3 383 for maybe_obs_node in arg.users.keys(): 384 if not _is_activation_post_process_node(maybe_obs_node, named_modules): 385 continue 386 maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] 387 if id(maybe_obs_mod) == id(input_edge_obs_or_fq): 388 return maybe_obs_node 389 390 new_arg = _insert_obs_or_fq( 391 arg, input_edge_obs_or_fq, model, named_modules, model.graph 392 ) 393 return new_arg 394 395 396def _maybe_insert_input_observers_for_node( 397 node: Node, 398 qconfig: QConfigAny, 399 model: torch.nn.Module, 400 named_modules: Dict[str, torch.nn.Module], 401 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 402 is_qat: bool, 403) -> None: 404 """ 405 If needed, inserts observers to the input args and kwargs of `node`. 406 Note: modifies `node` inplace. 407 408 For example, if cur_node needs an observer after prev_node, we change from 409 410 prev_node -> cur_node 411 412 To 413 414 prev_node -> obs -> cur_node 415 416 """ 417 # Look through every input arg. If that arg's target dtype does not 418 # match the current node's target dtype, insert an observer. 419 new_args = [] 420 for arg in node.args: 421 new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( 422 node, 423 arg, 424 qconfig, 425 model, 426 named_modules, 427 obs_or_fq_map, 428 is_qat, 429 ) 430 new_args.append(new_arg) 431 432 # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and 433 # gelu has a has an approximate kwarg that persist in exported graph. 434 # This is just a work around for these. 435 assert ( 436 node.target == torch.ops.aten.clone.default 437 or node.target == torch.ops.aten.zeros_like.default 438 or node.target == torch.ops.aten.gelu.default 439 or len(node.kwargs) == 0 440 ), " expecting kwargs for aten op IR to be empty" 441 442 # assign the new args to the node, inplace 443 node.args = tuple(new_args) 444 445 446def _maybe_insert_output_observer_for_node( 447 node: Node, 448 model: torch.nn.Module, 449 named_modules: Dict[str, torch.nn.Module], 450 graph: Graph, 451 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 452 is_qat: bool, 453) -> Optional[Node]: 454 if node in obs_or_fq_map: 455 output_act_obs_or_fq = obs_or_fq_map[node] 456 new_output = _insert_obs_or_fq( 457 node, output_act_obs_or_fq, model, named_modules, graph 458 ) 459 # propagate numeric debug handle from original node to observer/fake_quant node 460 if ( 461 isinstance(node, Node) 462 and isinstance(new_output, Node) 463 and CUSTOM_KEY in node.meta 464 and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] 465 ): 466 if CUSTOM_KEY not in new_output.meta: 467 new_output.meta[CUSTOM_KEY] = {} 468 new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ 469 CUSTOM_KEY 470 ][NUMERIC_DEBUG_HANDLE_KEY] 471 return new_output 472 return None 473 474 475def _maybe_insert_input_and_output_observers_for_node( 476 node: Node, 477 model: torch.fx.GraphModule, 478 obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], 479 is_qat: bool, 480): 481 this_node_quantization_annotation = ( 482 node.meta["quantization_annotation"] 483 if "quantization_annotation" in node.meta 484 else None 485 ) 486 if this_node_quantization_annotation is None: 487 return 488 489 named_modules = dict(model.named_modules(remove_duplicate=False)) 490 _maybe_insert_input_observers_for_node( 491 node, 492 None, # qconfig 493 model, 494 named_modules, 495 obs_or_fq_map, 496 is_qat, 497 ) 498 499 output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) 500 if not output_is_a_tensor: 501 return 502 503 # this returns the new observer node if it was needed 504 maybe_output_obs_node = _maybe_insert_output_observer_for_node( 505 node, model, named_modules, model.graph, obs_or_fq_map, is_qat 506 ) 507 508 if maybe_output_obs_node is None: 509 return 510 # Update users of original node to use the output observer 511 # instead. For example, change 512 # 513 # next_node 514 # / 515 # cur_node -> obs 516 # 517 # to 518 # 519 # next_node 520 # / 521 # cur_node -> obs 522 # 523 # We need to save orig users before updating uses because 524 # the list of users will change as we update uses 525 orig_users = list(node.users.keys()) 526 for user_node in orig_users: 527 if user_node is maybe_output_obs_node: 528 continue 529 user_node.replace_input_with(node, maybe_output_obs_node) 530 531 532def prepare( 533 model: GraphModule, 534 node_name_to_scope: Dict[str, Tuple[str, type]], 535 is_qat: bool, 536) -> GraphModule: 537 # Since we are mutating the graph as we go, we iterate over the original 538 # nodes before observer insertion, instead of model.graph.nodes. 539 nodes_before_observation = list(model.graph.nodes) 540 541 # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance 542 # all edge/nodes that belongs to the same group will use the same instance 543 # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant 544 # instance 545 edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) 546 edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) 547 obs_or_fq_map = _get_obs_or_fq_map( 548 edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat 549 ) 550 551 for node in nodes_before_observation: 552 # TODO: simplify logic for inserting observers 553 _maybe_insert_input_and_output_observers_for_node( 554 node, model, obs_or_fq_map, is_qat 555 ) 556 557 model = GraphModule(model, model.graph) 558 559 _save_state( 560 model, 561 {}, # node_name_to_qconfig 562 node_name_to_scope, 563 PrepareCustomConfig(), 564 {}, # equalization_node_name_to_qconfig 565 QConfigMapping(), 566 is_qat, 567 set(), # observed_node_names 568 ) 569 return model 570