1# mypy: allow-untyped-defs 2from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 3 4import torch 5from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map 6from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix 7from torch.ao.quantization.observer import _is_activation_post_process 8from torch.fx import GraphModule, map_arg 9from torch.fx.graph import Graph, Node 10 11from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph 12from .utils import ( 13 get_arg_indices_of_inputs_to_log, 14 get_node_first_input_and_output_type, 15 get_node_input_qparams, 16 get_normalized_nth_input, 17 get_number_of_non_param_args, 18 get_target_type_str, 19 getattr_from_fqn, 20 NodeInputOrOutputType, 21 op_type_supports_shadowing, 22 return_first_non_observer_node, 23) 24 25 26def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: 27 fqn = None 28 if hasattr(gm, "_node_name_to_scope"): 29 # fqn on observers is not present, because they do not 30 # exist when the fqns are created during tracing. If this is 31 # an observer, get the fqn of the node being observed. 32 node_to_use_for_fqn = node 33 if node.op == "call_module": 34 assert isinstance(node.target, str) 35 module = getattr_from_fqn(gm, node.target) 36 if _is_activation_post_process(module): 37 node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0) 38 fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] 39 return fqn # type: ignore[return-value] 40 41 42def _insert_logger_after_node( 43 node: Node, 44 gm: GraphModule, 45 logger_cls: Callable, 46 logger_node_name_suffix: str, 47 ref_node_name: str, 48 model_name: str, 49 ref_name: str, 50 ref_node_target_type: str, 51 results_type: str, 52 index_within_arg: int, 53 index_of_arg: int, 54 fqn: Optional[str], 55) -> Node: 56 """ 57 Given a starting graph of 58 59 prev_node -> node -> next_node 60 61 This function creates a new logger_cls obj and adds it 62 after node, resulting in 63 64 prev_node -> node -> logger_obj -> next_node 65 """ 66 # create new name 67 logger_node_name = get_new_attr_name_with_prefix( 68 node.name + logger_node_name_suffix 69 )(gm) 70 target_type = get_target_type_str(node, gm) 71 # create the logger object 72 logger_obj = logger_cls( 73 ref_node_name, 74 node.name, 75 model_name, 76 ref_name, 77 target_type, 78 ref_node_target_type, 79 results_type, 80 index_within_arg, 81 index_of_arg, 82 fqn, 83 ) 84 # attach the logger object to the parent module 85 setattr(gm, logger_node_name, logger_obj) 86 logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {}) 87 return logger_node 88 89 90def add_loggers_to_model( 91 gm: GraphModule, 92 node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]], 93 node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]], 94 logger_cls: Callable, 95 model_name: str, 96) -> GraphModule: 97 """ 98 Takes the graph of gm, adds loggers to the output 99 of each node in nodes_to_instrument. Returns a GraphModule with the new 100 graph. 101 """ 102 103 new_graph = Graph() 104 env: Dict[str, Any] = {} 105 modules = dict(gm.named_modules()) 106 107 def load_arg(a): 108 return map_arg(a, lambda node: env[node.name]) 109 110 for node in gm.graph.nodes: 111 if node.op == "output": 112 new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg)) 113 continue 114 115 if (node in node_to_instrument_inputs_to_ref_node_name) or ( 116 node in node_to_instrument_outputs_to_ref_node_name 117 ): 118 fqn = _maybe_get_fqn(node, gm) 119 120 if node in node_to_instrument_inputs_to_ref_node_name: 121 ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[ 122 node 123 ] 124 # Ops such add and mul are special because either 125 # one or two of the first two arguments can be tensors, 126 # and if one argument is a tensor it can be first or 127 # second (x + 1 versus 1 + x). 128 arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) 129 for node_arg_idx in arg_indices_to_log: 130 node_arg = get_normalized_nth_input(node, gm, node_arg_idx) 131 if type(node_arg) == Node: 132 # create a single input logger 133 prev_node = env[node_arg.name] 134 env[node_arg.name] = _insert_logger_after_node( 135 prev_node, 136 gm, 137 logger_cls, 138 "_ns_logger_", 139 node.name, 140 model_name, 141 ref_name, 142 ref_node_type, 143 NSSingleResultValuesType.NODE_INPUT.value, 144 index_within_arg=0, 145 index_of_arg=node_arg_idx, 146 fqn=fqn, 147 ) 148 elif ( 149 type(node_arg) == torch.fx.immutable_collections.immutable_list 150 ): 151 # create N input loggers, one for each node 152 for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type] 153 prev_node = env[arg.name] 154 env[prev_node.name] = _insert_logger_after_node( 155 prev_node, 156 gm, 157 logger_cls, 158 "_ns_logger_", 159 node.name, 160 model_name, 161 ref_name, 162 ref_node_type, 163 NSSingleResultValuesType.NODE_INPUT.value, 164 index_within_arg=arg_idx, 165 index_of_arg=node_arg_idx, 166 fqn=fqn, 167 ) 168 else: 169 pass 170 171 # ensure env is populated with base node 172 # Note: runs for both inputs and outputs 173 env[node.name] = new_graph.node_copy(node, load_arg) 174 175 if node in node_to_instrument_outputs_to_ref_node_name: 176 ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[ 177 node 178 ] 179 # add the logger after the base node 180 env[node.name] = _insert_logger_after_node( 181 env[node.name], 182 gm, 183 logger_cls, 184 "_ns_logger_", 185 node.name, 186 model_name, 187 ref_name, 188 ref_node_type, 189 NSSingleResultValuesType.NODE_OUTPUT.value, 190 index_within_arg=0, 191 index_of_arg=0, 192 fqn=fqn, 193 ) 194 195 else: 196 env[node.name] = new_graph.node_copy(node, load_arg) 197 198 new_gm = GraphModule(gm, new_graph) 199 return new_gm 200 201 202def _insert_quantize_per_tensor_node( 203 prev_node_c: Node, 204 node_a: Node, 205 gm_b: GraphModule, 206 graph_c: Graph, 207 scale: Union[torch.Tensor, float], 208 zero_point: Union[torch.Tensor, int], 209 dtype_cast_name: str, 210) -> Node: 211 # copy scale 212 scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b) 213 setattr(gm_b, scale_node_name, scale) 214 scale_node = graph_c.create_node( 215 "get_attr", scale_node_name, (), {}, scale_node_name 216 ) 217 # copy zero_point 218 zero_point_node_name = get_new_attr_name_with_prefix( 219 node_a.name + "_input_zero_point_" 220 )(gm_b) 221 setattr(gm_b, zero_point_node_name, zero_point) 222 zero_point_node = graph_c.create_node( 223 "get_attr", zero_point_node_name, (), {}, zero_point_node_name 224 ) 225 # create the quantize_per_tensor call 226 return graph_c.create_node( 227 "call_function", 228 torch.quantize_per_tensor, 229 (prev_node_c, scale_node, zero_point_node, torch.quint8), 230 {}, 231 dtype_cast_name, 232 ) 233 234 235def _insert_dtype_cast_after_node( 236 node_a: Node, 237 node_c: Node, 238 prev_node_c: Union[Node, List[Node]], 239 gm_a: GraphModule, 240 gm_b: GraphModule, 241 graph_c: Graph, 242 node_name_prefix: str, 243 logger_cls: Callable, 244 node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], 245) -> Union[Node, List[Node]]: 246 """ 247 Given a starting graph C (derived from graph B) of 248 249 ... -> prev_node_c -> node_c -> ... 250 251 And a corresponding related node_a, inserts the correct dtype 252 cast node after prev_node_c to cast into the dtype expected 253 by node_a, resulting in: 254 255 dtype_cast 256 / 257 ... -> prev_node_c -> node_c -> ... 258 259 For example, if node_c is an int8 op and node_a is an fp32 op, this function 260 will insert a dequant. 261 """ 262 dtype_cast_op = None 263 dtype_cast_mod_cls = None 264 dtype_cast_method = None 265 dtype_cast_method_dtype = None 266 dtype_cast_scale = None 267 dtype_cast_zero_point = None 268 node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type( 269 node_a, gm_a, logger_cls, node_type_to_io_type_map 270 ) 271 node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type( 272 node_c, gm_b, logger_cls, node_type_to_io_type_map 273 ) 274 275 if ( 276 ( 277 node_input_type_a == NodeInputOrOutputType.FP32 278 and node_input_type_c == NodeInputOrOutputType.INT8 279 ) 280 or ( 281 node_input_type_a == NodeInputOrOutputType.FP32 282 and node_input_type_c == NodeInputOrOutputType.FP16 283 ) 284 or 285 # TODO(future PR): determine the actual dtype of node_c, 286 # the current code only works because dequantize works with 287 # multiple input dtypes. 288 ( 289 node_input_type_a == NodeInputOrOutputType.FP32 290 and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8 291 ) 292 ): 293 dtype_cast_op = torch.dequantize 294 elif ( 295 node_input_type_a == node_input_type_c 296 and node_input_type_a != NodeInputOrOutputType.UNKNOWN 297 ): 298 dtype_cast_mod_cls = torch.nn.Identity 299 elif ( 300 node_input_type_a == NodeInputOrOutputType.INT8 301 and node_input_type_c == NodeInputOrOutputType.FP32 302 ): 303 # int8 shadows fp32, the dtype cast needs to quantize to int8 304 # with the right qparams. 305 node_a_input_qparams = get_node_input_qparams( 306 node_a, gm_a, node_type_to_io_type_map 307 ) 308 if node_a_input_qparams is not None: 309 dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment] 310 dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams 311 elif ( 312 node_input_type_a == NodeInputOrOutputType.FP16 313 and node_input_type_c == NodeInputOrOutputType.FP32 314 ): 315 dtype_cast_method = "to" 316 dtype_cast_method_dtype = torch.float16 317 else: 318 raise AssertionError( 319 f"dtype cast from {node_input_type_c} {node_c.format_node()} to " 320 + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" 321 ) 322 323 if isinstance(prev_node_c, Node): 324 new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) 325 if dtype_cast_op: 326 if dtype_cast_scale is not None and dtype_cast_zero_point is not None: 327 return _insert_quantize_per_tensor_node( 328 prev_node_c, 329 node_a, 330 gm_b, 331 graph_c, 332 dtype_cast_scale, 333 dtype_cast_zero_point, 334 new_dtype_cast_name, 335 ) 336 else: 337 return graph_c.create_node( 338 "call_function", 339 dtype_cast_op, 340 (prev_node_c,), 341 {}, 342 new_dtype_cast_name, 343 ) 344 elif dtype_cast_method: 345 return graph_c.create_node( 346 "call_method", 347 dtype_cast_method, 348 (prev_node_c, dtype_cast_method_dtype), 349 {}, 350 new_dtype_cast_name, 351 ) 352 else: 353 assert dtype_cast_mod_cls 354 dtype_cast_mod = dtype_cast_mod_cls() 355 setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) 356 return graph_c.create_node( 357 "call_module", 358 new_dtype_cast_name, 359 (prev_node_c,), 360 {}, 361 new_dtype_cast_name, 362 ) 363 elif isinstance(prev_node_c, list): 364 results = [] 365 for prev_node_c_inner in prev_node_c: 366 new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) 367 if dtype_cast_op: 368 # TODO(future PR): add handling for quantize_per_tensor 369 new_dtype_cast_node = graph_c.create_node( 370 "call_function", 371 dtype_cast_op, 372 (prev_node_c_inner,), 373 {}, 374 new_dtype_cast_name, 375 ) 376 results.append(new_dtype_cast_node) 377 else: 378 assert dtype_cast_mod_cls 379 dtype_cast_mod = dtype_cast_mod_cls() 380 setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) 381 new_dtype_cast_node = graph_c.create_node( 382 "call_module", 383 new_dtype_cast_name, 384 (prev_node_c_inner,), 385 {}, 386 new_dtype_cast_name, 387 ) 388 results.append(new_dtype_cast_node) 389 return results 390 else: 391 raise AssertionError(f"type f{type(prev_node_c)} is not handled") 392 393 394# TODO(future PR): look into using copy_node API instead 395def _copy_node_from_a_to_c( 396 node_a: Node, 397 gm_a: GraphModule, 398 gm_b: GraphModule, 399 graph_c: Graph, 400) -> Node: 401 """ 402 Simple copy of node_a to graph_c. 403 """ 404 if node_a.op == "get_attr": 405 node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")( 406 gm_b 407 ) 408 node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] 409 if torch.is_tensor(node_a_obj): 410 node_a_obj = node_a_obj.detach() 411 setattr(gm_b, node_a_copy_name, node_a_obj) 412 node_a_copy = graph_c.create_node( 413 node_a.op, node_a_copy_name, (), {}, node_a_copy_name 414 ) 415 return node_a_copy 416 elif node_a.op == "call_method": 417 assert node_a.target in ( 418 "dequantize", 419 "to", 420 ), f"target {node_a.target} is not implemented" 421 if node_a.target == "dequantize": 422 arg_copy = _copy_node_from_a_to_c( 423 get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c 424 ) # type: ignore[arg-type] 425 node_a_copy_name = get_new_attr_name_with_prefix( 426 node_a.name + "_shadow_copy_" 427 )(gm_b) 428 node_a_copy = graph_c.create_node( 429 node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name 430 ) 431 return node_a_copy 432 else: # to 433 arg_copy = _copy_node_from_a_to_c( 434 get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c 435 ) # type: ignore[arg-type] 436 node_a_copy_name = get_new_attr_name_with_prefix( 437 node_a.name + "_shadow_copy_" 438 )(gm_b) 439 node_a_copy = graph_c.create_node( 440 node_a.op, 441 node_a.target, 442 (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), 443 {}, 444 node_a_copy_name, 445 ) 446 return node_a_copy 447 448 else: 449 raise AssertionError( 450 f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented" 451 ) 452 453 454def _can_insert_copy_of_subgraph_a( 455 subgraph_a: NSSubgraph, 456 gm_a: GraphModule, 457 num_non_param_args_node_a: int, 458) -> bool: 459 """ 460 This function returns `False` if the input subgraph cannot be copied by 461 `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means 462 that there is a corner case logic for which copy is not yet implemented. 463 """ 464 # populate the list of nodes we need to check 465 nodes = [] 466 cur_node = subgraph_a.end_node 467 while cur_node != subgraph_a.start_node: 468 nodes.append(cur_node) 469 cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] 470 nodes.append(cur_node) 471 nodes.reverse() 472 473 def _can_insert(node_a_arg, gm_a): 474 if isinstance(node_a_arg, Node): 475 arg_a = return_first_non_observer_node(node_a_arg, gm_a) 476 if arg_a.op == "call_method": 477 return arg_a.target in ("dequantize", "to") 478 elif arg_a.op == "get_attr": 479 return True 480 else: 481 return False 482 elif isinstance(node_a_arg, (list, tuple)): 483 for el in node_a_arg: 484 if not isinstance(el, Node): 485 return False 486 return True 487 488 # For each node, check if we handle the copy behavior. This follows the 489 # logic in `_insert_copy_of_subgraph_a_after_input_node_c`. 490 for node_a in nodes: 491 local_num_non_param_args_node_a = ( 492 num_non_param_args_node_a if node_a is nodes[0] else 1 493 ) 494 495 norm_args_kwargs = node_a.normalized_arguments( 496 gm_a, normalize_to_only_use_kwargs=True 497 ) 498 if norm_args_kwargs is not None: 499 norm_args, norm_kwargs = norm_args_kwargs 500 else: 501 norm_args, norm_kwargs = node_a.args, node_a.kwargs 502 503 cur_idx = 0 504 505 while cur_idx < len(norm_args): 506 if cur_idx == 0: 507 pass 508 elif cur_idx == 1 and local_num_non_param_args_node_a == 2: 509 pass 510 else: 511 if not _can_insert(norm_args[cur_idx], gm_a): 512 return False 513 cur_idx += 1 514 515 for kwarg_val in norm_kwargs.values(): 516 # stitch the inputs from base graph 517 if cur_idx == 0: 518 pass 519 elif cur_idx == 1 and local_num_non_param_args_node_a == 2: 520 pass 521 else: 522 if not _can_insert(kwarg_val, gm_a): 523 return False 524 cur_idx += 1 525 526 return True 527 528 529def _insert_copy_of_subgraph_a_after_input_node_c( 530 input_node_c: Union[Node, List[Node]], 531 input_node_c_2: Optional[Union[Node, List[Node]]], 532 subgraph_a: NSSubgraph, 533 gm_a: GraphModule, 534 gm_b: GraphModule, 535 node_name_prefix: str, 536) -> Node: 537 """ 538 TODO(before land): real docblock 539 """ 540 if isinstance(input_node_c, Node): 541 graph_c = input_node_c.graph 542 else: 543 assert isinstance(input_node_c, list) 544 graph_c = input_node_c[0].graph 545 546 # create a sequential list of the subgraphs' nodes from start to end, 547 # because we need to add the nodes to graph C in non-reverse order 548 nodes_of_a = [subgraph_a.end_node] 549 cur_node = subgraph_a.end_node 550 while cur_node != subgraph_a.start_node: 551 cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] 552 nodes_of_a.insert(0, cur_node) 553 554 # go through nodes of a in order, and insert them into the graph of c 555 # sequentially 556 cur_node_a = nodes_of_a[0] 557 cur_node_c = _insert_copy_of_node_a_after_input_node_c( 558 input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix 559 ) 560 for cur_idx_a in range(1, len(nodes_of_a)): 561 cur_node_a = nodes_of_a[cur_idx_a] 562 prev_node_c = cur_node_c # previous added node is the input to next node 563 cur_node_c = _insert_copy_of_node_a_after_input_node_c( 564 prev_node_c, 565 # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph 566 None, 567 cur_node_a, 568 gm_a, 569 gm_b, 570 node_name_prefix, 571 ) 572 # return the last inserted node 573 return cur_node_c 574 575 576def _insert_copy_of_node_a_after_input_node_c( 577 input_node_c: Union[Node, List[Node]], 578 input_node_c_2: Optional[Union[Node, List[Node]]], 579 node_a: Node, 580 gm_a: GraphModule, 581 gm_b: GraphModule, 582 node_name_prefix: str, 583) -> Node: 584 """ 585 Assume that node_a from graph_a has 586 args (input, (input2)?, arg1, ...), and 587 kwargs {kw0: kwarg0, ...} 588 589 Note: input2 is optional. If it equals to None, we assume that the op 590 has a single non-param input. If it is specified, we assume that the op 591 has two non-param inputs. 592 593 Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, 594 and creates the corresponding nodes in graph_c. Note: observers are ignored, 595 so if an arg is an observer we navigate up until we find a non-observer parent. 596 597 If node_a is a call_module, points the module pointed to by node_a to gm_b. 598 599 Creates the copy of node_a in graph_c, with input as the first arg, 600 and all other args and kwargs pointing to the copies of the objects 601 in gm_b created above. 602 603 An example in pictures: 604 605 graph A: 606 ======== 607 608 input -------------> node_a 609 / / / 610 (input_2)?----------/ / / 611 / / 612 weight -> weight_obs / 613 / 614 bias ---------------- 615 616 graph C (derived from B): 617 ========================= 618 619 input_node_c --> node_a_copy 620 / / / 621 (input_node_c_2)? / / 622 / / 623 weight_copy ----/ / 624 / 625 bias_copy ------/ 626 """ 627 if isinstance(input_node_c, Node): 628 graph_c = input_node_c.graph 629 else: 630 assert isinstance(input_node_c, list) 631 graph_c = input_node_c[0].graph 632 633 norm_args_kwargs = node_a.normalized_arguments( 634 gm_a, normalize_to_only_use_kwargs=True 635 ) 636 if norm_args_kwargs is not None: 637 norm_args, norm_kwargs = norm_args_kwargs 638 else: 639 norm_args, norm_kwargs = node_a.args, node_a.kwargs 640 641 new_args = [] 642 new_kwargs = {} 643 644 def _copy_arg(arg): 645 # copy the other inputs from the other graph 646 if isinstance(arg, Node): 647 arg = return_first_non_observer_node(arg, gm_a) 648 arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c) 649 return arg 650 elif isinstance(arg, (int, float, torch.dtype)): 651 return arg 652 elif isinstance(kwarg_val, (list, tuple)): 653 for el in kwarg_val: 654 assert not isinstance( 655 el, Node 656 ), "handling of Node inside list is not implemented" 657 return arg 658 else: 659 raise AssertionError( 660 f"handling for kwarg of type {type(kwarg_val)} is not implemented" 661 ) 662 663 cur_idx = 0 664 665 while cur_idx < len(norm_args): 666 if cur_idx == 0: 667 new_arg = input_node_c 668 elif cur_idx == 1 and input_node_c_2 is not None: 669 new_arg = input_node_c_2 670 else: 671 new_arg = _copy_arg(norm_args[cur_idx]) 672 new_args.append(new_arg) 673 cur_idx += 1 674 675 for kwarg_name, kwarg_val in norm_kwargs.items(): 676 # stitch the inputs from base graph 677 if cur_idx == 0: 678 new_kwargs[kwarg_name] = input_node_c 679 elif cur_idx == 1 and input_node_c_2 is not None: 680 new_kwargs[kwarg_name] = input_node_c_2 681 else: 682 new_kwargs[kwarg_name] = _copy_arg(kwarg_val) 683 cur_idx += 1 684 685 new_args = tuple(new_args) # type: ignore[assignment] 686 687 node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) 688 689 if node_a.op == "call_module": 690 # if target is a module, we point to the module from gm_b 691 new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) 692 # fetch the corresponding module from gm_a 693 assert isinstance(node_a.target, str) 694 mod_a = getattr_from_fqn(gm_a, node_a.target) 695 setattr(gm_b, new_mod_copy_name, mod_a) 696 node_a_shadows_c = graph_c.create_node( 697 node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type] 698 ) 699 return node_a_shadows_c 700 else: 701 assert node_a.op in ("call_function", "call_method") 702 node_a_shadows_c = graph_c.create_node( 703 node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type] 704 ) 705 return node_a_shadows_c 706 707 708def create_a_shadows_b( 709 name_a: str, 710 gm_a: GraphModule, 711 name_b: str, 712 gm_b: GraphModule, 713 matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]], 714 logger_cls: Callable, 715 should_log_inputs: bool, 716 node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 717) -> GraphModule: 718 """ 719 Creates a new GraphModule consisting of the graph of C, with the meaningful 720 nodes of A shadowing the corresponding nodes of B. For example, 721 722 Graph A: 723 a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 724 725 Graph B: 726 b0 -> op0_int8 -> b1 -> op1_int8 -> b2 727 728 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} 729 730 Graph C (A shadows B): 731 732 / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 733 / / 734 b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 735 736 In a nutshell, this function does the following for each node pair: 737 * copies the necessary attributes and modules from gm_a to gm_b, 738 keeping names unique 739 * adds a dtype cast op (dequant, quant, etc) 740 * adds a copy of node_a in gm_b's graph 741 * adds loggers to the outputs of node_a and node_b 742 """ 743 744 if node_type_to_io_type_map is None: 745 node_type_to_io_type_map = get_node_type_to_io_type_map() 746 747 # graph_c is the graph created from copying the nodes of graph_b and inserting 748 # the shadows with the nodes copied from graph_a 749 graph_c = Graph() 750 env_c: Dict[str, Any] = {} 751 modules = dict(gm_b.named_modules()) 752 753 def load_arg(a): 754 return map_arg(a, lambda node: env_c[node.name]) 755 756 start_node_b_to_matched_subgraph_a_and_name = {} 757 end_node_b_to_matched_subgraph_a_and_name = {} 758 for match_name, match in matched_subgraph_pairs.items(): 759 subgraph_a, subgraph_b = match 760 ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) 761 ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) 762 start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = ( 763 subgraph_a, 764 match_name, 765 ref_node_type_a, 766 ref_node_type_b, 767 ) 768 end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = ( 769 subgraph_a, 770 match_name, 771 ref_node_type_a, 772 ref_node_type_b, 773 ) 774 775 for node_b in gm_b.graph.nodes: 776 if node_b.op == "output": 777 graph_c.output(map_arg(node_b.args[0], load_arg)) 778 continue 779 780 # calculate the flags to determine what to do with this node 781 node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name 782 node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name 783 784 if node_b_is_start_node or node_b_is_end_node: 785 if node_b_is_start_node: 786 ( 787 subgraph_a, 788 ref_name, 789 ref_node_type_a, 790 ref_node_type_b, 791 ) = start_node_b_to_matched_subgraph_a_and_name[node_b] 792 else: 793 assert node_b_is_end_node 794 ( 795 subgraph_a, 796 ref_name, 797 ref_node_type_a, 798 ref_node_type_b, 799 ) = end_node_b_to_matched_subgraph_a_and_name[node_b] 800 801 all_op_types_support_shadowing = op_type_supports_shadowing( 802 subgraph_a.start_node 803 ) and op_type_supports_shadowing(node_b) 804 if not all_op_types_support_shadowing: 805 print( 806 f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" 807 + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" 808 + ", unsupported" 809 ) 810 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 811 continue 812 813 # For both start_node and end_node verify that we know how to do 814 # the dtype cast. If we do not, skip. 815 ( 816 node_input_type_a, 817 node_output_type_a, 818 ) = get_node_first_input_and_output_type( 819 subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map 820 ) 821 ( 822 node_input_type_b, 823 node_output_type_b, 824 ) = get_node_first_input_and_output_type( 825 node_b, gm_b, logger_cls, node_type_to_io_type_map 826 ) 827 node_io_types_known_a_and_b = ( 828 node_input_type_a != NodeInputOrOutputType.UNKNOWN 829 and node_output_type_a != NodeInputOrOutputType.UNKNOWN 830 and node_input_type_b != NodeInputOrOutputType.UNKNOWN 831 and node_output_type_b != NodeInputOrOutputType.UNKNOWN 832 ) 833 if not node_io_types_known_a_and_b: 834 print( 835 f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" 836 + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" 837 + ", unknown dtype cast" 838 ) 839 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 840 continue 841 842 # If we are shadowing from fp32 to int8, we need to insert 843 # quantize_per_tensor call with qparams from the previous node. 844 # Only do this if we are able to infer these qparams from the graph. 845 if ( 846 node_input_type_a == NodeInputOrOutputType.INT8 847 and node_input_type_b == NodeInputOrOutputType.FP32 848 ): 849 node_a_input_qparams = get_node_input_qparams( 850 subgraph_a.start_node, gm_a, node_type_to_io_type_map 851 ) 852 if not node_a_input_qparams: 853 print( 854 f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" 855 + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" 856 + ", unknown input qparams" 857 ) 858 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 859 continue 860 861 num_non_param_args_node_a = get_number_of_non_param_args( 862 subgraph_a.start_node, gm_a 863 ) 864 if not _can_insert_copy_of_subgraph_a( 865 subgraph_a, gm_a, num_non_param_args_node_a 866 ): 867 print( 868 f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" 869 + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" 870 + ", unhandled logic in subgraph copy" 871 ) 872 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 873 continue 874 875 fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a) 876 fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined] 877 878 if node_b_is_start_node: 879 # if necessary, log the input of node_c 880 if should_log_inputs: 881 prev_node_b = get_normalized_nth_input(node_b, gm_b, 0) 882 if isinstance(prev_node_b, Node): 883 prev_node_c = env_c[prev_node_b.name] 884 env_c[prev_node_c.name] = _insert_logger_after_node( 885 prev_node_c, 886 gm_b, 887 logger_cls, 888 "_ns_logger_b_inp_", 889 node_b.name, 890 name_b, 891 ref_name, 892 ref_node_type_b, 893 NSSingleResultValuesType.NODE_INPUT.value, 894 index_within_arg=0, 895 index_of_arg=0, 896 fqn=fqn_base_b, 897 ) 898 elif isinstance(prev_node_b, list): 899 # first, save the prev_node instances, because they 900 # will be overwritten in the env after the first logger 901 # is added 902 prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] 903 904 for arg_idx, arg in enumerate(prev_node_b): 905 prev_node_c = prev_node_c_list[arg_idx] 906 env_c[prev_node_c.name] = _insert_logger_after_node( 907 prev_node_c, 908 gm_b, 909 logger_cls, 910 "_ns_logger_b_inp_", 911 node_b.name, 912 name_b, 913 ref_name, 914 ref_node_type_b, 915 NSSingleResultValuesType.NODE_INPUT.value, 916 index_within_arg=arg_idx, 917 index_of_arg=0, 918 fqn=fqn_base_b, 919 ) 920 else: 921 # logging of inputs which are not lists is not supported yet 922 raise AssertionError( 923 f"type {type(prev_node_b)} is not handled yet" 924 ) 925 # subgraph so far: 926 # 927 # (prev_node_c)+ -> (logger_c_input)? 928 929 # Note: this if statement is always True, spelling it out to clarify code 930 # intent. 931 if node_b_is_start_node or node_b_is_end_node: 932 # ensure env_c is populated with base node 933 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 934 node_c = env_c[node_b.name] 935 936 # after this point, 937 # 938 # node_a is the original node from graph_a, with parent module gm_a 939 # node_b is the original node from graph_b, with parent module gm_b 940 # node_c is the copy of node_b in graph_c 941 # 942 # subgraph so far: 943 # 944 # (prev_node_c)+ -> (logger_c_input)? -> node_start_c 945 946 if node_b_is_start_node: 947 # cast dtype from the dtype of node_c's input to the dtype of 948 # node_a's input (dequant, etc) 949 # prev_node_c = node_c.args[0] 950 prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined] 951 if should_log_inputs: 952 # skip the input logger when inserting a dtype cast 953 if isinstance(prev_node_c, Node): 954 prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) 955 elif isinstance(prev_node_c, list): 956 prev_node_c = [ 957 get_normalized_nth_input(arg, gm_b, 0) 958 for arg in prev_node_c 959 ] 960 dtype_cast_node = _insert_dtype_cast_after_node( 961 subgraph_a.start_node, 962 node_c, 963 prev_node_c, 964 gm_a, 965 gm_b, 966 graph_c, 967 node_b.name + "_dtype_cast_", 968 logger_cls, 969 node_type_to_io_type_map, 970 ) 971 # note: not inserting to env_c because all nodes which use the dtype 972 # casts are copied from graph_a 973 # 974 # subgraph so far: 975 # 976 # (dtype_cast_node)+ 977 # / 978 # (prev_node_c)+ -> (logger_c_input)? -> node_start_c 979 980 # if input logging is enabled, log the input to the subgraph 981 if should_log_inputs: 982 # TODO: explain this 983 ref_node_name = "" 984 if isinstance(dtype_cast_node, Node): 985 dtype_cast_node = _insert_logger_after_node( 986 dtype_cast_node, 987 gm_b, 988 logger_cls, 989 "_ns_logger_a_inp_", 990 ref_node_name, 991 name_a, 992 ref_name, 993 ref_node_type_a, 994 NSSingleResultValuesType.NODE_INPUT.value, 995 index_within_arg=0, 996 index_of_arg=0, 997 fqn=fqn_base_a, 998 ) 999 input_logger: Union[Node, List[Node]] = dtype_cast_node 1000 else: 1001 assert isinstance(dtype_cast_node, list) 1002 new_loggers = [] 1003 for dtype_cast_idx, dtype_cast_node_inner in enumerate( 1004 dtype_cast_node 1005 ): 1006 dtype_cast_logger = _insert_logger_after_node( 1007 dtype_cast_node_inner, 1008 gm_b, 1009 logger_cls, 1010 "_ns_logger_a_inp_", 1011 ref_node_name, 1012 name_a, 1013 ref_name, 1014 ref_node_type_a, 1015 NSSingleResultValuesType.NODE_INPUT.value, 1016 index_within_arg=dtype_cast_idx, 1017 index_of_arg=0, 1018 fqn=fqn_base_a, 1019 ) 1020 new_loggers.append(dtype_cast_logger) 1021 dtype_cast_node = new_loggers 1022 input_logger = dtype_cast_node 1023 # subgraph so far: 1024 # 1025 # (dtype_cast_node)+ -> (logger_a_input)? 1026 # / 1027 # prev_node_c -> (logger_c_input)? -> node_start_c 1028 1029 # hook up the new mod_a copy to be in the graph, receiving the 1030 # same inputs as mod_b does, with dtype cast to match a 1031 # Some ops, such as LSTMs, have two non-param inputs. If we have 1032 # such an op, pass the second param as well. Note: dtype casting 1033 # for the second param is not implemented yet, it can be added 1034 # later if there is a use case. 1035 node_c_second_non_param_arg = None 1036 num_non_param_args_node_a = get_number_of_non_param_args( 1037 subgraph_a.start_node, gm_a 1038 ) 1039 if num_non_param_args_node_a == 2: 1040 # node_c_second_non_param_arg = node_c.args[1] 1041 node_c_second_non_param_arg = get_normalized_nth_input( 1042 node_c, gm_b, 1 1043 ) 1044 node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( 1045 dtype_cast_node, 1046 node_c_second_non_param_arg, 1047 subgraph_a, 1048 gm_a, 1049 gm_b, 1050 node_c.name + "_shadow_copy_", 1051 ) 1052 env_c[node_a_shadows_c.name] = node_a_shadows_c 1053 # subgraph so far: 1054 # 1055 # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) 1056 # / 1057 # (prev_node_c)+ -> (logger_c_input)? -> node_start_c 1058 1059 if should_log_inputs: 1060 # When we created the input logger, we left the ref_node_name 1061 # as an empty string, because the subgraph copy did not exist 1062 # yet. Now that the subgraph copy exists, we modify this name 1063 # to its true value. 1064 # Note: the alternative to this is to create the input logger 1065 # after creating the subgraph, which is slightly more 1066 # complicated. This is the lesser of two evils. 1067 # input_logger = env_c[dtype_cast_node.name] 1068 # Find the first node in the subgraph 1069 cur_node = node_a_shadows_c 1070 while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] 1071 cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] 1072 if isinstance(input_logger, Node): 1073 input_logger_mod = getattr(gm_b, input_logger.name) 1074 input_logger_mod.ref_node_name = cur_node.name 1075 else: 1076 assert isinstance(input_logger, list) 1077 for input_logger_inner in input_logger: 1078 input_logger_mod = getattr(gm_b, input_logger_inner.name) 1079 input_logger_mod.ref_node_name = cur_node.name 1080 1081 # hook up a logger to the mod_a copy 1082 env_c[node_a_shadows_c.name] = _insert_logger_after_node( 1083 env_c[node_a_shadows_c.name], 1084 gm_b, 1085 logger_cls, 1086 "_ns_logger_a_", 1087 node_a_shadows_c.name, 1088 name_a, 1089 ref_name, 1090 ref_node_type_a, 1091 NSSingleResultValuesType.NODE_OUTPUT.value, 1092 index_within_arg=0, 1093 index_of_arg=0, 1094 fqn=fqn_base_a, 1095 ) 1096 # subgraph so far: 1097 # 1098 # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a 1099 # / 1100 # (prev_node_c)+ -> (logger_c_input)? -> node_start_c 1101 1102 if node_b_is_end_node: 1103 # hook up a logger to the mod_b copy 1104 env_c[node_b.name] = _insert_logger_after_node( 1105 env_c[node_b.name], 1106 gm_b, 1107 logger_cls, 1108 "_ns_logger_b_", 1109 node_b.name, 1110 name_b, 1111 ref_name, 1112 ref_node_type_b, 1113 NSSingleResultValuesType.NODE_OUTPUT.value, 1114 index_within_arg=0, 1115 index_of_arg=0, 1116 fqn=fqn_base_b, 1117 ) 1118 # subgraph so far: 1119 # 1120 # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a 1121 # / 1122 # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c 1123 # 1124 # Note: node_start_c may be the same node as node_end_c, or they 1125 # may have nodes inbetween. 1126 1127 else: 1128 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) 1129 1130 gm_c = GraphModule(gm_b, graph_c) 1131 return gm_c 1132