xref: /aosp_15_r20/external/pytorch/torch/fx/passes/splitter_base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import argparse
3import copy
4from collections import defaultdict
5from dataclasses import dataclass
6from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
7import logging
8
9import torch
10from torch.fx.passes.graph_manipulation import get_size_of_node
11from torch.fx.node import map_arg
12from torch.fx._compatibility import compatibility
13
14from .operator_support import (
15    get_node_target,
16    OperatorSupportBase,
17)
18from .graph_drawer import FxGraphDrawer
19from .shape_prop import ShapeProp
20from .split_utils import split_by_tags
21from .tools_common import (
22    FxNetAccFusionsFinder,
23    CALLABLE_NODE_OPS,
24    Tensors,
25    NodeList,
26    NodeSet,
27    is_node_output_tensor,
28)
29
30
31__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
32_LOGGER = logging.getLogger(__name__)
33
34DEFAULT_MIN_ACC_MODULE_SIZE = 1
35DEFAULT_SKIP_FUSION = False
36DEFAULT_ALLOW_NON_TENSOR = False
37
38class _SplitterSettingBase:
39    def __init__(
40        self,
41        min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
42        skip_fusion=DEFAULT_SKIP_FUSION,
43        allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
44        max_acc_splits: int = -1,
45    ):
46        parser = argparse.ArgumentParser()
47        parser.add_argument(
48            "--min-acc-module-size",
49            "--min_acc_module_size",
50            required=False,
51            type=int,
52            help="Minimum size limit of an accelerator subgraph.",
53        )
54        parser.add_argument(
55            "--max-acc-splits",
56            "--max_acc_splits",
57            required=False,
58            type=int,
59            help="Enforce a maximum number of split subgraphs.",
60        )
61        parser.add_argument(
62            "--skip-fusion",
63            "--skip_fusion",
64            default=False,
65            action="store_true",
66            help="If true then no fusion groups. Fusion group is used to "
67            "enforce no non-tensor data flow between submodules. If we don't "
68            "have this constrain, setting this to false is recommended as it "
69            "can reduce overhead.",
70        )
71        parser.add_argument(
72            "--allow-non-tensor",
73            "--allow_non_tensor",
74            default=False,
75            action="store_true",
76            help="For some backends non-tensor data flow between cpu and them "
77            "are not allowed. Therefore, if a node supported by accelerator but "
78            "it has non-tensor inputs or outputs to a cpu node we would want to "
79            "consider it as a cpu node during splitting. However, for some backends "
80            "we might not care about non-tensor data flow and we can set this option "
81            "to true to disable the functionality that prevent non-tensor data flow.",
82        )
83        args, unknown = parser.parse_known_args()
84
85        self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
86        self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
87        self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
88        self.max_acc_splits: int = max_acc_splits
89
90
91@compatibility(is_backward_compatible=False)
92class FxNetAccNodesFinder:
93    """
94    Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
95    input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
96
97    I.e. if we have a chain:
98
99    ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
100
101    where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
102
103    This behavior can be turned off by passing allow_non_tensor=True.
104    """
105
106    def __init__(
107        self,
108        module: torch.fx.GraphModule,
109        operator_support: OperatorSupportBase,
110        allow_non_tensor: bool,
111    ):
112        self.module = module
113        self.operator_support = operator_support
114        self.allow_non_tensor = allow_non_tensor
115        self.acc_nodes: NodeSet = set()
116
117    def reduce_acc_nodes_non_tensor_input_helper(
118        self, cpu_worklist: NodeList
119    ):
120        """
121        Transitively excludes nodes from ACC supported set.
122        For every node in the worklist:
123        - removes its downstream ACC nodes from ACC supported set,
124        - if any downstream ACC node produces non-tensor output,
125          then it gets added into the worklist.
126        """
127        while cpu_worklist:
128            node = cpu_worklist.pop(0)
129
130            for user in node.users:
131                if user in self.acc_nodes:
132                    self.acc_nodes.remove(user)
133                    if not is_node_output_tensor(user):
134                        cpu_worklist.append(user)
135
136    def reduce_acc_nodes_non_tensor_input(self):
137        """
138        Excludes nodes from ACC supported set that have direct
139        upstream CPU nodes that produce non-tensor outputs.
140        """
141        non_tensor_cpu_nodes: NodeList = []
142
143        for node in self.module.graph.nodes:
144            if node.op not in CALLABLE_NODE_OPS:
145                continue
146            if node in self.acc_nodes:
147                continue
148            if is_node_output_tensor(node):
149                continue
150            non_tensor_cpu_nodes.append(node)
151
152        self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
153
154    def reduce_acc_nodes_non_tensor_output(self):
155        """
156        Excludes nodes from ACC supported set that produce non-tensor
157        outputs and have downstream CPU nodes.
158        """
159        while True:
160            new_cpu_nodes: NodeList = []
161
162            for acc_node in self.acc_nodes:
163                if is_node_output_tensor(acc_node):
164                    continue
165                for user in acc_node.users:
166                    if user not in self.acc_nodes:
167                        new_cpu_nodes.append(acc_node)
168                        break
169
170            if not new_cpu_nodes:
171                break
172
173            for new_cpu_node in new_cpu_nodes:
174                self.acc_nodes.remove(new_cpu_node)
175
176            self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
177
178    def __call__(self) -> NodeSet:
179        submodules = dict(self.module.named_modules())
180        self.acc_nodes = {
181            n
182            for n in self.module.graph.nodes
183            if n.op in CALLABLE_NODE_OPS
184            and self.operator_support.is_node_supported(submodules, n)
185        }
186
187        if not self.allow_non_tensor:
188            self.reduce_acc_nodes_non_tensor_input()
189            self.reduce_acc_nodes_non_tensor_output()
190
191        return self.acc_nodes
192
193@compatibility(is_backward_compatible=False)
194class FxNetSplitterInternalError(Exception):
195    pass
196
197@compatibility(is_backward_compatible=False)
198@dataclass
199class Subgraph:
200    is_acc: bool
201    nodes: NodeList
202    device_ordinal: Optional[int] = None
203
204@compatibility(is_backward_compatible=False)
205class SplitResult(NamedTuple):
206    """
207    Stores the results of the splitter.
208
209    Attributes:
210        split_module: root module after splitting.
211        submodule_inputs: a dict that maps submodule name to its inputs.
212        non_acc_submodule_prefix: the prefix for non acc submodules. For
213            acc submodule the prefix is alwasy "_run_on_acc_".
214    """
215
216    split_module: torch.fx.GraphModule
217    submodule_inputs: Dict[str, Any]
218    non_acc_submodule_prefix: str
219
220
221@compatibility(is_backward_compatible=False)
222def generate_inputs_for_submodules(
223    model: torch.nn.Module,
224    inputs: Sequence[Any],
225    target_submodules: Iterable[str],
226    deepcopy: bool = False,
227) -> Dict[str, Any]:
228    """
229    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
230    function doesn't work.
231
232    Args:
233        model: root model.
234        inputs: inputs to the root model.
235        target_submodules: submodules that we want to generate inputs for.
236
237    Returns:
238        A dict that maps from submodule name to its inputs.
239    """
240
241    handles = []
242    results = {}
243    submodule_to_names = {mod: name for name, mod in model.named_modules()}
244
245    def pre_forward(module, module_inputs):
246        results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
247
248    for name, mod in model.named_modules():
249        if name in target_submodules:
250            handles.append(mod.register_forward_pre_hook(pre_forward))
251
252    def clean_up_handles():
253        for h in handles:
254            h.remove()
255
256    try:
257        with torch.no_grad():
258            model(*inputs)
259    except Exception as e:
260        clean_up_handles()
261        raise e
262
263    clean_up_handles()
264    return results
265
266
267class _SplitterBase:
268    """
269    Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
270    Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
271    Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
272
273    Given the following graph:
274          ==> b ==>
275        //         \\
276       a             d
277        \\         //
278          ==> c ==>
279
280    class SimpleModule(torch.nn.Module):
281        def forward(self, a):
282            b = torch.sin(a)
283            c = torch.cos(a)
284            d = b + c
285            return d
286
287    and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
288    we will get the following split result:
289
290    main:
291    def forward(self, a):
292        run_on_acc_0_0 = self._run_on_acc_0_0(a)
293        getitem = run_on_acc_0_0[0]
294        getitem_1 = run_on_acc_0_0[1]
295        run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
296        return run_on_cpu_1_1
297
298    _run_on_acc_0_0:
299    def forward(self, a):
300        sin_1 = torch.sin(a)
301        cos_1 = torch.cos(a)
302        return (sin_1, cos_1)
303
304    _run_on_cpu_1_1:
305    def forward(self, sin_1, cos_1):
306        add_1 = sin_1 + cos_1
307        return add_1
308    """
309
310    # PCIe bandwidth for the backend, default to 100 GB/s
311    PCIe_BW = 100 * 2 ** 30
312
313    def __init__(
314        self,
315        module: torch.fx.GraphModule,
316        sample_input: Sequence[Any],
317        operator_support: OperatorSupportBase,
318        settings: _SplitterSettingBase,
319        non_acc_submodule_name: str = "_run_on_cpu_",
320        return_tuple: bool = False,
321    ):
322        """
323        Preprocesses graph before splitting:
324        - finds nodes supported by ACC,
325        - finds fusion groups for ACC nodes having non-tensor IO,
326        - builds a graph of direct dependencies,
327        - builds a map of fused nodes to their fusions.
328        As a result we get self.acc_nodes, self.deps and self.fusions.
329        """
330        assert isinstance(module, torch.fx.GraphModule)
331
332        self.module = module
333        ShapeProp(self.module).propagate(*sample_input)
334
335        self.settings = settings
336        self.operator_support = operator_support
337        self.sample_input = sample_input
338        self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
339
340        if self.settings.skip_fusion:
341            self.fusions = {}
342        else:
343            self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
344
345        # Modify deps to add more deps for fused nodes
346        self.deps = self.find_deps()
347        self.update_deps_for_fusions()
348
349        self.non_acc_submodule_name = non_acc_submodule_name
350        self._node_submodule_map: Dict[str, str] = {}
351        self._return_tuple = return_tuple
352
353        self.tags: List[str] = []
354
355    # ===============================================================
356    # Helpers for ctor and initial state
357    # ===============================================================
358
359    def get_node_submodule_map(self) -> Dict[str, str]:
360        """ Returns a map from node name to submodule name, e.g.
361            node: main_module_impl_impl_over_arch_unary_multiple_embedding
362              _pooling_embedding_pooling_sparse_entity_equivalence_key
363              _proxy_embedding_bag
364            maps to submodule name of: _run_on_acc_1
365        """
366        return self._node_submodule_map
367
368    def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
369        """
370        Builds a graph of node dependencies. Leaf nodes don't have any
371        dependencies and the "output" node doesn't have nodes depending on it.
372
373        Resulting graph has only direct dependencies, i.e. there are no
374        transitive dependencies.
375        """
376        deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
377        for node in self.module.graph.nodes:
378            if node.op not in CALLABLE_NODE_OPS:
379                continue
380
381            for user in node.users:
382                if user.op != "output":
383                    deps[user].add(node)
384        return deps
385
386    def update_deps_for_fusions(self):
387        """
388        Updates graph of dependencies so that:
389        - nodes from the same fusion depend on the same set of outer nodes,
390        - outer nodes depending on a fusion depend on all nodes in that fusion.
391        """
392        for node in self.fusions:
393            fusion = self.fusions[node]
394            for fused_neighbor in fusion:
395                self.deps[node].update(self.deps[fused_neighbor] - fusion)
396
397                for user in fused_neighbor.users:
398                    if user not in fusion:
399                        self.deps[user].add(node)
400
401    # ===============================================================
402    # Helpers for preview
403    # ===============================================================
404
405    def _lower_model_to_backend(
406        self, mod: torch.fx.GraphModule, inputs: Tensors
407    ) -> torch.nn.Module:
408        """
409        Lower the model to a backend.
410        """
411
412        return mod
413
414    def _find_culprit(
415        self, mod: torch.fx.GraphModule, inputs: Tensors
416    ) -> str:
417        """
418        When an error occurs during lowering or running the lowered mod, we use this
419        function to find culprits in the `mod` that causes the error.
420        """
421
422        return "Unable to find a culprit because _find_culprit() function is not implemented."
423
424    def _draw_graph_based_on_node_support(
425        self, mod: torch.fx.GraphModule, supported_nodes: NodeList
426    ):
427        color_map = {
428            "default": "AliceBlue",
429            "supported": "chartreuse1",
430            "unsupported": "crimson",
431        }
432
433        class CustomDrawer(FxGraphDrawer):
434            def _get_node_style(self, node):
435                template = super()._get_node_style(node)
436                if node in supported_nodes:
437                    template["fillcolor"] = color_map["supported"]
438                elif node.op in CALLABLE_NODE_OPS:
439                    template["fillcolor"] = color_map["unsupported"]
440                else:
441                    template["fillcolor"] = color_map["default"]
442
443                return template
444
445        drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
446        dot_graph = drawer.get_main_dot_graph()
447        # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
448        dot_graph.write_raw("node_support.dot")
449
450    def node_support_preview(self, dump_graph: bool = False):
451        submodules = dict(self.module.named_modules())
452
453        supported_nodes: NodeList = []
454        supported_node_types = defaultdict(set)
455        unsupported_node_types = defaultdict(set)
456
457        def get_dtype(arg):
458            tensor_meta = arg.meta.get("tensor_meta")
459            return getattr(tensor_meta, "dtype", None)
460
461        for node in self.module.graph.nodes:
462            if node.op not in CALLABLE_NODE_OPS:
463                continue
464
465            target = get_node_target(submodules, node)
466
467            # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
468            arg_dtypes = [
469                get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
470                for arg in node.args
471            ]
472
473            # Find last non-None element. If all elements are None, return max_len.
474            last_index = len(arg_dtypes) - next(
475                (
476                    i
477                    for i, dtype in enumerate(reversed(arg_dtypes))
478                    if dtype is not None
479                ),
480                len(arg_dtypes),
481            )
482
483            # Strip None elements at the end.
484            arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
485            kwarg_dtypes_tuple = tuple(
486                (k, get_dtype(arg))
487                for k, arg in node.kwargs.items()
488                if isinstance(arg, torch.fx.Node)
489            )
490
491            if self.operator_support.is_node_supported(submodules, node):
492                supported_nodes.append(node)
493                supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
494            else:
495                unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
496
497        if dump_graph:
498            self._draw_graph_based_on_node_support(self.module, supported_nodes)
499
500        reports = "\nSupported node types in the model:\n"
501        for t, dtypes in supported_node_types.items():
502            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
503                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
504
505        reports += "\nUnsupported node types in the model:\n"
506        for t, dtypes in unsupported_node_types.items():
507            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
508                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
509
510        print(reports)
511
512        # Return reports for testing purpose
513        return reports
514
515    def split_preview(self, dump_graph: bool = False):
516        reports = ""
517        subgraphs = self.put_nodes_into_subgraphs()
518        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
519        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
520        reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
521        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
522
523        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
524        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
525        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
526        reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
527        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
528
529        for i, subgraph in enumerate(subgraphs):
530            reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
531            reports += f"{len(subgraph.nodes)} node(s)\n"
532
533        self.tag(subgraphs)
534        split_mod = self.split(remove_tag=True)
535        split_mod.eval()
536
537        if dump_graph:
538            drawer = FxGraphDrawer(
539                split_mod, "preview", ignore_getattr=True
540            )
541            dot_graphs = drawer.get_all_dot_graphs()
542            for name, dot_graph in dot_graphs.items():
543                # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
544                dot_graph.write_raw(f"{name}.dot")
545
546        max_qps: float = self.PCIe_BW
547        bottleneck_module = ""
548
549        for node in split_mod.graph.nodes:
550            if node.op == "call_module" and "acc" in node.target:
551                reports += f"\nProcessing acc submodule {node.target}\n"
552
553                submod = getattr(split_mod, node.target)
554
555                def get_submod_inputs(main_mod, submod, example_inputs):
556                    sub_inputs = None
557
558                    def get_inputs(self, inputs):
559                        nonlocal sub_inputs
560                        sub_inputs = inputs
561
562                    handle = submod.register_forward_pre_hook(get_inputs)
563                    main_mod(*example_inputs)
564                    handle.remove()
565                    return sub_inputs
566
567                submod_inputs = get_submod_inputs(
568                    split_mod, submod, self.sample_input
569                )
570                ShapeProp(submod).propagate(*submod_inputs)
571
572                total_input_bytes = 0
573                total_output_bytes = 0
574
575                reports += "Checking inputs...\n"
576                for n in submod.graph.nodes:
577                    if n.op == "placeholder":
578                        if not is_node_output_tensor(n):
579                            reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
580                        else:
581                            total_input_bytes += get_size_of_node(submod, n)[0]
582                    if n.op == "output":
583                        output_node = n
584
585                reports += "Checking outputs...\n"
586
587                def get_bytes(node: torch.fx.Node):
588                    nonlocal total_output_bytes
589                    nonlocal reports
590                    if not is_node_output_tensor(node):
591                        reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
592                    else:
593                        total_output_bytes += get_size_of_node(submod, node)[0]
594
595                map_arg(output_node.args, get_bytes)  # type: ignore[possibly-undefined]
596                qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
597                reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
598                reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
599
600                if qps < max_qps:
601                    max_qps = qps
602                    bottleneck_module = node.target
603
604                try:
605                    lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
606                except RuntimeError:
607                    reports += "Run into an error during lowering!\n"
608                    reports += self._find_culprit(submod, submod_inputs)
609                    continue
610
611                try:
612                    lowered_submod(*submod_inputs)
613                except RuntimeError:
614                    reports += "Run into an error during inference!\n"
615                    reports += self._find_culprit(submod, submod_inputs)
616                else:
617                    reports += "Lowering and running succeed!\n"
618
619        reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
620        reports += f" bottleneck is submodule {bottleneck_module}."
621        print(reports)
622
623        # return the reports for testing purposes
624        return reports
625
626    # ===============================================================
627    # Helpers for extend_acc_subgraph() method
628    # ===============================================================
629
630    def find_reverse_deps(
631        self, tag_id: Optional[int] = None
632    ) -> Dict[torch.fx.Node, NodeSet]:
633        """
634        Builds reversed topological node dependencies, if tag_id is specified,
635        we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
636        """
637        result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
638
639        for node in self.module.graph.nodes:
640            if node.op not in CALLABLE_NODE_OPS:
641                continue
642
643            for user in node.users:
644                if user.op not in CALLABLE_NODE_OPS:
645                    continue
646
647                if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
648                    result[node].add(user)
649
650        return result
651
652    def update_reverse_deps_for_fusions(
653        self, deps: Dict[torch.fx.Node, NodeSet]
654    ):
655        processed_node = set()
656
657        for node, fusion in self.fusions.items():
658            if node in processed_node:
659                continue
660
661            new_dep = set()
662
663            # Create a new dependency set which include all the
664            # dependencies of the nodes in the fusion group
665            for n in fusion:
666                new_dep.update(deps[n])
667
668            # Exclude nodes in the fusion
669            new_dep.difference_update(fusion)
670
671            # Update dependency
672            for n in fusion:
673                deps[n] = new_dep
674
675                for arg in n.all_input_nodes:
676                    if arg not in fusion:
677                        deps[arg].update(fusion)
678
679                processed_node.add(n)
680
681    def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
682        """
683        Finds parent nodes of the `tag` subgraph.
684
685        Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
686        and is not a placeholder, we consider it as the parent node of the subgraph.
687        """
688        parent_nodes = set()
689
690        for node in self.module.graph.nodes:
691            if node.op in CALLABLE_NODE_OPS and node.tag == tag:
692                for arg in node.all_input_nodes:
693                    if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
694                        parent_nodes.add(arg)
695
696        return parent_nodes
697
698    def extend_acc_subgraph(self, tag: str):
699        """
700        Extend the acc subgraph with `tag` going the reversed topological direction.
701        """
702        # Dict that maps node to its users and ignore users that
703        # are in the subgraph that has greater tag
704        deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
705        self.update_reverse_deps_for_fusions(deps)
706
707        # Parent nodes of the subgraph
708        parent_nodes = self.find_parent_nodes_of_subgraph(tag)
709
710        visited_nodes: NodeSet = set()
711
712        while parent_nodes:
713            node = None
714
715            # Find a acc node that depends on visited nodes only
716            for n in parent_nodes:
717                if deps[n] <= visited_nodes and n in self.acc_nodes:
718                    node = n
719                    break
720
721            if node is None:
722                break
723
724            # Put the node into `tag` subgraph
725            node.tag = tag  # type: ignore[attr-defined]
726            parent_nodes.remove(node)
727            visited_nodes.add(node)
728
729            # If node is in a fusion group, add all fusion buddies to parent nodes
730            if node in self.fusions:
731                for fusion_node in self.fusions[node]:
732                    if fusion_node not in visited_nodes:
733                        parent_nodes.add(fusion_node)
734
735            # Add inputs of the node to parent nodes
736            for arg in node.all_input_nodes:
737                if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
738                    parent_nodes.add(arg)
739
740    # ===============================================================
741    # Helpers for split() method
742    # ===============================================================
743
744    def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
745        """
746        Finds nodes that consume module inputs or get_attr nodes.
747        """
748        starter_cpu_nodes: NodeSet = set()
749        starter_acc_nodes: NodeSet = set()
750        for node in self.module.graph.nodes:
751            if node.op not in {"placeholder", "get_attr"}:
752                continue
753            for user in node.users:
754                if user in self.acc_nodes:
755                    starter_acc_nodes.add(user)
756                else:
757                    starter_cpu_nodes.add(user)
758        return starter_cpu_nodes, starter_acc_nodes
759
760    def put_nodes_into_subgraphs(self) -> List[Subgraph]:
761        # We start graph traversal from leaf nodes
762        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
763        visited_nodes: NodeSet = set()
764
765        # Determine which subgraph to start from based on which subgraph has
766        # 0-dep node
767        acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
768
769        current_subgraph_nodes: NodeList = []
770
771        # Result accumulator
772        subgraphs: List[Subgraph] = []
773        while current_cpu_nodes or current_acc_nodes:
774            # Find the first node that should belong to the current subgraph and has all dependencies resolved
775            current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
776            node = next(
777                (n for n in current_nodes if self.deps[n] <= visited_nodes),
778                None,
779            )
780
781            # If nothing was found, then it's time to flip the mode and start a new subgraph
782            if node is None:
783                if not current_subgraph_nodes:
784                    raise FxNetSplitterInternalError("Subgraph can't be empty")
785
786                subgraphs.append(
787                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
788                )
789                acc_subgraph = not acc_subgraph
790                current_subgraph_nodes = []
791                continue
792
793            current_nodes.remove(node)
794            visited_nodes.add(node)
795            current_subgraph_nodes.append(node)
796
797            # Add fusion buddies
798            if node in self.fusions:
799                if node in self.acc_nodes:
800                    current_acc_nodes.update(self.fusions[node] - visited_nodes)
801                else:
802                    current_cpu_nodes.update(self.fusions[node] - visited_nodes)
803
804            # Put depending nodes into the queue
805            for user in node.users:
806                if user.op not in CALLABLE_NODE_OPS:
807                    continue
808
809                # Add downstream nodes
810                if user in self.acc_nodes:
811                    current_acc_nodes.add(user)
812                else:
813                    current_cpu_nodes.add(user)
814
815        # Check if the last subgraph was not created
816        if current_subgraph_nodes:
817            subgraphs.append(
818                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
819            )
820
821        if not subgraphs:
822            raise FxNetSplitterInternalError("Couldn't create subgraphs")
823
824        return subgraphs
825
826    def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
827        """
828        This pass finds ACC submodules with less than specified size and merges
829        them with adjacent CPU submodules.
830        """
831        result: List[Subgraph] = []
832        for subgraph in subgraphs:
833            if subgraph.is_acc:
834                if len(subgraph.nodes) >= self.settings.min_acc_module_size:
835                    result.append(subgraph)
836                else:
837                    print(
838                        "Eliminating acc subgraph because it's smaller than the threshold: "
839                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
840                    )
841                    if result:
842                        result[-1].nodes.extend(subgraph.nodes)
843                    else:
844                        subgraph.is_acc = False
845                        result.append(subgraph)
846            else:
847                if result and not result[-1].is_acc:
848                    result[-1].nodes.extend(subgraph.nodes)
849                else:
850                    result.append(subgraph)
851        return result
852
853    def tag(self, subgraphs: List[Subgraph]):
854        self.tags = []
855        for subgraph in subgraphs:
856            tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
857            self.tags.append(tag)
858            for node in subgraph.nodes:
859                if hasattr(node, "tag"):
860                    raise FxNetSplitterInternalError(f"Node {node} was already tagged")
861
862                node.tag = tag  # type: ignore[attr-defined]
863                self._node_submodule_map[node.name] = tag
864
865    def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
866        split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple)
867        if remove_tag:
868            for node in self.module.graph.nodes:
869                if hasattr(node, "tag"):
870                    del node.tag
871        return split_module  # type: ignore[return-value]
872
873    def __call__(self) -> torch.fx.GraphModule:
874        subgraphs = self.put_nodes_into_subgraphs()
875        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
876        acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
877        non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
878        print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
879        self.tag(subgraphs)
880        return self.split()
881
882    def generate_split_results(self) -> SplitResult:
883        split_module = self()
884        submodule_names = []
885        for name, mod in split_module.named_children():
886            submodule_names.append(name)
887        if (
888            self.settings.max_acc_splits > 0
889            and len(submodule_names) > self.settings.max_acc_splits
890        ):
891            raise ValueError(
892                "Cannot fulfill max_acc_splits limit. "
893                "This may cause split fragmentation and "
894                "result in performance issues."
895            )
896
897        submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
898        return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
899