1# mypy: allow-untyped-defs 2import argparse 3import copy 4from collections import defaultdict 5from dataclasses import dataclass 6from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple 7import logging 8 9import torch 10from torch.fx.passes.graph_manipulation import get_size_of_node 11from torch.fx.node import map_arg 12from torch.fx._compatibility import compatibility 13 14from .operator_support import ( 15 get_node_target, 16 OperatorSupportBase, 17) 18from .graph_drawer import FxGraphDrawer 19from .shape_prop import ShapeProp 20from .split_utils import split_by_tags 21from .tools_common import ( 22 FxNetAccFusionsFinder, 23 CALLABLE_NODE_OPS, 24 Tensors, 25 NodeList, 26 NodeSet, 27 is_node_output_tensor, 28) 29 30 31__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] 32_LOGGER = logging.getLogger(__name__) 33 34DEFAULT_MIN_ACC_MODULE_SIZE = 1 35DEFAULT_SKIP_FUSION = False 36DEFAULT_ALLOW_NON_TENSOR = False 37 38class _SplitterSettingBase: 39 def __init__( 40 self, 41 min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, 42 skip_fusion=DEFAULT_SKIP_FUSION, 43 allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, 44 max_acc_splits: int = -1, 45 ): 46 parser = argparse.ArgumentParser() 47 parser.add_argument( 48 "--min-acc-module-size", 49 "--min_acc_module_size", 50 required=False, 51 type=int, 52 help="Minimum size limit of an accelerator subgraph.", 53 ) 54 parser.add_argument( 55 "--max-acc-splits", 56 "--max_acc_splits", 57 required=False, 58 type=int, 59 help="Enforce a maximum number of split subgraphs.", 60 ) 61 parser.add_argument( 62 "--skip-fusion", 63 "--skip_fusion", 64 default=False, 65 action="store_true", 66 help="If true then no fusion groups. Fusion group is used to " 67 "enforce no non-tensor data flow between submodules. If we don't " 68 "have this constrain, setting this to false is recommended as it " 69 "can reduce overhead.", 70 ) 71 parser.add_argument( 72 "--allow-non-tensor", 73 "--allow_non_tensor", 74 default=False, 75 action="store_true", 76 help="For some backends non-tensor data flow between cpu and them " 77 "are not allowed. Therefore, if a node supported by accelerator but " 78 "it has non-tensor inputs or outputs to a cpu node we would want to " 79 "consider it as a cpu node during splitting. However, for some backends " 80 "we might not care about non-tensor data flow and we can set this option " 81 "to true to disable the functionality that prevent non-tensor data flow.", 82 ) 83 args, unknown = parser.parse_known_args() 84 85 self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size 86 self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion 87 self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor 88 self.max_acc_splits: int = max_acc_splits 89 90 91@compatibility(is_backward_compatible=False) 92class FxNetAccNodesFinder: 93 """ 94 Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor 95 input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. 96 97 I.e. if we have a chain: 98 99 ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 100 101 where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. 102 103 This behavior can be turned off by passing allow_non_tensor=True. 104 """ 105 106 def __init__( 107 self, 108 module: torch.fx.GraphModule, 109 operator_support: OperatorSupportBase, 110 allow_non_tensor: bool, 111 ): 112 self.module = module 113 self.operator_support = operator_support 114 self.allow_non_tensor = allow_non_tensor 115 self.acc_nodes: NodeSet = set() 116 117 def reduce_acc_nodes_non_tensor_input_helper( 118 self, cpu_worklist: NodeList 119 ): 120 """ 121 Transitively excludes nodes from ACC supported set. 122 For every node in the worklist: 123 - removes its downstream ACC nodes from ACC supported set, 124 - if any downstream ACC node produces non-tensor output, 125 then it gets added into the worklist. 126 """ 127 while cpu_worklist: 128 node = cpu_worklist.pop(0) 129 130 for user in node.users: 131 if user in self.acc_nodes: 132 self.acc_nodes.remove(user) 133 if not is_node_output_tensor(user): 134 cpu_worklist.append(user) 135 136 def reduce_acc_nodes_non_tensor_input(self): 137 """ 138 Excludes nodes from ACC supported set that have direct 139 upstream CPU nodes that produce non-tensor outputs. 140 """ 141 non_tensor_cpu_nodes: NodeList = [] 142 143 for node in self.module.graph.nodes: 144 if node.op not in CALLABLE_NODE_OPS: 145 continue 146 if node in self.acc_nodes: 147 continue 148 if is_node_output_tensor(node): 149 continue 150 non_tensor_cpu_nodes.append(node) 151 152 self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) 153 154 def reduce_acc_nodes_non_tensor_output(self): 155 """ 156 Excludes nodes from ACC supported set that produce non-tensor 157 outputs and have downstream CPU nodes. 158 """ 159 while True: 160 new_cpu_nodes: NodeList = [] 161 162 for acc_node in self.acc_nodes: 163 if is_node_output_tensor(acc_node): 164 continue 165 for user in acc_node.users: 166 if user not in self.acc_nodes: 167 new_cpu_nodes.append(acc_node) 168 break 169 170 if not new_cpu_nodes: 171 break 172 173 for new_cpu_node in new_cpu_nodes: 174 self.acc_nodes.remove(new_cpu_node) 175 176 self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) 177 178 def __call__(self) -> NodeSet: 179 submodules = dict(self.module.named_modules()) 180 self.acc_nodes = { 181 n 182 for n in self.module.graph.nodes 183 if n.op in CALLABLE_NODE_OPS 184 and self.operator_support.is_node_supported(submodules, n) 185 } 186 187 if not self.allow_non_tensor: 188 self.reduce_acc_nodes_non_tensor_input() 189 self.reduce_acc_nodes_non_tensor_output() 190 191 return self.acc_nodes 192 193@compatibility(is_backward_compatible=False) 194class FxNetSplitterInternalError(Exception): 195 pass 196 197@compatibility(is_backward_compatible=False) 198@dataclass 199class Subgraph: 200 is_acc: bool 201 nodes: NodeList 202 device_ordinal: Optional[int] = None 203 204@compatibility(is_backward_compatible=False) 205class SplitResult(NamedTuple): 206 """ 207 Stores the results of the splitter. 208 209 Attributes: 210 split_module: root module after splitting. 211 submodule_inputs: a dict that maps submodule name to its inputs. 212 non_acc_submodule_prefix: the prefix for non acc submodules. For 213 acc submodule the prefix is alwasy "_run_on_acc_". 214 """ 215 216 split_module: torch.fx.GraphModule 217 submodule_inputs: Dict[str, Any] 218 non_acc_submodule_prefix: str 219 220 221@compatibility(is_backward_compatible=False) 222def generate_inputs_for_submodules( 223 model: torch.nn.Module, 224 inputs: Sequence[Any], 225 target_submodules: Iterable[str], 226 deepcopy: bool = False, 227) -> Dict[str, Any]: 228 """ 229 Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this 230 function doesn't work. 231 232 Args: 233 model: root model. 234 inputs: inputs to the root model. 235 target_submodules: submodules that we want to generate inputs for. 236 237 Returns: 238 A dict that maps from submodule name to its inputs. 239 """ 240 241 handles = [] 242 results = {} 243 submodule_to_names = {mod: name for name, mod in model.named_modules()} 244 245 def pre_forward(module, module_inputs): 246 results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs 247 248 for name, mod in model.named_modules(): 249 if name in target_submodules: 250 handles.append(mod.register_forward_pre_hook(pre_forward)) 251 252 def clean_up_handles(): 253 for h in handles: 254 h.remove() 255 256 try: 257 with torch.no_grad(): 258 model(*inputs) 259 except Exception as e: 260 clean_up_handles() 261 raise e 262 263 clean_up_handles() 264 return results 265 266 267class _SplitterBase: 268 """ 269 Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. 270 Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. 271 Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. 272 273 Given the following graph: 274 ==> b ==> 275 // \\ 276 a d 277 \\ // 278 ==> c ==> 279 280 class SimpleModule(torch.nn.Module): 281 def forward(self, a): 282 b = torch.sin(a) 283 c = torch.cos(a) 284 d = b + c 285 return d 286 287 and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, 288 we will get the following split result: 289 290 main: 291 def forward(self, a): 292 run_on_acc_0_0 = self._run_on_acc_0_0(a) 293 getitem = run_on_acc_0_0[0] 294 getitem_1 = run_on_acc_0_0[1] 295 run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) 296 return run_on_cpu_1_1 297 298 _run_on_acc_0_0: 299 def forward(self, a): 300 sin_1 = torch.sin(a) 301 cos_1 = torch.cos(a) 302 return (sin_1, cos_1) 303 304 _run_on_cpu_1_1: 305 def forward(self, sin_1, cos_1): 306 add_1 = sin_1 + cos_1 307 return add_1 308 """ 309 310 # PCIe bandwidth for the backend, default to 100 GB/s 311 PCIe_BW = 100 * 2 ** 30 312 313 def __init__( 314 self, 315 module: torch.fx.GraphModule, 316 sample_input: Sequence[Any], 317 operator_support: OperatorSupportBase, 318 settings: _SplitterSettingBase, 319 non_acc_submodule_name: str = "_run_on_cpu_", 320 return_tuple: bool = False, 321 ): 322 """ 323 Preprocesses graph before splitting: 324 - finds nodes supported by ACC, 325 - finds fusion groups for ACC nodes having non-tensor IO, 326 - builds a graph of direct dependencies, 327 - builds a map of fused nodes to their fusions. 328 As a result we get self.acc_nodes, self.deps and self.fusions. 329 """ 330 assert isinstance(module, torch.fx.GraphModule) 331 332 self.module = module 333 ShapeProp(self.module).propagate(*sample_input) 334 335 self.settings = settings 336 self.operator_support = operator_support 337 self.sample_input = sample_input 338 self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() 339 340 if self.settings.skip_fusion: 341 self.fusions = {} 342 else: 343 self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() 344 345 # Modify deps to add more deps for fused nodes 346 self.deps = self.find_deps() 347 self.update_deps_for_fusions() 348 349 self.non_acc_submodule_name = non_acc_submodule_name 350 self._node_submodule_map: Dict[str, str] = {} 351 self._return_tuple = return_tuple 352 353 self.tags: List[str] = [] 354 355 # =============================================================== 356 # Helpers for ctor and initial state 357 # =============================================================== 358 359 def get_node_submodule_map(self) -> Dict[str, str]: 360 """ Returns a map from node name to submodule name, e.g. 361 node: main_module_impl_impl_over_arch_unary_multiple_embedding 362 _pooling_embedding_pooling_sparse_entity_equivalence_key 363 _proxy_embedding_bag 364 maps to submodule name of: _run_on_acc_1 365 """ 366 return self._node_submodule_map 367 368 def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: 369 """ 370 Builds a graph of node dependencies. Leaf nodes don't have any 371 dependencies and the "output" node doesn't have nodes depending on it. 372 373 Resulting graph has only direct dependencies, i.e. there are no 374 transitive dependencies. 375 """ 376 deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) 377 for node in self.module.graph.nodes: 378 if node.op not in CALLABLE_NODE_OPS: 379 continue 380 381 for user in node.users: 382 if user.op != "output": 383 deps[user].add(node) 384 return deps 385 386 def update_deps_for_fusions(self): 387 """ 388 Updates graph of dependencies so that: 389 - nodes from the same fusion depend on the same set of outer nodes, 390 - outer nodes depending on a fusion depend on all nodes in that fusion. 391 """ 392 for node in self.fusions: 393 fusion = self.fusions[node] 394 for fused_neighbor in fusion: 395 self.deps[node].update(self.deps[fused_neighbor] - fusion) 396 397 for user in fused_neighbor.users: 398 if user not in fusion: 399 self.deps[user].add(node) 400 401 # =============================================================== 402 # Helpers for preview 403 # =============================================================== 404 405 def _lower_model_to_backend( 406 self, mod: torch.fx.GraphModule, inputs: Tensors 407 ) -> torch.nn.Module: 408 """ 409 Lower the model to a backend. 410 """ 411 412 return mod 413 414 def _find_culprit( 415 self, mod: torch.fx.GraphModule, inputs: Tensors 416 ) -> str: 417 """ 418 When an error occurs during lowering or running the lowered mod, we use this 419 function to find culprits in the `mod` that causes the error. 420 """ 421 422 return "Unable to find a culprit because _find_culprit() function is not implemented." 423 424 def _draw_graph_based_on_node_support( 425 self, mod: torch.fx.GraphModule, supported_nodes: NodeList 426 ): 427 color_map = { 428 "default": "AliceBlue", 429 "supported": "chartreuse1", 430 "unsupported": "crimson", 431 } 432 433 class CustomDrawer(FxGraphDrawer): 434 def _get_node_style(self, node): 435 template = super()._get_node_style(node) 436 if node in supported_nodes: 437 template["fillcolor"] = color_map["supported"] 438 elif node.op in CALLABLE_NODE_OPS: 439 template["fillcolor"] = color_map["unsupported"] 440 else: 441 template["fillcolor"] = color_map["default"] 442 443 return template 444 445 drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) 446 dot_graph = drawer.get_main_dot_graph() 447 # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. 448 dot_graph.write_raw("node_support.dot") 449 450 def node_support_preview(self, dump_graph: bool = False): 451 submodules = dict(self.module.named_modules()) 452 453 supported_nodes: NodeList = [] 454 supported_node_types = defaultdict(set) 455 unsupported_node_types = defaultdict(set) 456 457 def get_dtype(arg): 458 tensor_meta = arg.meta.get("tensor_meta") 459 return getattr(tensor_meta, "dtype", None) 460 461 for node in self.module.graph.nodes: 462 if node.op not in CALLABLE_NODE_OPS: 463 continue 464 465 target = get_node_target(submodules, node) 466 467 # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. 468 arg_dtypes = [ 469 get_dtype(arg) if isinstance(arg, torch.fx.Node) else None 470 for arg in node.args 471 ] 472 473 # Find last non-None element. If all elements are None, return max_len. 474 last_index = len(arg_dtypes) - next( 475 ( 476 i 477 for i, dtype in enumerate(reversed(arg_dtypes)) 478 if dtype is not None 479 ), 480 len(arg_dtypes), 481 ) 482 483 # Strip None elements at the end. 484 arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) 485 kwarg_dtypes_tuple = tuple( 486 (k, get_dtype(arg)) 487 for k, arg in node.kwargs.items() 488 if isinstance(arg, torch.fx.Node) 489 ) 490 491 if self.operator_support.is_node_supported(submodules, node): 492 supported_nodes.append(node) 493 supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) 494 else: 495 unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) 496 497 if dump_graph: 498 self._draw_graph_based_on_node_support(self.module, supported_nodes) 499 500 reports = "\nSupported node types in the model:\n" 501 for t, dtypes in supported_node_types.items(): 502 for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: 503 reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" 504 505 reports += "\nUnsupported node types in the model:\n" 506 for t, dtypes in unsupported_node_types.items(): 507 for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: 508 reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" 509 510 print(reports) 511 512 # Return reports for testing purpose 513 return reports 514 515 def split_preview(self, dump_graph: bool = False): 516 reports = "" 517 subgraphs = self.put_nodes_into_subgraphs() 518 acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) 519 cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num 520 reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" 521 reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" 522 523 subgraphs = self.remove_small_acc_subgraphs(subgraphs) 524 acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) 525 cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num 526 reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" 527 reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" 528 529 for i, subgraph in enumerate(subgraphs): 530 reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " 531 reports += f"{len(subgraph.nodes)} node(s)\n" 532 533 self.tag(subgraphs) 534 split_mod = self.split(remove_tag=True) 535 split_mod.eval() 536 537 if dump_graph: 538 drawer = FxGraphDrawer( 539 split_mod, "preview", ignore_getattr=True 540 ) 541 dot_graphs = drawer.get_all_dot_graphs() 542 for name, dot_graph in dot_graphs.items(): 543 # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. 544 dot_graph.write_raw(f"{name}.dot") 545 546 max_qps: float = self.PCIe_BW 547 bottleneck_module = "" 548 549 for node in split_mod.graph.nodes: 550 if node.op == "call_module" and "acc" in node.target: 551 reports += f"\nProcessing acc submodule {node.target}\n" 552 553 submod = getattr(split_mod, node.target) 554 555 def get_submod_inputs(main_mod, submod, example_inputs): 556 sub_inputs = None 557 558 def get_inputs(self, inputs): 559 nonlocal sub_inputs 560 sub_inputs = inputs 561 562 handle = submod.register_forward_pre_hook(get_inputs) 563 main_mod(*example_inputs) 564 handle.remove() 565 return sub_inputs 566 567 submod_inputs = get_submod_inputs( 568 split_mod, submod, self.sample_input 569 ) 570 ShapeProp(submod).propagate(*submod_inputs) 571 572 total_input_bytes = 0 573 total_output_bytes = 0 574 575 reports += "Checking inputs...\n" 576 for n in submod.graph.nodes: 577 if n.op == "placeholder": 578 if not is_node_output_tensor(n): 579 reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" 580 else: 581 total_input_bytes += get_size_of_node(submod, n)[0] 582 if n.op == "output": 583 output_node = n 584 585 reports += "Checking outputs...\n" 586 587 def get_bytes(node: torch.fx.Node): 588 nonlocal total_output_bytes 589 nonlocal reports 590 if not is_node_output_tensor(node): 591 reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" 592 else: 593 total_output_bytes += get_size_of_node(submod, node)[0] 594 595 map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] 596 qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) 597 reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," 598 reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" 599 600 if qps < max_qps: 601 max_qps = qps 602 bottleneck_module = node.target 603 604 try: 605 lowered_submod = self._lower_model_to_backend(submod, submod_inputs) 606 except RuntimeError: 607 reports += "Run into an error during lowering!\n" 608 reports += self._find_culprit(submod, submod_inputs) 609 continue 610 611 try: 612 lowered_submod(*submod_inputs) 613 except RuntimeError: 614 reports += "Run into an error during inference!\n" 615 reports += self._find_culprit(submod, submod_inputs) 616 else: 617 reports += "Lowering and running succeed!\n" 618 619 reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," 620 reports += f" bottleneck is submodule {bottleneck_module}." 621 print(reports) 622 623 # return the reports for testing purposes 624 return reports 625 626 # =============================================================== 627 # Helpers for extend_acc_subgraph() method 628 # =============================================================== 629 630 def find_reverse_deps( 631 self, tag_id: Optional[int] = None 632 ) -> Dict[torch.fx.Node, NodeSet]: 633 """ 634 Builds reversed topological node dependencies, if tag_id is specified, 635 we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. 636 """ 637 result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) 638 639 for node in self.module.graph.nodes: 640 if node.op not in CALLABLE_NODE_OPS: 641 continue 642 643 for user in node.users: 644 if user.op not in CALLABLE_NODE_OPS: 645 continue 646 647 if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): 648 result[node].add(user) 649 650 return result 651 652 def update_reverse_deps_for_fusions( 653 self, deps: Dict[torch.fx.Node, NodeSet] 654 ): 655 processed_node = set() 656 657 for node, fusion in self.fusions.items(): 658 if node in processed_node: 659 continue 660 661 new_dep = set() 662 663 # Create a new dependency set which include all the 664 # dependencies of the nodes in the fusion group 665 for n in fusion: 666 new_dep.update(deps[n]) 667 668 # Exclude nodes in the fusion 669 new_dep.difference_update(fusion) 670 671 # Update dependency 672 for n in fusion: 673 deps[n] = new_dep 674 675 for arg in n.all_input_nodes: 676 if arg not in fusion: 677 deps[arg].update(fusion) 678 679 processed_node.add(n) 680 681 def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: 682 """ 683 Finds parent nodes of the `tag` subgraph. 684 685 Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph 686 and is not a placeholder, we consider it as the parent node of the subgraph. 687 """ 688 parent_nodes = set() 689 690 for node in self.module.graph.nodes: 691 if node.op in CALLABLE_NODE_OPS and node.tag == tag: 692 for arg in node.all_input_nodes: 693 if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: 694 parent_nodes.add(arg) 695 696 return parent_nodes 697 698 def extend_acc_subgraph(self, tag: str): 699 """ 700 Extend the acc subgraph with `tag` going the reversed topological direction. 701 """ 702 # Dict that maps node to its users and ignore users that 703 # are in the subgraph that has greater tag 704 deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) 705 self.update_reverse_deps_for_fusions(deps) 706 707 # Parent nodes of the subgraph 708 parent_nodes = self.find_parent_nodes_of_subgraph(tag) 709 710 visited_nodes: NodeSet = set() 711 712 while parent_nodes: 713 node = None 714 715 # Find a acc node that depends on visited nodes only 716 for n in parent_nodes: 717 if deps[n] <= visited_nodes and n in self.acc_nodes: 718 node = n 719 break 720 721 if node is None: 722 break 723 724 # Put the node into `tag` subgraph 725 node.tag = tag # type: ignore[attr-defined] 726 parent_nodes.remove(node) 727 visited_nodes.add(node) 728 729 # If node is in a fusion group, add all fusion buddies to parent nodes 730 if node in self.fusions: 731 for fusion_node in self.fusions[node]: 732 if fusion_node not in visited_nodes: 733 parent_nodes.add(fusion_node) 734 735 # Add inputs of the node to parent nodes 736 for arg in node.all_input_nodes: 737 if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: 738 parent_nodes.add(arg) 739 740 # =============================================================== 741 # Helpers for split() method 742 # =============================================================== 743 744 def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: 745 """ 746 Finds nodes that consume module inputs or get_attr nodes. 747 """ 748 starter_cpu_nodes: NodeSet = set() 749 starter_acc_nodes: NodeSet = set() 750 for node in self.module.graph.nodes: 751 if node.op not in {"placeholder", "get_attr"}: 752 continue 753 for user in node.users: 754 if user in self.acc_nodes: 755 starter_acc_nodes.add(user) 756 else: 757 starter_cpu_nodes.add(user) 758 return starter_cpu_nodes, starter_acc_nodes 759 760 def put_nodes_into_subgraphs(self) -> List[Subgraph]: 761 # We start graph traversal from leaf nodes 762 current_cpu_nodes, current_acc_nodes = self.starter_nodes() 763 visited_nodes: NodeSet = set() 764 765 # Determine which subgraph to start from based on which subgraph has 766 # 0-dep node 767 acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) 768 769 current_subgraph_nodes: NodeList = [] 770 771 # Result accumulator 772 subgraphs: List[Subgraph] = [] 773 while current_cpu_nodes or current_acc_nodes: 774 # Find the first node that should belong to the current subgraph and has all dependencies resolved 775 current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes 776 node = next( 777 (n for n in current_nodes if self.deps[n] <= visited_nodes), 778 None, 779 ) 780 781 # If nothing was found, then it's time to flip the mode and start a new subgraph 782 if node is None: 783 if not current_subgraph_nodes: 784 raise FxNetSplitterInternalError("Subgraph can't be empty") 785 786 subgraphs.append( 787 Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) 788 ) 789 acc_subgraph = not acc_subgraph 790 current_subgraph_nodes = [] 791 continue 792 793 current_nodes.remove(node) 794 visited_nodes.add(node) 795 current_subgraph_nodes.append(node) 796 797 # Add fusion buddies 798 if node in self.fusions: 799 if node in self.acc_nodes: 800 current_acc_nodes.update(self.fusions[node] - visited_nodes) 801 else: 802 current_cpu_nodes.update(self.fusions[node] - visited_nodes) 803 804 # Put depending nodes into the queue 805 for user in node.users: 806 if user.op not in CALLABLE_NODE_OPS: 807 continue 808 809 # Add downstream nodes 810 if user in self.acc_nodes: 811 current_acc_nodes.add(user) 812 else: 813 current_cpu_nodes.add(user) 814 815 # Check if the last subgraph was not created 816 if current_subgraph_nodes: 817 subgraphs.append( 818 Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) 819 ) 820 821 if not subgraphs: 822 raise FxNetSplitterInternalError("Couldn't create subgraphs") 823 824 return subgraphs 825 826 def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: 827 """ 828 This pass finds ACC submodules with less than specified size and merges 829 them with adjacent CPU submodules. 830 """ 831 result: List[Subgraph] = [] 832 for subgraph in subgraphs: 833 if subgraph.is_acc: 834 if len(subgraph.nodes) >= self.settings.min_acc_module_size: 835 result.append(subgraph) 836 else: 837 print( 838 "Eliminating acc subgraph because it's smaller than the threshold: " 839 f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" 840 ) 841 if result: 842 result[-1].nodes.extend(subgraph.nodes) 843 else: 844 subgraph.is_acc = False 845 result.append(subgraph) 846 else: 847 if result and not result[-1].is_acc: 848 result[-1].nodes.extend(subgraph.nodes) 849 else: 850 result.append(subgraph) 851 return result 852 853 def tag(self, subgraphs: List[Subgraph]): 854 self.tags = [] 855 for subgraph in subgraphs: 856 tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" 857 self.tags.append(tag) 858 for node in subgraph.nodes: 859 if hasattr(node, "tag"): 860 raise FxNetSplitterInternalError(f"Node {node} was already tagged") 861 862 node.tag = tag # type: ignore[attr-defined] 863 self._node_submodule_map[node.name] = tag 864 865 def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: 866 split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple) 867 if remove_tag: 868 for node in self.module.graph.nodes: 869 if hasattr(node, "tag"): 870 del node.tag 871 return split_module # type: ignore[return-value] 872 873 def __call__(self) -> torch.fx.GraphModule: 874 subgraphs = self.put_nodes_into_subgraphs() 875 subgraphs = self.remove_small_acc_subgraphs(subgraphs) 876 acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) 877 non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count 878 print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") 879 self.tag(subgraphs) 880 return self.split() 881 882 def generate_split_results(self) -> SplitResult: 883 split_module = self() 884 submodule_names = [] 885 for name, mod in split_module.named_children(): 886 submodule_names.append(name) 887 if ( 888 self.settings.max_acc_splits > 0 889 and len(submodule_names) > self.settings.max_acc_splits 890 ): 891 raise ValueError( 892 "Cannot fulfill max_acc_splits limit. " 893 "This may cause split fragmentation and " 894 "result in performance issues." 895 ) 896 897 submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) 898 return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) 899