xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/split_cat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import logging
4import operator
5from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
6from typing_extensions import TypeAlias
7
8import torch
9from torch._dynamo.utils import counters
10
11from ..pattern_matcher import (
12    Arg,
13    CallFunction,
14    CallFunctionVarArgs,
15    CallMethodVarArgs,
16    FailedMatch,
17    get_arg_value,
18    Ignored,
19    KeywordArg,
20    ListOf,
21    Match,
22    MatchContext,
23    MULTIPLE,
24    PatternExpr,
25    PatternMatcherPass,
26    register_graph_pattern,
27    RepeatedExpr,
28)
29from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
30
31
32log = logging.getLogger(__name__)
33
34_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...]
35_TransformParam: TypeAlias = Tuple[
36    Optional[_Arguments],
37    Optional[_Arguments],
38    Optional[_Arguments],
39    Optional[_Arguments],
40]
41_Range: TypeAlias = Tuple[int, int]
42
43
44PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {}
45POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {}
46
47pre_grad_pass_names = [
48    "normalization_pass",
49    "remove_split_with_size_one_pass",
50    "merge_getitem_cat_pass",
51    "merge_stack_tahn_unbind_pass",
52    "merge_splits_pass",
53    "mutate_cat_pass",
54    "split_cat_pass",
55    "unbind_stack_pass",
56    "split_cat_to_slices_pass",
57    "unbind_cat_to_view_pass",
58    "split_stack_to_cats_pass",
59    "unbind_stack_to_slices_pass",
60    "move_reshape_out_of_split_stack_pass",
61]
62
63post_grad_pass_names = [
64    "normalization_aten_pass",
65    "decompose_mm_pass",
66    "unbind_stack_aten_pass",
67    "shape_padding_multiplier",
68]
69
70for pass_name in pre_grad_pass_names:
71    # exclude all passes from the group batch fusion
72    # they do not use pattern matcher
73    if pass_name in PRE_GRAD_FUSIONS:
74        continue
75    PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
76        pass_name=pass_name,
77    )
78
79for pass_name in post_grad_pass_names:
80    # exclude all passes from the group batch fusion
81    # they do not use pattern matcher
82    if pass_name in POST_GRAD_FUSIONS:
83        continue
84    POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
85        pass_name=pass_name,
86    )
87
88
89def construct_pattern_matcher_pass(pass_name: str):
90    """
91    Return the specific pattern_matcher_pass given the pass name.
92    """
93    if pass_name in PRE_GRAD_PATTERNS:
94        return PRE_GRAD_PATTERNS[pass_name]
95    else:
96        return POST_GRAD_PATTERNS[pass_name]
97
98
99def _get_split_args_default(split_node):
100    input_kwarg = "tensor"
101    split_size_kwarg = "split_size_or_sections"
102    dim_kwarg = "dim"
103    default_dim_value = 0
104    if split_node.op == "call_method":
105        split_size_kwarg = "split_size"
106    return (
107        get_arg_value(split_node, 0, input_kwarg),
108        get_arg_value(split_node, 1, split_size_kwarg),
109        get_arg_value(split_node, 2, dim_kwarg) or default_dim_value,
110    )
111
112
113def _get_dim(node: Any):
114    assert isinstance(node, torch.fx.Node)
115    if "dim" in node.kwargs:
116        assert isinstance(node.kwargs["dim"], int)
117        return node.kwargs["dim"]
118    if node.target == torch.unbind:
119        if len(node.args) == 2:
120            assert isinstance(node.args[-1], int)
121            return node.args[-1]
122        return 0  # defaults to dim=0
123    if node.target == torch.split:
124        if len(node.args) == 3:
125            assert isinstance(node.args[-1], int)
126            return node.args[-1]
127        return 0  # defaults to dim=0
128    raise AssertionError(
129        f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}"
130    )
131
132
133# noqa: W605
134# ############The pattern to be optimized is#########
135#         unbind (dim=0)
136#       /   ...    \
137# getitem      getitem   -> user=1
138#    |            |
139#  split         split  -> dim=1, user=1, split_section_size=1
140#    |            |
141#  getitem       getitem  -> user=1
142#    \           /
143#        cat (dim=1)  -> user=1
144#          |
145
146# ################After transformation#############
147#          unbind (dim=0)
148#        /    ...   \
149#    getitem       getitem  -> user=1
150#       \          /
151#        cat (dim=1)  -> user=1
152#         |
153
154
155def normalize_split_base(
156    match: Match,
157    _get_split_args: Callable[
158        [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]]
159    ],
160):
161    """
162    Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in
163    subsequent optimizations
164    """
165    split_node = match.nodes[0]
166    graph = match.graph
167    split_input, split_size, split_dim = _get_split_args(split_node)
168    if split_input is None or split_dim is None or split_size is None:
169        log.debug("couldn't find split args")
170        return
171    if not is_node_meta_valid(split_node):
172        log.debug("example value absent for node: %s", split_node)
173        return
174    assert isinstance(split_node.meta["example_value"], (list, tuple))
175    split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
176
177    if any(isinstance(section, torch.SymInt) for section in split_sections):
178        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
179        return
180    if split_dim < 0:  # Normalize split dim
181        split_dim += split_input.meta["example_value"].dim()
182
183    new_args = (split_input, split_sections)
184    new_kwargs = {"dim": split_dim}
185    if (
186        split_node.args == new_args
187        and split_node.kwargs == new_kwargs
188        and split_node.op == "call_function"
189    ):
190        return
191
192    with graph.inserting_after(split_node):
193        new_split_node = graph.call_function(
194            torch.split,
195            args=new_args,
196            kwargs=new_kwargs,  # type: ignore[arg-type]
197        )
198    split_node.replace_all_uses_with(new_split_node)
199    new_split_node.meta.update(split_node.meta)
200    graph.erase_node(split_node)
201    counters["inductor"]["normalization_pass"] += 1
202
203
204@register_graph_pattern(
205    CallFunctionVarArgs(torch.split, users=MULTIPLE),
206    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
207)
208@register_graph_pattern(
209    CallMethodVarArgs("split", users=MULTIPLE),
210    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
211)
212def normalize_split_default(match: Match, *args, **kwargs):
213    return normalize_split_base(match, _get_split_args_default)
214
215
216@register_graph_pattern(
217    CallFunctionVarArgs(torch.split, users=MULTIPLE),
218    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
219)
220@register_graph_pattern(
221    CallMethodVarArgs("split", users=MULTIPLE),
222    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
223)
224def remove_split_with_size_one(match: Match, *args, **kwargs):
225    graph = match.graph
226    split_node = match.nodes[0]
227    split_input, split_size, split_dim = _get_split_args_default(split_node)
228    if split_input is None or split_dim is None or split_size is None:
229        log.debug("couldn't find split args")
230        return
231    if not is_node_meta_valid(split_node):
232        log.debug("example value absent for node: %s", split_node)
233        return
234    assert isinstance(split_node.meta["example_value"], (list, tuple))
235    split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
236
237    if any(isinstance(section, torch.SymInt) for section in split_sections):
238        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
239        return
240    # remove the dummy split whose split sections size is one
241    if len(split_sections) == 1:
242        # find the grand children of the split_node
243        next_users = find_next_users(split_node)
244        user = next(iter(split_node.users.keys()))
245        # replace the users of grand child node with the input node
246        for next_user in next_users:
247            next_user.replace_input_with(user, split_input)
248        # erase the split node and its child
249        graph.erase_node(user)
250        graph.erase_node(split_node)
251        counters["inductor"]["remove_split_with_size_one_pass"] += 1
252
253
254@register_graph_pattern(
255    CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
256    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
257)
258@register_graph_pattern(
259    CallMethodVarArgs("unbind", users=MULTIPLE),
260    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
261)
262def normalize_unbind_default(match: Match, *args, **kwargs):
263    node = match.nodes[0]
264    graph = match.graph
265    input = get_arg_value(node, 0, "input")
266    dim = get_arg_value(node, 1, "dim")
267    if dim is None:
268        axis = node.kwargs.get("axis")
269        if axis is not None:
270            dim = axis
271        else:
272            dim = 0
273    if input is None:
274        log.debug("couldn't find unbind args")
275        return
276    if not is_node_meta_valid(input):
277        log.debug("example value absent for node: %s", input)
278        return
279    ndim = input.meta["example_value"].ndim
280    if dim < 0:  # Normalize unbind dim
281        dim += ndim
282    with graph.inserting_after(node):
283        new_node = graph.call_function(
284            torch.unbind,
285            args=(input,),
286            kwargs={"dim": dim},
287        )
288    node.replace_all_uses_with(new_node)
289    new_node.meta.update(node.meta)
290    graph.erase_node(node)
291    counters["inductor"]["normalization_pass"] += 1
292
293
294@register_graph_pattern(
295    CallFunctionVarArgs(torch.cat, users=MULTIPLE),
296    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
297)
298def normalize_cat_default(match: Match, *args, **kwargs):
299    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
300
301    cat_node = match.nodes[0]
302    graph = match.graph
303    tensors = get_arg_value(cat_node, 0, "tensors")
304    cat_dim = get_arg_value(cat_node, 1, "dim")
305    if cat_dim is None:
306        cat_axis = cat_node.kwargs.get("axis")
307        if cat_axis is not None:
308            cat_dim = cat_axis
309        else:
310            cat_dim = 0
311    if tensors is None or cat_dim is None:
312        log.debug("couldn't find cat args")
313        return
314    assert isinstance(tensors, (list, tuple))
315    for tensor in itertools.chain([cat_node], tensors):
316        if not is_node_meta_valid(tensor):
317            log.debug("example value absent for node: %s", tensor)
318            return
319
320    ndim = cat_node.meta["example_value"].dim()
321
322    def is_empty_tensor(x):
323        # special case where torch.cat supports cat'ing with an empty tensor
324        x_shape = x.meta["example_value"].shape
325        return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0)
326
327    assert all(
328        ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
329    )
330
331    if cat_dim < 0:  # Normalize cat dim
332        cat_dim += ndim
333
334    new_args = (tensors,)
335    new_kwargs = {"dim": cat_dim}
336    if (
337        cat_node.args == new_args
338        and cat_node.kwargs == new_kwargs
339        and cat_node.op == "call_function"
340    ):
341        return
342
343    with graph.inserting_after(cat_node):
344        new_cat_node = graph.call_function(
345            torch.cat,
346            args=new_args,
347            kwargs=new_kwargs,
348        )
349    cat_node.replace_all_uses_with(new_cat_node)
350    new_cat_node.meta.update(cat_node.meta)
351    graph.erase_node(cat_node)
352    counters["inductor"]["normalization_pass"] += 1
353
354
355@register_graph_pattern(
356    CallFunctionVarArgs(torch.stack, users=MULTIPLE),
357    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
358)
359def normalize_stack_default(match: Match, *args, **kwargs):
360    node = match.nodes[0]
361    graph = match.graph
362    tensors = get_arg_value(node, 0, "tensors")
363    dim = get_arg_value(node, 1, "dim") or 0
364    if tensors is None or dim is None:
365        log.debug("couldn't find stack args")
366        return
367    assert isinstance(tensors, (list, tuple))
368
369    # A bug in pytorch, some nodes miss the example_value metadata
370    for tensor in itertools.chain([node], tensors):
371        if not is_node_meta_valid(tensor):
372            log.debug("example value absent for node: %s", tensor)
373            return
374
375    ndim = node.meta["example_value"].dim()
376    if dim < 0:  # Normalize dim
377        dim += ndim
378
379    with graph.inserting_after(node):
380        new_node = graph.call_function(
381            node.target,  # type: ignore[arg-type]
382            args=(tensors,),
383            kwargs={"dim": dim},
384        )
385    node.replace_all_uses_with(new_node)
386    new_node.meta.update(node.meta)
387    graph.erase_node(node)
388    counters["inductor"]["normalization_pass"] += 1
389
390
391def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]:
392    next_users = []
393    for getitem_node in split_node.users.keys():
394        for getitem_user in getitem_node.users.keys():
395            if getitem_user not in next_users:
396                next_users.append(getitem_user)
397    return next_users
398
399
400@register_graph_pattern(
401    CallMethodVarArgs("squeeze", users=MULTIPLE),
402    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
403)
404def normalize_squeeze_default(match: Match, *args, **kwargs):
405    squeeze_node = match.nodes[0]
406    squeeze_input = get_arg_value(squeeze_node, 0)
407
408    if "dim" in squeeze_node.kwargs:
409        assert len(squeeze_node.args) == 1
410        dim = squeeze_node.kwargs["dim"]
411    elif len(squeeze_node.args) == 1:
412        # squeeze(Tensor)
413        dim = None
414    elif len(squeeze_node.args) == 2:
415        # squeeze(Tensor self, int dim)
416        # squeeze(Tensor self, int[] dim)
417        dim = squeeze_node.args[1]
418    else:
419        # squeeze(Tensor self, int[] dim) (called with varargs)
420        dim = squeeze_node.args[1:]
421
422    if isinstance(dim, Sequence) and len(dim) == 1:
423        dim = dim[0]
424
425    with match.graph.inserting_after(squeeze_node):
426        if dim is None:
427            new_squeeze_node = match.graph.call_function(
428                torch.squeeze, args=(squeeze_input,)
429            )
430        else:
431            new_squeeze_node = match.graph.call_function(
432                torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim}
433            )
434    squeeze_node.replace_all_uses_with(new_squeeze_node)
435    new_squeeze_node.meta.update(squeeze_node.meta)
436    match.graph.erase_node(squeeze_node)
437
438
439@register_graph_pattern(
440    CallMethodVarArgs("reshape", users=MULTIPLE),
441    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
442)
443def normalize_reshape_default(match: Match, *args, **kwargs):
444    reshape_node = match.nodes[0]
445    if not is_node_meta_valid(reshape_node):
446        log.debug("example value absent for node: %s", reshape_node)
447        return
448    reshape_input = get_arg_value(reshape_node, 0)
449
450    with match.graph.inserting_after(reshape_node):
451        new_reshape_node = match.graph.call_function(
452            torch.reshape,
453            args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)),
454        )
455    reshape_node.replace_all_uses_with(new_reshape_node)
456    new_reshape_node.meta.update(reshape_node.meta)
457    match.graph.erase_node(reshape_node)
458
459
460class TorchSplit(CallFunction):
461    """
462    Matches a call to torch.split if it is in a normalized form. Ensures that all users of
463    splits are unique getitems.
464    """
465
466    def __init__(self, arg, sizes, func=torch.split) -> None:
467        # using KeywordArg("dim") for `dim` checks they all match
468        super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim"))
469
470    def _match(self, node: torch.fx.Node, ctx: MatchContext):
471        m = super()._match(node, ctx)
472        if not m:
473            return m
474        split_sections = node.args[1]
475        if not isinstance(split_sections, (list, tuple)):
476            return FailedMatch("split not normalized")
477        # check users are all unique getitems
478        seen_idxs = set()
479        for user in node.users:
480            if not CallFunction(operator.getitem, Arg(), Arg()).match(user):
481                # This should ideally never happen. Split user should always be a getitem
482                return FailedMatch(f"user of split not a getitem: {user}")
483            if not isinstance(user.args[1], int):
484                return FailedMatch("only integer getitems are handled")
485            if user.args[1] in seen_idxs:
486                return FailedMatch(f"duplicate getitem {user.args[1]}")
487            if user.args[-1] < 0:  # type: ignore[operator]
488                # This shouldn't ideally happen as dynamo normalizes indexes to positive
489                return FailedMatch("negative index")
490            seen_idxs.add(user.args[1])
491        return m
492
493
494@register_graph_pattern(
495    TorchSplit(
496        CallFunction(
497            operator.getitem,
498            TorchSplit(
499                KeywordArg("first_split_input"),
500                KeywordArg("first_split_sections"),
501            ),
502            Ignored(),
503        ),
504        KeywordArg("next_split_sections"),
505    ),
506    pass_dict=construct_pattern_matcher_pass("merge_splits_pass"),
507)
508def merge_splits(
509    match: Match,
510    first_split_input: torch.fx.Node,
511    first_split_sections: List[int],
512    next_split_sections: List[int],
513    # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim
514    dim: int,
515):
516    node = match.output_node()
517    # it is possible that the split has no users,
518    # we check the corner case and skip the pattern
519    if len(node.users.keys()) == 0:
520        return
521    graph = match.graph
522    first_split = node.args[0].args[0]  # type: ignore[union-attr]
523    next_split_index = node.args[0].args[1]  # type: ignore[union-attr]
524
525    new_split_sections = list(first_split_sections)
526    new_split_sections[next_split_index : next_split_index + 1] = next_split_sections  # type: ignore[operator, misc]
527
528    first_split_dim = _get_dim(first_split)
529
530    to_remove = []
531
532    with graph.inserting_before(first_split):  # type: ignore[arg-type]
533        # Add the new split node
534        new_split = graph.call_function(
535            torch.split,
536            args=(first_split_input, new_split_sections),
537            kwargs={"dim": first_split_dim},
538        )
539        if is_node_meta_valid(first_split_input):
540            new_split.meta["example_value"] = torch.split(
541                first_split_input.meta["example_value"],
542                new_split_sections,
543                dim=first_split_dim,
544            )
545        first_split_num_to_user = {
546            user.args[1]: user for user in first_split.users.keys()  # type: ignore[union-attr]
547        }
548
549        new_split_num = 0
550        for split_num in range(len(first_split_sections)):
551            if split_num not in first_split_num_to_user:
552                new_split_num += 1
553                continue
554            old_getitem = first_split_num_to_user[split_num]
555            if split_num != next_split_index:
556                old_getitem.update_arg(0, new_split)
557                old_getitem.update_arg(1, new_split_num)
558                new_split_num += 1
559            else:
560                next_split_num_to_user = {
561                    user.args[1]: user for user in node.users.keys()
562                }
563                # It is not necessary all getitems from the split node are used.
564                # We use the num of users to check the getitems to be merged.
565                for next_split_num in range(len(node.users.keys())):
566                    with graph.inserting_after(new_split):
567                        new_getitem = graph.call_function(
568                            operator.getitem, args=(new_split, new_split_num)
569                        )
570                    new_split_num += 1
571                    next_getitem = next_split_num_to_user[next_split_num]
572                    new_getitem.meta.update(next_getitem.meta)
573                    next_getitem.replace_all_uses_with(new_getitem)
574                    to_remove.append(next_getitem)
575                to_remove.append(node)
576                to_remove.append(old_getitem)
577
578        to_remove.append(first_split)  # type: ignore[arg-type]
579    for node in to_remove:
580        graph.erase_node(node)
581
582    counters["inductor"]["merge_splits_pass"] += 1
583
584
585class SplitCatSimplifier:
586    """
587    Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat"
588    pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat.
589    Some such cases are:
590        1. Final node has additional args (not coming from the initial split)
591        2. Shuffling of args between split/cat
592        3. Some final nodes are non-(cat/stack)
593        4. Split-dim != cat-dim (but equal split)
594
595    Note that any combination of the above cases can happen.
596
597    To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged.
598    Then, we simplify the split accordingly. In the best case, split can be entirely removed.
599
600    To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`).
601
602    Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added.
603
604    """
605
606    def simplify(
607        self,
608        graph: torch.fx.Graph,
609        split_node: torch.fx.Node,
610        split_sections: List[int],
611    ):
612        # Find the next users (i.e. users after the getitem)
613        next_users = find_next_users(split_node)
614        # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by
615        # a tuple indicating the split ranges. See `get_user_input_list` for more details
616        user_inputs_list = self.get_user_input_list(split_node, next_users)
617        # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and
618        # we can simply replace the split node. Otherwise, we simplify it.
619        simplified_split_ranges = self.get_simplified_split_ranges(
620            split_sections, next_users, user_inputs_list
621        )
622        if not simplified_split_ranges:  # Simplification not possible
623            return
624        transform_params_list = self.get_transform_params(
625            split_node, next_users, user_inputs_list
626        )
627        if not transform_params_list:
628            return
629
630        # Start actual replacement
631        user_inputs_list_new = self.replace_split(
632            graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
633        )
634        self.replace_cat(
635            graph, split_node, next_users, user_inputs_list_new, transform_params_list  # type: ignore[arg-type]
636        )
637        self.erase_old_nodes(graph, split_node, next_users)  # type: ignore[arg-type]
638        counters["inductor"]["unbind_stack_pass"] += 1
639
640    def get_user_input_list(
641        self, split_node: torch.fx.Node, next_users: List[torch.fx.Node]
642    ) -> List[List[Union[torch.fx.Node, _Range]]]:
643        """
644        Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
645        list represents the inputs to that particular node. This list can either contain
646          - a tuple representing the ranges of get_items that should go into the cat (closed interval)
647          - torch.fx.Node representing "other" inputs (which are not coming from our split)
648        """
649        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = []
650        for user in next_users:
651            if user.target in {torch.cat, torch.stack}:
652                user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
653            else:
654                user_inputs_list.append(self.get_non_cat_node_input(split_node, user))  # type: ignore[arg-type]
655        return user_inputs_list
656
657    def get_merged_user_inputs(
658        self, split_node: torch.fx.Node, cat_node: torch.fx.Node
659    ) -> List[Union[torch.fx.Node, _Range]]:
660        user_inputs = get_arg_value(cat_node, 0, "tensors")
661        simplified_user_inputs = []
662        split_users = set(split_node.users.keys())
663        for user_input in user_inputs:
664            if user_input not in split_users:
665                simplified_user_inputs.append(user_input)
666            else:
667                # Add which "getitem" cat depends on
668                simplified_user_inputs.append(user_input.args[1])
669        return self.merge_consecutive_inputs(simplified_user_inputs)
670
671    def get_non_cat_node_input(
672        self, split_node: torch.fx.Node, node: torch.fx.Node
673    ) -> List[_Range]:
674        """
675        Get input for a non cat node in the same format as `get_merged_user_inputs`
676        """
677        node_input = []
678        split_users = set(split_node.users.keys())
679        for node_arg in node.all_input_nodes:
680            if node_arg in split_users:
681                getitem_num = get_arg_value(node_arg, 1)
682                node_input.append((getitem_num, getitem_num))
683        return node_input
684
685    def merge_consecutive_inputs(
686        self, inputs: List[Union[torch.fx.Node, int]]
687    ) -> List[Union[torch.fx.Node, _Range]]:
688        """
689        Merge consecutive inputs going into a user node.
690
691        For e.g.
692        [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1]
693        """
694        merged_ranges = []
695        cur_range = None
696        for input_ in inputs:
697            if isinstance(input_, int):
698                if not cur_range:
699                    cur_range = [input_, input_]
700                elif input_ == cur_range[1] + 1:
701                    cur_range[1] += 1
702                else:
703                    merged_ranges.append(tuple(cur_range))
704                    cur_range = [input_, input_]
705            else:
706                if cur_range:
707                    merged_ranges.append(tuple(cur_range))
708                    cur_range = None
709                merged_ranges.append(input_)  # type: ignore[arg-type]
710        if cur_range:
711            merged_ranges.append(tuple(cur_range))
712        return merged_ranges  # type: ignore[return-value]
713
714    def get_simplified_split_ranges(
715        self,
716        split_sections,
717        next_users,
718        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
719    ) -> Optional[List[_Range]]:
720        ranges = set()
721        for user_node, user_inputs in zip(next_users, user_inputs_list):
722            ranges |= {
723                user_input
724                for user_input in user_inputs
725                if isinstance(user_input, tuple)
726            }
727        cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
728        split_ranges = sorted(
729            [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges]
730        )
731
732        if not self.has_non_overlapping_ranges(
733            split_ranges,
734        ):  # This need not be a strict condition
735            # However, we keep it now for simplicity.
736            return None
737        split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1])
738        if len(split_sections) == len(split_ranges):  # Simplification not possible
739            return None
740        counters["inductor"]["scmerge_split_sections_removed"] = len(
741            split_sections
742        ) - len(split_ranges)
743        return split_ranges
744
745    def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool:
746        for range_, next_range in zip(ranges, ranges[1:]):
747            if range_[1] > next_range[0]:
748                return False
749        return True
750
751    def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]:
752        cur = min_
753        filled_ranges = []
754        for a, b in ranges:
755            if cur < a:
756                filled_ranges.append((cur, a))
757            filled_ranges.append((a, b))
758            cur = b
759        if filled_ranges[-1][1] < max_:
760            filled_ranges.append((filled_ranges[-1][1], max_))
761        return filled_ranges
762
763    def get_transform_params(
764        self,
765        split_node: torch.fx.Node,
766        next_users: List[torch.fx.Node],
767        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
768    ) -> Optional[List[List[_TransformParam]]]:
769        """
770        Figure out what transforms are needed for each input to each cat node.
771
772        We replace a split node with an unflatten followed by a movedim
773        """
774        split_dim = _get_dim(split_node)
775        split_sections = split_node.args[1]
776        transform_params_list: List[List[_TransformParam]] = []
777
778        for user_node, user_inputs in zip(next_users, user_inputs_list):
779            if user_node.target not in {torch.cat, torch.stack}:
780                transform_params_list.append([])
781                continue
782
783            cat_dim = get_arg_value(user_node, 1, "dim")
784            transform_params: List[_TransformParam] = []
785            for user_input in user_inputs:
786                if split_dim == cat_dim and user_node.target == torch.cat:
787                    # No transform needed
788                    transform_params.append((None, None, None, None))
789                elif isinstance(user_input, tuple):  # Split being simplified
790                    # Verify equal split
791                    subset_split_sections = split_sections[  # type: ignore[index]
792                        user_input[0] : user_input[1] + 1
793                    ]
794                    # All sections should be equal
795                    if len(set(subset_split_sections)) != 1:
796                        return None
797
798                    num_splits = len(subset_split_sections)
799                    unflatten_params = (split_dim, (num_splits, -1))
800                    movedim_params = (
801                        (split_dim, cat_dim) if split_dim != cat_dim else None
802                    )
803                    transform_params.append(
804                        (unflatten_params, movedim_params, None, None)
805                    )
806                elif (
807                    user_node.target == torch.stack or split_dim != cat_dim
808                ):  # We need to unsqueeze inputs not coming through split
809                    transform_params.append((None, None, (cat_dim,), None))
810                else:  # Non-split inputs
811                    transform_params.append((None, None, None, None))
812            transform_params_list.append(transform_params)
813        return transform_params_list
814
815    def replace_split(
816        self,
817        graph: torch.fx.Graph,
818        split_node: torch.fx.Node,
819        split_sections: List[int],
820        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
821        split_ranges: List[_Range],
822    ) -> List[List[torch.fx.Node]]:
823        """
824        Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it
825        into a split with lesser sections if len(split_ranges) > 1.
826
827        Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node.
828        """
829        split_input = split_node.args[0]
830        split_dim = _get_dim(split_node)
831        if len(split_ranges) == 1:  # We can completely eliminate the split node
832            split_items = [split_input]
833        else:
834            with graph.inserting_after(split_node):
835                new_split = graph.call_function(
836                    torch.split,
837                    args=(
838                        split_input,
839                        [r[1] - r[0] for r in split_ranges],
840                    ),
841                    kwargs={"dim": split_dim},
842                )
843                if is_node_meta_valid(split_input):  # type: ignore[arg-type, union-attr]
844                    new_split.meta["example_value"] = torch.split(
845                        split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim  # type: ignore[union-attr]
846                    )
847                counters["inductor"]["scmerge_split_added"] += 1
848            split_items = []
849            with graph.inserting_after(new_split):
850                for i in range(len(split_ranges)):
851                    getitem = graph.call_function(operator.getitem, args=(new_split, i))
852                    if is_node_meta_valid(new_split):
853                        getitem.meta["example_value"] = new_split.meta["example_value"][
854                            i
855                        ]
856                        split_items.append(getitem)
857        # Now assign the right getitem to the right input
858        cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
859        new_user_inputs_list = []
860        for user_inputs in user_inputs_list:
861            new_user_inputs = []
862            for user_input in user_inputs:
863                if isinstance(user_input, tuple):
864                    # Find the correct new getitem (present in split_items)
865                    new_user_inputs.append(
866                        split_items[
867                            split_ranges.index(
868                                (
869                                    cumulative_sizes[user_input[0]],
870                                    cumulative_sizes[user_input[1] + 1],
871                                )
872                            )
873                        ]
874                    )
875                else:
876                    new_user_inputs.append(user_input)
877            new_user_inputs_list.append(new_user_inputs)
878        return new_user_inputs_list  # type: ignore[return-value]
879
880    def replace_cat(
881        self,
882        graph: torch.fx.GraphModule,
883        split_node: torch.fx.Node,
884        next_users: List[torch.fx.Node],
885        user_inputs_list_new,
886        transform_params_list: List[List[_TransformParam]],
887    ):
888        split_dim = _get_dim(split_node)
889        split_users = split_node.users.keys()
890        new_cats = []
891        for user_node, user_inputs_new, transform_params in zip(
892            next_users, user_inputs_list_new, transform_params_list
893        ):
894            if user_node.target not in {torch.cat, torch.stack}:
895                # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to
896                # the original split node) with the newer getitems
897                next_cat_input = 0
898                for input_node in user_node.all_input_nodes:
899                    if input_node in split_users:
900                        user_node.replace_input_with(
901                            input_node, user_inputs_new[next_cat_input]
902                        )
903                        next_cat_input += 1
904                continue
905
906            # Handle cat/stack user nodes
907            cat_dim = get_arg_value(user_node, 1, "dim")
908            user_inputs_new_transformed, user_inputs_new_transformed_meta = [], []
909            # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them
910            to_stack, to_stack_meta = [], []
911            stack_dim = None
912            with graph.inserting_before(user_node):
913                for user_input_new, transform_param in zip(
914                    user_inputs_new, transform_params
915                ):
916                    if not is_node_meta_valid(user_input_new):
917                        log.debug("example value absent for node: %s", user_input_new)
918                        return
919                    # Apply transforms
920                    (
921                        unflatten_params,
922                        movedim_params,
923                        unsqueeze_params,
924                        flatten_params,
925                    ) = transform_param
926                    if unsqueeze_params and (
927                        stack_dim is None or stack_dim == unsqueeze_params[0]
928                    ):
929                        to_stack.append(user_input_new)
930                        to_stack_meta.append(user_input_new.meta["example_value"])
931                        stack_dim = unsqueeze_params[0]
932                        continue
933                    elif to_stack:
934                        stacked_input = graph.call_function(
935                            torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
936                        )
937                        stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim)  # type: ignore[arg-type, union-attr]
938                        to_stack, to_stack_meta = [], []
939                        stack_dim = None
940                        user_inputs_new_transformed.append(stacked_input)
941                        user_inputs_new_transformed_meta.append(
942                            stacked_input.meta["example_value"]
943                        )
944                        if unsqueeze_params:
945                            to_stack.append(user_input_new)
946                            stack_dim = unsqueeze_params[0]
947                            to_stack_meta.append(user_input_new.meta["example_value"])
948                            continue
949
950                    if unflatten_params:
951                        user_input_new_meta = user_input_new.meta["example_value"]
952                        user_input_new = graph.call_function(
953                            torch.unflatten, args=(user_input_new, *unflatten_params)
954                        )
955                        user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params)  # type: ignore[arg-type, possibly-undefined, union-attr]
956                    if movedim_params:
957                        user_input_new_meta = user_input_new.meta["example_value"]
958                        user_input_new = graph.call_function(
959                            torch.movedim, args=(user_input_new, *movedim_params)
960                        )
961                        user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params)  # type: ignore[arg-type, possibly-undefined, union-attr]
962                    if flatten_params:
963                        user_input_new_meta = user_input_new.meta["example_value"]
964                        user_input_new = graph.call_function(
965                            torch.flatten, args=(user_input_new, *flatten_params)
966                        )
967                        user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params)  # type: ignore[arg-type, possibly-undefined, union-attr]
968                    user_inputs_new_transformed.append(user_input_new)
969                    user_inputs_new_transformed_meta.append(
970                        user_input_new.meta["example_value"]
971                    )
972                if to_stack:
973                    stacked_input = graph.call_function(
974                        torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
975                    )
976                    stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim)  # type: ignore[arg-type, union-attr]
977                    user_inputs_new_transformed.append(stacked_input)
978                    user_inputs_new_transformed_meta.append(
979                        stacked_input.meta["example_value"]
980                    )
981
982            with graph.inserting_after(user_node):
983                if len(user_inputs_new_transformed) > 1:
984                    new_cat_node = graph.call_function(
985                        torch.cat,
986                        args=(user_inputs_new_transformed,),
987                        kwargs={"dim": cat_dim},
988                    )
989                    new_cat_node.meta["example_value"] = torch.cat(
990                        user_inputs_new_transformed_meta, dim=cat_dim
991                    )
992                    counters["inductor"]["scmerge_cat_added"] += 1
993                else:
994                    new_cat_node = user_inputs_new_transformed[-1]
995                    new_cat_node.meta[
996                        "example_value"
997                    ] = user_inputs_new_transformed_meta[-1]
998
999            if (
1000                user_node.target == torch.cat
1001                and split_dim != cat_dim
1002                and split_node.target == torch.split
1003            ):
1004                with graph.inserting_after(new_cat_node):
1005                    new_cat_node_meta = new_cat_node.meta["example_value"]
1006                    new_cat_node = graph.call_function(
1007                        torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
1008                    )
1009                    new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1)  # type: ignore[possibly-undefined, union-attr]
1010            user_node.replace_all_uses_with(new_cat_node)
1011            new_cats.append(new_cat_node)
1012
1013    def erase_old_nodes(
1014        self,
1015        graph: torch.fx.GraphModule,
1016        split_node: torch.fx.Node,
1017        next_users: List[torch.fx.Node],
1018    ):
1019        to_remove = [split_node]
1020        counters["inductor"]["scmerge_split_removed"] += 1
1021        to_remove.extend(split_node.users.keys())
1022        for next_user in next_users:
1023            if next_user.target not in {torch.cat, torch.stack}:
1024                continue
1025            counters["inductor"]["scmerge_cat_removed"] += 1
1026            to_remove.append(next_user)
1027        for node in reversed(to_remove):
1028            if len(node.users.keys()) == 0:
1029                graph.erase_node(node)
1030
1031
1032class UnbindCatRemover(SplitCatSimplifier):
1033    """
1034    Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier.
1035
1036    Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this,
1037    other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`,
1038    hence we extend that class.
1039    """
1040
1041    def remove_unbind(
1042        self,
1043        graph: torch.fx.Graph,
1044        unbind_node: torch.fx.Node,
1045    ):
1046        if not is_node_meta_valid(unbind_node):
1047            return
1048        # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
1049        # before we do the unbind remove, otherwise it will hit the error when we unbind part of them
1050        getitem_indices = []
1051        for getitem_node in unbind_node.users.keys():
1052            getitem_indices.append(getitem_node.args[1])
1053        if not is_sorted_and_consecutive(getitem_indices) or len(  # type: ignore[arg-type]
1054            getitem_indices
1055        ) != len(
1056            unbind_node.meta["example_value"]
1057        ):
1058            return
1059        num_unbind = len(getitem_indices)
1060        split_sections = [1 for _ in range(num_unbind)]  # type: ignore[operator, arg-type]
1061
1062        super().simplify(graph, unbind_node, split_sections)
1063
1064    def get_simplified_split_ranges(
1065        self,
1066        split_sections: List[int],
1067        next_users: List[torch.fx.Node],
1068        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
1069    ) -> Optional[List[_Range]]:
1070        simplified_split_ranges = super().get_simplified_split_ranges(
1071            split_sections, next_users, user_inputs_list
1072        )
1073        if not simplified_split_ranges or len(simplified_split_ranges) != 1:
1074            return None
1075        return simplified_split_ranges
1076
1077    def get_transform_params(
1078        self,
1079        split_node: torch.fx.Node,
1080        next_users: List[torch.fx.Node],
1081        user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
1082    ) -> Optional[List[List[_TransformParam]]]:
1083        """
1084        Figure out what transforms are needed for each input to each cat node.
1085
1086        Here is the rough transforms we apply:
1087
1088        x -> unbind -> stack => x -> movedim
1089
1090        x -> unbind -> cat => x -> movedim -> flatten
1091
1092        When cat/stack nodes have additional args:
1093
1094             addn ---|              addn -> unsqueeze ---|
1095        x -> unbind -> stack  =>           x -> movedim  -> cat
1096
1097             addn ---|                            addn ---|
1098        x -> unbind -> cat  =>   x -> movedim -> flatten  -> cat
1099
1100        (Note application of these depends on the dims as well)
1101
1102
1103        """
1104        split_dim = _get_dim(split_node)
1105        transform_params_list: List[List[_TransformParam]] = []
1106        for user_node, user_inputs in zip(next_users, user_inputs_list):
1107            cat_dim = get_arg_value(user_node, 1, "dim") or 0
1108            transform_params: List[_TransformParam] = []
1109            for user_input in user_inputs:
1110                if isinstance(user_input, tuple):
1111                    # User input is coming from unbind
1112                    movedim_params = (
1113                        (split_dim, cat_dim) if split_dim != cat_dim else None
1114                    )
1115                    flatten_params = None
1116                    if user_node.target == torch.cat:
1117                        flatten_params = (cat_dim, cat_dim + 1)
1118                    transform_params.append(
1119                        (None, movedim_params, None, flatten_params)
1120                    )
1121                elif (
1122                    user_node.target == torch.stack
1123                ):  # We need to unsqueeze inputs not coming through unbind into cat
1124                    transform_params.append((None, None, (cat_dim,), None))
1125                else:  # Non-unbind inputs
1126                    transform_params.append((None, None, None, None))
1127            transform_params_list.append(transform_params)
1128        return transform_params_list
1129
1130
1131class GetItem(CallFunction):
1132    def __init__(self, arg, index, _users=1) -> None:
1133        super().__init__(operator.getitem, arg, index, _users=_users)
1134
1135    def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]):
1136        # We generally match GetItem with arg being an Arg(). So, we never return the anchor
1137        # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes
1138        # to not use ctx.pattern_to_node
1139        for pattern in self.flat_args_kwargs[0]:
1140            if isinstance(pattern, PatternExpr):
1141                for other_node in pattern.find_anchor_nodes(ctx, searched):
1142                    if not isinstance(other_node, torch.fx.Node):
1143                        continue
1144                    for node in other_node.users:
1145                        if node not in searched:
1146                            if self._match_fns(node):
1147                                yield node
1148                                searched.add(node)
1149
1150
1151@register_graph_pattern(
1152    RepeatedExpr(
1153        CallFunction(
1154            torch.squeeze,
1155            GetItem(
1156                TorchSplit(
1157                    KeywordArg("split_input"),
1158                    KeywordArg("split_sizes"),
1159                ),
1160                Ignored(),
1161            ),
1162            KeywordArg("dim"),
1163            _users=MULTIPLE,
1164        ),
1165    ),
1166    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
1167)
1168@register_graph_pattern(
1169    RepeatedExpr(
1170        CallFunction(
1171            torch.squeeze,
1172            GetItem(
1173                TorchSplit(
1174                    KeywordArg("split_input"),
1175                    KeywordArg("split_sizes"),
1176                ),
1177                Ignored(),
1178            ),
1179            dim=KeywordArg("dim"),
1180            _users=MULTIPLE,
1181        )
1182    ),
1183    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
1184)
1185def merge_split_squeeze(
1186    match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
1187):
1188    graph = match.graph
1189    split = next(node for node in match.nodes if node.target == torch.split)
1190    if not all(s == 1 for s in split_sizes):
1191        return
1192    if isinstance(dim, Sequence):
1193        return
1194    next_users = find_next_users(split)
1195    if not all(node.target == torch.squeeze for node in next_users):
1196        return
1197    with graph.inserting_before(match.output_node()):
1198        unbind = graph.call_function(
1199            torch.unbind, args=(split_input,), kwargs={"dim": dim}
1200        )
1201        if is_node_meta_valid(split_input):
1202            unbind.meta["example_value"] = torch.unbind(
1203                split_input.meta["example_value"], dim=dim
1204            )
1205        for item_index, getitem_node in sorted(
1206            [
1207                (getitem_node.args[1], getitem_node)
1208                for getitem_node in split.users.keys()
1209            ]
1210        ):
1211            squeeze = next(iter(getitem_node.users.keys()))
1212            new_get_item = graph.call_function(
1213                operator.getitem, args=(unbind, item_index)
1214            )
1215            squeeze.replace_all_uses_with(new_get_item)
1216            new_get_item.meta.update(squeeze.meta)
1217            graph.erase_node(squeeze)
1218            graph.erase_node(getitem_node)
1219    graph.erase_node(split)
1220    counters["inductor"]["split_cat_pass"] += 1
1221
1222
1223getitem_unbind = ListOf(
1224    GetItem(
1225        CallFunction(
1226            torch.unbind,
1227            KeywordArg("unbind_input"),
1228            dim=KeywordArg("dim"),
1229            _users=MULTIPLE,
1230        ),
1231        Ignored(),
1232        _users=MULTIPLE,
1233    ),
1234    partial=True,
1235)
1236
1237
1238@register_graph_pattern(
1239    CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
1240    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
1241)
1242@register_graph_pattern(
1243    CallFunction(
1244        [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
1245    ),
1246    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
1247)
1248@register_graph_pattern(
1249    CallFunction(
1250        [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
1251    ),
1252    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
1253)
1254def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
1255    unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
1256    UnbindCatRemover().remove_unbind(match.graph, unbind_node)
1257
1258
1259getitem_split = ListOf(
1260    CallFunction(
1261        operator.getitem,
1262        TorchSplit(
1263            Ignored(),
1264            KeywordArg("split_sections"),
1265        ),
1266        Ignored(),
1267        _users=MULTIPLE,
1268    ),
1269    partial=True,
1270)
1271
1272
1273reshape_getitem_split = ListOf(
1274    CallFunction(
1275        torch.reshape,
1276        CallFunction(
1277            operator.getitem,
1278            TorchSplit(
1279                Ignored(),
1280                KeywordArg("split_sections"),
1281            ),
1282            Ignored(),
1283            _users=MULTIPLE,
1284        ),
1285        Arg(),
1286        _users=MULTIPLE,
1287    ),
1288    partial=True,
1289)
1290
1291
1292@register_graph_pattern(
1293    CallFunction(
1294        [torch.stack, torch.cat],
1295        tensors=getitem_split,
1296        dim=Ignored(),
1297        _users=MULTIPLE,
1298    ),
1299    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
1300)
1301@register_graph_pattern(
1302    CallFunction(
1303        [torch.stack, torch.cat],
1304        getitem_split,
1305        dim=Ignored(),
1306        _users=MULTIPLE,
1307    ),
1308    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
1309)
1310@register_graph_pattern(
1311    CallFunction(
1312        [torch.stack, torch.cat],
1313        getitem_split,
1314        Ignored(),
1315        _users=MULTIPLE,
1316    ),
1317    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
1318)
1319def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
1320    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
1321        return
1322    split_node = next(node for node in match.nodes if node.target == torch.split)
1323    SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
1324
1325
1326# noqa: W605
1327# ############pattern to be optimized is#########
1328
1329#                 split_node(dim=1)
1330#       /     \         ...       /         \
1331# getitem    getitem          getitem     getitem   -> user=1
1332#    \       /                     \       /
1333#      cat (user=mul, dim=1)           cat(user=mul, dim=1)
1334#       |            \                   |          \
1335
1336# ################after transformation#############
1337
1338#                 split_node(dim=1)
1339#       /              ...                  \
1340#     getitem                             getitem
1341#     |    \                              |     \
1342
1343
1344def has_same_parent_node(node: torch.fx.Node):
1345    # the input nodes of the node should come from the same parent
1346    prev_node = None
1347    for getitem in node.args[0]:  # type: ignore[union-attr]
1348        if getitem.target != operator.getitem:  # type: ignore[union-attr]
1349            return False
1350        if prev_node is None:
1351            prev_node = getitem.args[0]  # type: ignore[union-attr]
1352        else:
1353            if getitem.args[0] != prev_node:
1354                return False
1355    return True
1356
1357
1358def remove_zeros(split_sections: List[int]):
1359    """
1360    Remove zeros from the list and get the index mapping dict from getitem
1361    in split node to getitem in new split node
1362    """
1363    new_split_sections, index_mapping = [], {}
1364    idx = 0
1365    for i in range(len(split_sections)):
1366        if split_sections[i] > 0:
1367            new_split_sections.append(split_sections[i])
1368            index_mapping[i] = idx
1369            idx += 1
1370
1371    return new_split_sections, index_mapping
1372
1373
1374def is_sorted_and_consecutive(arr: List[int]) -> bool:
1375    # check if the array is sorted
1376    if arr == sorted(arr):
1377        # check if the differences between adjacent elements are all 1
1378        return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:]))
1379    else:
1380        return False
1381
1382
1383def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int:
1384    """
1385    Calculate the fused tensor size in the indices
1386    """
1387    fused_tensor_size = 0
1388    for i in range(len(split_node.args[1])):  # type: ignore[arg-type]
1389        if i in indices:
1390            fused_tensor_size += split_node.args[1][i]  # type: ignore[operator, assignment, index]
1391    return fused_tensor_size
1392
1393
1394@register_graph_pattern(
1395    CallFunction(
1396        torch.cat,
1397        getitem_split,
1398        dim=Ignored(),
1399        _users=MULTIPLE,
1400    ),
1401    pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"),
1402)
1403def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
1404    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
1405        return
1406    graph = match.graph
1407    split_node = next(node for node in match.nodes if node.target == torch.split)
1408    split_input, split_size, split_dim = _get_split_args_default(split_node)
1409    # if the cat and split have different dims, return
1410    # Find the next users (i.e. users after the getitem)
1411    next_users = find_next_users(split_node)
1412    # 'immutable_list' object does not support mutation. Create a new copy of it
1413    split_sections = list(split_sections)
1414    for cat_user in next_users:
1415        if cat_user.target == torch.cat:
1416            cat_dim = get_arg_value(cat_user, 1, "dim")
1417            # check the all getitems in the cat_user from the same node
1418            # check the input of the cat has all getitem from the split
1419            # check all getitem only has one single user
1420            if (
1421                split_dim != cat_dim
1422                or not has_same_parent_node(cat_user)
1423                or not all(len(arg.users) == 1 for arg in cat_user.args[0])  # type: ignore[union-attr]
1424            ):
1425                continue
1426            # find the index of getitems to be cated/stacked
1427            indices = []
1428            for arg in cat_user.args[0]:  # type: ignore[union-attr]
1429                indices.append(arg.args[1])  # type: ignore[union-attr]
1430            # the gettitems to be merged must be consecutive, otherwise
1431            # returned sliced tensor could be wrong
1432            if not is_sorted_and_consecutive(indices):
1433                continue
1434            # update the arg of cat user, only keep the first getitem
1435            cat_user.update_arg(0, cat_user.args[0][0])  # type: ignore[index]
1436            # calculate the fused tensor sizes in the indices
1437            fused_tensor_size = 0
1438            for i in range(len(split_node.args[1])):  # type: ignore[arg-type]
1439                if i in indices:
1440                    fused_tensor_size += split_node.args[1][i]  # type: ignore[operator, assignment, index]
1441            # update the split sections
1442            split_sections[indices[0]] = calculate_fused_tensor_size(
1443                split_node, indices
1444            )
1445            # padding others with zeros to keep the same dict size
1446            for i in indices[1:]:
1447                split_sections[i] = 0
1448            # remove all unused indexes in the split_node
1449            new_split_sections, index_mapping = remove_zeros(split_sections)
1450            with graph.inserting_after(split_node):
1451                new_split_node = graph.call_function(
1452                    torch.split,
1453                    args=(split_input, split_sections),
1454                    kwargs={"dim": split_dim},
1455                )
1456                split_node.replace_all_uses_with(new_split_node)
1457                new_split_node.meta.update(split_node.meta)
1458                # remove all unused getitem nodes
1459                to_remove = [cat_user]
1460                # dictionary keys changed during iteration
1461                new_split_getitem_nodes = list(new_split_node.users.keys())
1462                for getitem_node in new_split_getitem_nodes:
1463                    if getitem_node.args[1] in indices[1:]:
1464                        to_remove.append(getitem_node)
1465                    # update meta data of getitem
1466                    elif getitem_node.args[1] == indices[0]:
1467                        cat_user.replace_all_uses_with(getitem_node)
1468                        getitem_node.meta.update(cat_user.meta)
1469                    else:
1470                        # update getitem index for new split node
1471                        getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
1472                graph.erase_node(split_node)
1473                for getitem_node in to_remove:
1474                    graph.erase_node(getitem_node)
1475                # update the split sections of new split node
1476                new_split_node.update_arg(1, new_split_sections)
1477                split_node = new_split_node
1478                split_sections = new_split_sections
1479
1480                counters["inductor"]["merge_getitem_cat_pass"] += 1
1481
1482
1483# ############pattern to be optimized is#########
1484
1485#                 split_node(dim=1)  -> user=multiple
1486#       /     \         ...       /         \
1487# getitem    getitem          getitem     getitem   -> user=multiple
1488#    \       \                    /            \
1489#          other_op /cat(user=mul, dim=1)             other_op
1490#                      |
1491
1492# ################after transformation#############
1493
1494#                 split_node(dim=1)         -> -> user=multiple
1495#       /     \         ...       /         \
1496# getitem    getitem          getitem     getitem   -> user=multiple
1497#    \       \                    /           \
1498#                          other_op
1499
1500
1501@register_graph_pattern(
1502    CallFunction(
1503        torch.cat,
1504        getitem_split,
1505        dim=Ignored(),
1506        _users=MULTIPLE,
1507    ),
1508    pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"),
1509)
1510def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
1511    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
1512        return
1513    graph = match.graph
1514    split_node = next(node for node in match.nodes if node.target == torch.split)
1515    split_input, split_size, split_dim = _get_split_args_default(split_node)
1516    # if the cat and split have different dims, return
1517    # Find the next users (i.e. users after the getitem)
1518    next_users = find_next_users(split_node)
1519    for cat_user in next_users:
1520        if cat_user.target == torch.cat:
1521            cat_dim = get_arg_value(cat_user, 1, "dim") or 0
1522            # check that all getitems in the cat_user from the same node
1523            # check the input of the cat has all getitem from the split
1524            if split_dim != cat_dim or not has_same_parent_node(cat_user):
1525                continue
1526            # find the index of getitems to be cat
1527            indices, idx_to_getitem = [], {}
1528            for getitem in cat_user.args[0]:  # type: ignore[union-attr]
1529                indices.append(getitem.args[1])  # type: ignore[union-attr]
1530                idx_to_getitem[getitem.args[1]] = getitem  # type: ignore[union-attr]
1531            # the gettitems to be merged must be consecutive, otherwise
1532            # returned sliced tensor could be wrong
1533            if not is_sorted_and_consecutive(indices):
1534                continue
1535            # case 1: the cat uses all getitems from the split
1536            if len(split_sections) == len(cat_user.args[0]):  # type: ignore[arg-type]
1537                # replace the users of the cat node to be the input of the split node
1538                cat_user.replace_all_uses_with(split_node.args[0])  # type: ignore[arg-type]
1539                # remove the cat node
1540                graph.erase_node(cat_user)
1541                counters["inductor"]["mutate_cat_pass"] += 1
1542            # case 2: the cat uses some getitems from the split
1543            elif is_node_meta_valid(split_node.args[0]):  # type: ignore[arg-type]
1544                # check the split dim, and construct the slice tuple
1545                start_fused_size = calculate_fused_tensor_size(
1546                    split_node, list(range(indices[0]))
1547                )
1548                end_fused_size = start_fused_size + calculate_fused_tensor_size(
1549                    split_node, indices
1550                )
1551                slice_list = []
1552                for i in range(len(split_node.args[0].meta["example_value"].shape)):  # type: ignore[union-attr]
1553                    if i != split_dim:
1554                        slice_list.append(slice(None, None, None))
1555                    else:
1556                        slice_list.append(slice(start_fused_size, end_fused_size, None))
1557                with graph.inserting_after(split_node):
1558                    slice_node = graph.call_function(
1559                        operator.getitem,
1560                        args=(split_node.args[0], tuple(slice_list)),
1561                    )
1562                    cat_user.replace_all_uses_with(slice_node)
1563                    slice_node.meta.update(cat_user.meta)
1564
1565                # remove the cat node
1566                graph.erase_node(cat_user)
1567                counters["inductor"]["mutate_cat_pass"] += 1
1568
1569
1570@register_graph_pattern(
1571    CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE),
1572    pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
1573)
1574def normalize_cat_default_aten(match: Match, *args, **kwargs):
1575    cat_node = match.nodes[0]
1576    graph = match.graph
1577    tensors = get_arg_value(cat_node, 0, "tensors")
1578    cat_dim = get_arg_value(cat_node, 1, "dim")
1579    if cat_dim is None:
1580        cat_axis = cat_node.kwargs.get("axis")
1581        if cat_axis is not None:
1582            cat_dim = cat_axis
1583        else:
1584            cat_dim = 0
1585    if tensors is None or cat_dim is None:
1586        log.debug("couldn't find cat args")
1587        return
1588    assert isinstance(tensors, (list, tuple))
1589    for tensor in itertools.chain([cat_node], tensors):
1590        if "val" not in tensor.meta:
1591            log.debug("val absent for node: %s", tensor)
1592            return
1593
1594    ndim = cat_node.meta["val"].dim()
1595
1596    def is_empty_tensor(x: torch.fx.Node) -> bool:
1597        # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor
1598        x_shape = x.meta["val"].shape
1599        return len(x_shape) == 1 and x_shape[0] == 0
1600
1601    assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
1602
1603    if cat_dim < 0:  # Normalize cat dim
1604        cat_dim += ndim
1605
1606    with graph.inserting_after(cat_node):
1607        new_cat_node = graph.call_function(
1608            torch.ops.aten.cat.default,
1609            args=(tensors,),
1610            kwargs={"dim": cat_dim},
1611        )
1612    cat_node.replace_all_uses_with(new_cat_node)
1613    new_cat_node.meta.update(cat_node.meta)
1614    graph.erase_node(cat_node)
1615    counters["inductor"]["normalization_aten_pass"] += 1
1616
1617
1618@register_graph_pattern(
1619    CallFunction(
1620        torch.ops.aten.cat,
1621        ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)),
1622        _users=MULTIPLE,
1623    ),
1624    pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"),
1625)
1626def merge_unbind_stack_aten(match: Match, *args, **kwargs):
1627    node = match.nodes[-1]
1628    graph = match.graph
1629    # pyre-fixme[6]
1630    unsqueeze_nodes = list(node.args[0])  # type: ignore[arg-type]
1631    cat_dim = get_arg_value(node, 1, "dim")
1632    # check the unsqueeze nodes come from the select nodes
1633    if not all(
1634        get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select
1635        for unsqueeze_node in unsqueeze_nodes
1636    ):
1637        return
1638    select_nodes = [
1639        get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes
1640    ]
1641    parent_of_select_node = get_arg_value(select_nodes[0], 0, "input")
1642    # check the target of select_nodes are the same
1643    if not all(
1644        select_node.target == torch.ops.aten.select for select_node in select_nodes
1645    ):
1646        return
1647    # check the select nodes come from the same parent node
1648    if not all(
1649        get_arg_value(select_node, 0, "input") == parent_of_select_node
1650        for select_node in select_nodes
1651    ):
1652        return
1653    if len(unsqueeze_nodes) != len(select_nodes):
1654        return
1655    # check the select nodes have the same dim
1656    if not all(
1657        get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes
1658    ):
1659        return
1660    # check the select nodes have consecutive indices starting from 0
1661    if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive(
1662        [get_arg_value(select_node, 2, "index") for select_node in select_nodes]
1663    ):
1664        return
1665    # check the users of parent of select node only from unsqueeze nodes that go to the cat node
1666    # we simply check the number of users of the parent of select node
1667    if len(parent_of_select_node.users.keys()) != len(node.args[0]):  # type: ignore[arg-type]
1668        return
1669    node.replace_all_uses_with(parent_of_select_node)
1670    graph.erase_node(node)
1671    for unsqueeze_node in unsqueeze_nodes:
1672        graph.erase_node(unsqueeze_node)
1673    for select_node in select_nodes:
1674        if len(select_node.users) == 0:
1675            graph.erase_node(select_node)
1676    counters["inductor"]["unbind_stack_aten_pass"] += 1
1677
1678
1679def divide_into_consecutive_sublists(indices: List[int]) -> List[List[int]]:
1680    n = len(indices)
1681    if n <= 1:
1682        return [indices]
1683
1684    # Initialize the list of sublists
1685    sublists = []
1686
1687    # Iterate over the indices
1688    i = 0
1689    while i < n:
1690        # Initialize the current sublist
1691        sublist = [indices[i]]
1692
1693        # Iterate over the remaining indices
1694        j = i + 1
1695        while j < n and indices[j] == indices[j - 1] + 1:
1696            # Add the next index to the current sublist
1697            sublist.append(indices[j])
1698            j += 1
1699
1700        # Add the current sublist to the list of sublists
1701        sublists.append(sublist)
1702        # Move to the next index
1703        i = j
1704
1705    return sublists
1706
1707
1708def update_args_from_split_getitem(
1709    graph: torch.fx.Graph,
1710    node: torch.fx.Node,
1711    getitem_indices: List[int],
1712    parents_seen: List[torch.fx.Node],
1713    new_cat_args: List[torch.fx.Node],
1714    new_cat_args_meta: List[torch.fx.Node],
1715    idx_to_getitems: Dict[int, torch.fx.Node],
1716    threshold_to_cat: int = 2,
1717):
1718    split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1])
1719    # case 1: the number of getitems is the same as the split size, elimiate the split
1720    if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive(
1721        getitem_indices
1722    ):
1723        # we can merge the getitems from the previous parent
1724        new_cat_args.append(split_input)
1725        new_cat_args_meta.append(split_input.meta["example_value"])
1726    else:
1727        if len(getitem_indices) > 0:
1728            # case 2: the number of getitems is smaller than the split size but larger than the threshold, and
1729            # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups
1730            geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices)
1731            for sublist in geitem_indices_sublist:
1732                if len(sublist) >= threshold_to_cat:
1733                    # case 2: the number of getitems is smaller than the split size but larger than the threshold
1734                    # we need to slice the input of parent
1735                    start_fused_size = sum(split_size[: sublist[0]])
1736                    end_fused_size = sum(split_size[: sublist[-1] + 1])
1737                    slice_list = []
1738                    for i in range(len(split_input.meta["example_value"].shape)):  # type: ignore[union-attr]
1739                        if i != split_dim:
1740                            slice_list.append(slice(None, None, None))
1741                        else:
1742                            slice_list.append(
1743                                slice(start_fused_size, end_fused_size, None)
1744                            )
1745                    with graph.inserting_after(node):
1746                        slice_node = graph.call_function(
1747                            operator.getitem,
1748                            args=(split_input, tuple(slice_list)),
1749                        )
1750                        slice_node.meta["example_value"] = split_input.meta[
1751                            "example_value"
1752                        ][tuple(slice_list)]
1753                        new_cat_args.append(slice_node)
1754                        new_cat_args_meta.append(slice_node.meta["example_value"])
1755                else:
1756                    # case 3: the number of getitems is smaller than the threshold, no merge is done
1757                    # get the getitems based on the indexes
1758                    for i in sublist:
1759                        new_cat_args.append(idx_to_getitems[i])
1760                        new_cat_args_meta.append(
1761                            idx_to_getitems[i].meta["example_value"]
1762                        )
1763
1764
1765def reshape_cat_node(
1766    graph: torch.fx.Graph,
1767    cat_node: torch.fx.Node,
1768    unbind_input: torch.fx.Node,
1769    cat_dim: int,
1770    unbind_dim: int,
1771    cat_shape: torch.Size,
1772) -> torch.fx.Node:
1773    if cat_dim != unbind_dim:
1774        # construct the permute node args, which has the same shape as the slice node
1775        # then it has the same dim as the unbind_input, i.e., shape of cat + 1
1776        with graph.inserting_after(cat_node):
1777            permute_list = list(range(len(cat_shape) + 1))
1778            permute_list[unbind_dim], permute_list[cat_dim] = (
1779                permute_list[cat_dim],
1780                permute_list[unbind_dim],
1781            )
1782            permute_node = graph.call_function(
1783                torch.permute,
1784                args=(unbind_input, permute_list),
1785            )
1786            permute_node.meta["example_value"] = torch.permute(
1787                unbind_input.meta["example_value"], permute_list
1788            )  # type: ignore[arg-type]
1789    else:
1790        permute_node = unbind_input
1791    with graph.inserting_after(permute_node):
1792        reshape_node = graph.call_function(
1793            torch.reshape, args=(permute_node, tuple(cat_shape))
1794        )
1795        reshape_node.meta["example_value"] = torch.reshape(
1796            permute_node.meta["example_value"], tuple(cat_shape)
1797        )  # type: ignore[arg-type]
1798    return reshape_node
1799
1800
1801def update_args_from_unbind_getitem(
1802    graph: torch.fx.Graph,
1803    node: torch.fx.Node,  # cat or stack node
1804    getitem_indices: List[int],
1805    parents_seen: List[torch.fx.Node],
1806    new_cat_args: List[torch.fx.Node],
1807    new_cat_args_meta: List[torch.fx.Node],
1808    idx_to_getitems: Dict[int, torch.fx.Node],
1809    threshold_to_cat: int = 2,
1810):
1811    unbind_input = get_arg_value(parents_seen[-1], 0, "input")  # split or unbind input
1812    unbind_dim = get_arg_value(parents_seen[-1], 1, "dim")  # split or unbind dim
1813    cat_dim = get_arg_value(node, 1, "dim")  # cat or stack dim
1814    # case 1: the number of getitems is the same as the split size, elimiate the split
1815    size = list(unbind_input.meta["example_value"].shape)[unbind_dim]
1816    if size == len(getitem_indices):
1817        cat_shape = torch.cat(
1818            [idx_to_getitems[i].meta["example_value"] for i in getitem_indices],
1819            dim=cat_dim,
1820        ).shape
1821        # we can merge the getitems from the previous parent
1822        reshape_node = reshape_cat_node(
1823            graph, node, unbind_input, cat_dim, unbind_dim, cat_shape
1824        )
1825        new_cat_args.append(reshape_node)
1826        new_cat_args_meta.append(reshape_node.meta["example_value"])
1827    elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive(
1828        getitem_indices
1829    ):
1830        # case 2: the number of getitems is smaller than the split size but larger than the threshold
1831        # we need to slice the input of parent
1832        cat_shape = torch.cat(
1833            [idx_to_getitems[i].meta["example_value"] for i in getitem_indices],
1834            dim=cat_dim,
1835        ).shape
1836        slice_list = []
1837        for i in range(len(cat_shape) + 1):
1838            if i != unbind_dim:
1839                slice_list.append(slice(None, None, None))  # start, end, step
1840            else:
1841                slice_list.append(
1842                    slice(getitem_indices[0], getitem_indices[-1] + 1, None)
1843                )
1844        with graph.inserting_after(node):
1845            slice_node = graph.call_function(
1846                operator.getitem,
1847                args=(unbind_input, tuple(slice_list)),
1848            )
1849            slice_node.meta["example_value"] = torch.narrow(
1850                unbind_input.meta["example_value"],
1851                unbind_dim,
1852                getitem_indices[0],
1853                getitem_indices[-1] - getitem_indices[0] + 1,
1854            )
1855            reshape_node = reshape_cat_node(
1856                graph, node, slice_node, cat_dim, unbind_dim, cat_shape
1857            )
1858            new_cat_args.append(reshape_node)
1859            new_cat_args_meta.append(reshape_node.meta["example_value"])
1860    else:
1861        # case 3: the number of getitems is smaller than the threshold, no merge is done
1862        # get the getitems based on the indexes
1863        for i in getitem_indices:
1864            new_cat_args.append(idx_to_getitems[i])
1865            new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"])
1866
1867
1868def construct_cat_args(
1869    graph: torch.fx.Graph,
1870    cat_or_stack_node: torch.fx.Node,
1871    inputs: List[torch.fx.Node],
1872    split_or_unbind_node: torch.fx.Node,
1873    threshold_to_cat: int = 2,
1874    run_update_func: Callable = update_args_from_split_getitem,  # type: ignore[type-arg]
1875) -> Tuple[List[torch.fx.Node], List[torch.Tensor]]:
1876    new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {}  # type: ignore[var-annotated]
1877    new_cat_args_meta = []  # type: ignore[var-annotated]
1878    for input in inputs:
1879        if input.target != operator.getitem:
1880            # update the last arg based on getitem_indices and parents_seens
1881            if len(parents_seen) > 0:
1882                run_update_func(  # type: ignore[arg-type, union-attr]
1883                    graph,
1884                    cat_or_stack_node,
1885                    getitem_indices,
1886                    parents_seen,
1887                    new_cat_args,
1888                    new_cat_args_meta,
1889                    idx_to_getitems,  # type: ignore[arg-type, union-attr]
1890                    threshold_to_cat,
1891                )
1892            new_cat_args.append(input)
1893            new_cat_args_meta.append(input.meta["example_value"])
1894            # reset the indices array
1895            getitem_indices, idx_to_getitems = [], {}
1896        else:
1897            # get the parent node of the getitem input
1898            parent, idx = input.args[0], input.args[1]  # type: ignore[union-attr]
1899            if parent.target != split_or_unbind_node.target:  # type: ignore[union-attr]
1900                new_cat_args.append(input)
1901                new_cat_args_meta.append(input.meta["example_value"])
1902                continue
1903            # cannot use parents_seen to check since the first item could be non getitem node
1904            if len(parents_seen) == 0:
1905                parents_seen.append(parent)
1906                idx_to_getitems[idx] = input
1907                getitem_indices.append(idx)
1908                # case: we only have one getitem input, and it is in the last position
1909                if input == inputs[-1]:
1910                    new_cat_args.append(input)
1911                    new_cat_args_meta.append(input.meta["example_value"])
1912                continue
1913                # if it is the last input in the tensors, we also check if it can be optimized
1914            if parent != parents_seen[-1] or input == inputs[-1]:
1915                if input == inputs[-1]:
1916                    getitem_indices.append(idx)
1917                    idx_to_getitems[idx] = input
1918                run_update_func(  # type: ignore[arg-type, union-attr]
1919                    graph,
1920                    cat_or_stack_node,
1921                    getitem_indices,
1922                    parents_seen,
1923                    new_cat_args,
1924                    new_cat_args_meta,
1925                    idx_to_getitems,  # type: ignore[arg-type, union-attr]
1926                    threshold_to_cat,
1927                )
1928                # reset the indices array for the next parent
1929                # remember to add the last element since it is the first
1930                # item in this round of parent
1931                # add the parent to the list of seen parents
1932                parents_seen.append(parent)
1933                getitem_indices, idx_to_getitems = [idx], {idx: input}
1934            else:
1935                getitem_indices.append(idx)
1936                idx_to_getitems[idx] = input
1937    return new_cat_args, new_cat_args_meta
1938
1939
1940def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.Node]):
1941    nodes = set()
1942    for input in inputs:
1943        if input.target == operator.getitem:
1944            nodes.add(input.args[0])  # type: ignore[union-attr]
1945        if len(input.users.keys()) == 0:
1946            graph.erase_node(input)
1947    # check the split node to remove if it has no users
1948    for node in nodes:
1949        if len(node.users.keys()) == 0:  # type: ignore[union-attr]
1950            graph.erase_node(node)  # type: ignore[arg-type]
1951
1952
1953# ############pattern to be optimized is#########
1954
1955#               split_node(dim=1)  -> user=multiple
1956#       /           \         ...       /         \
1957# other inputs    getitem        getitem     getitem   -> user=multiple
1958#            \                    /            \
1959#                cat(user=mul, dim=1)             other_op
1960#                      |
1961
1962# ################after transformation#############
1963
1964#                 split_node(dim=1)     other inputs    -> -> user=multiple
1965#                           /           \
1966#                         cat (user=mul, dim=1, split_node)
1967
1968
1969@register_graph_pattern(
1970    CallFunctionVarArgs(torch.cat, users=MULTIPLE),
1971    pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"),
1972)
1973@register_graph_pattern(
1974    CallFunction(
1975        torch.cat,
1976        getitem_split,
1977        dim=Ignored(),
1978        _users=MULTIPLE,
1979    ),
1980    pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"),
1981)
1982def split_cat_to_slices(match: Match, split_sections: List[int], dim: int):
1983    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
1984        return
1985    split_nodes = [node for node in match.nodes if node.target == torch.split]
1986    if split_nodes:
1987        split_node = next(node for node in split_nodes)
1988    else:
1989        # Handle the case where there are no nodes with a target of torch.split
1990        return
1991    split_dim = get_arg_value(split_node, 2, "dim") or 0
1992    graph = match.graph
1993    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
1994        "split_cat_to_slices_pass"
1995    ].get("threshold_to_cat", 10)
1996    # get the cat_node and check its inputs and meta data
1997    next_users = find_next_users(split_node)
1998    for cat_node in next_users:
1999        if cat_node.target != torch.cat or not is_node_meta_valid(cat_node):
2000            continue
2001        cat_inputs = get_arg_value(cat_node, 0, "tensors")  # type: ignore[union-attr]
2002        new_cat_args, _ = construct_cat_args(
2003            graph,
2004            cat_node,
2005            cat_inputs,
2006            split_node,
2007            threshold_to_cat,
2008            update_args_from_split_getitem,
2009        )
2010        # At least one node would be in the returned new_cat_args
2011        # case 1: if new cat args has length 1, we can remove the cat node
2012        if len(new_cat_args) == 1:
2013            cat_node.replace_all_uses_with(new_cat_args[0])
2014            # remove inputs of cat_node if they have no users
2015            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
2016            graph.erase_node(cat_node)
2017            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
2018            counters["inductor"]["split_cat_to_slices_pass"] += 1
2019            continue
2020        if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs):
2021            new_args = (new_cat_args,)
2022            with graph.inserting_after(cat_node):
2023                new_cat_node = graph.call_function(
2024                    torch.cat,
2025                    args=new_args,
2026                    # split and cat have the same dim
2027                    kwargs={"dim": split_dim},
2028                )
2029                cat_node.replace_all_uses_with(new_cat_node)
2030                new_cat_node.meta.update(cat_node.meta)
2031                # remove the cat node
2032                graph.erase_node(cat_node)
2033                remove_split_unbind_children(graph, cat_inputs)
2034                counters["inductor"]["split_cat_to_slices_pass"] += 1
2035
2036
2037# ############pattern to be optimized is#########
2038
2039#               unbind(dim=0)  -> user=multiple
2040#       /           \         ...       /         \
2041# getitem    getitem        getitem     getitem   -> user=multiple
2042#            \                    /            \
2043#                cat(user=mul, dim=1)             other_op
2044#                      |
2045
2046# ################after transformation#############
2047
2048#                 input_of_unbind
2049#                           |    \
2050#                         slice
2051#                           |
2052#                          view
2053#                           |
2054
2055
2056@register_graph_pattern(
2057    CallFunction(
2058        torch.cat,
2059        getitem_unbind,
2060        dim=Ignored(),
2061        _users=MULTIPLE,
2062    ),
2063    pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"),
2064)
2065def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
2066    unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
2067    graph = match.graph
2068    # get the cat_node and check its inputs and meta data
2069    next_users = find_next_users(unbind_node)
2070    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
2071        "unbind_cat_to_view_pass"
2072    ].get("threshold_to_cat", 10)
2073    # get the cat_node and check its inputs and meta data
2074    for cat_node in next_users:
2075        if cat_node.target != torch.cat or not is_node_meta_valid(cat_node):
2076            continue
2077        inputs = get_arg_value(cat_node, 0, "tensors")  # type: ignore[union-attr]
2078        new_cat_args, new_cat_args_meta = construct_cat_args(
2079            graph,
2080            cat_node,
2081            inputs,
2082            unbind_node,
2083            threshold_to_cat,
2084            update_args_from_unbind_getitem,
2085        )
2086        # get the view shape
2087        # At least one node would be in the returned new_cat_args
2088        # case 1: only one node in the new cat args, don't need to cat
2089        if len(new_cat_args) == 1:
2090            cat_node.replace_all_uses_with(new_cat_args[0])
2091            # remove inputs of cat_node if they have no users
2092            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
2093            graph.erase_node(cat_node)
2094            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
2095            counters["inductor"]["unbind_cat_to_view_pass"] += 1
2096            continue
2097        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
2098            # get the view shape
2099            cat_dim = get_arg_value(cat_node, 1, "dim")
2100            with graph.inserting_after(cat_node):
2101                new_cat_node = graph.call_function(
2102                    torch.cat,
2103                    args=(new_cat_args,),
2104                    kwargs={"dim": cat_dim},
2105                )
2106                new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim)  # type: ignore[arg-type]
2107                cat_node.replace_all_uses_with(new_cat_node)
2108                new_cat_node.meta.update(cat_node.meta)
2109            # remove inputs of cat_node if they have no users
2110            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
2111            graph.erase_node(cat_node)
2112            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
2113            counters["inductor"]["unbind_cat_to_view_pass"] += 1
2114
2115
2116def reshape_cat_node_to_stack(
2117    graph: torch.fx.Graph,
2118    cat_node: torch.fx.Node,
2119    stack_node: torch.fx.Node,
2120    split_or_unbind_dim: int,
2121) -> None:
2122    # reshape the cat node to the stack node shape
2123    stack_shape = stack_node.meta["example_value"].shape
2124    stack_dim = _get_dim(stack_node)
2125    if stack_dim != split_or_unbind_dim:
2126        # case 1: the stack dim is not the same as the split dim
2127        # we need to reshape the split input before we do the reshape
2128        reshape_list = list(stack_shape)
2129        reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = (
2130            reshape_list[split_or_unbind_dim],
2131            reshape_list[stack_dim],
2132        )
2133        reshape_node = graph.call_function(
2134            torch.reshape,
2135            args=(cat_node, tuple(reshape_list)),
2136        )
2137        reshape_node.meta["example_value"] = torch.reshape(
2138            cat_node.meta["example_value"], tuple(reshape_list)
2139        )
2140        permute_list = list(range(len(stack_shape)))
2141        permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
2142            permute_list[split_or_unbind_dim],
2143            permute_list[stack_dim],
2144        )
2145        permute_node = graph.call_function(
2146            torch.permute,
2147            args=(reshape_node, permute_list),
2148        )
2149        permute_node.meta["example_value"] = torch.permute(
2150            reshape_node.meta["example_value"], permute_list
2151        )
2152    else:
2153        # case 2: the stack dim is the same as the split dim
2154        # we can directly reshape the split input
2155        permute_node = cat_node
2156    reshape_node = graph.call_function(
2157        torch.Tensor.view,
2158        args=(permute_node, *stack_shape),  # type: ignore[arg-type]
2159    )
2160    stack_node.replace_all_uses_with(reshape_node)
2161    reshape_node.meta.update(stack_node.meta)
2162    stack_inputs = stack_node.args[0]  # type: ignore[union-attr]
2163    # remove stack node
2164    graph.erase_node(stack_node)
2165    # check the input of stack node, and remove nodes that have no users
2166    remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
2167
2168
2169def convert_reshape_cat_arg_to_stack(
2170    graph: torch.fx.Graph,
2171    cat_node: torch.fx.Node,
2172    stack_node: torch.fx.Node,
2173    stack_node_shape: torch.Size,
2174    stack_dim: int,
2175    split_dim: int,
2176) -> torch.fx.Node:
2177    # reshape the cat node to the stack node shape
2178    cat_shape = cat_node.meta["example_value"].shape
2179    if stack_dim != split_dim:
2180        permute_list = list(range(len(cat_shape)))
2181        permute_list[stack_dim], permute_list[split_dim] = (
2182            permute_list[split_dim],
2183            permute_list[stack_dim],
2184        )
2185        permute_node = graph.call_function(
2186            torch.permute,
2187            args=(cat_node, permute_list),
2188        )
2189        permute_node.meta["example_value"] = torch.permute(
2190            cat_node.meta["example_value"], permute_list
2191        )
2192    else:
2193        permute_node = cat_node
2194    reshape_node = graph.call_function(
2195        torch.Tensor.view,
2196        args=(permute_node, tuple(stack_node_shape)),  # type: ignore[arg-type]
2197    )
2198    reshape_node.meta["example_value"] = torch.Tensor.view(
2199        permute_node.meta["example_value"], tuple(stack_node_shape)  # type: ignore[arg-type]
2200    )
2201    return reshape_node
2202
2203
2204# ############pattern to be optimized is#########
2205#    |           |
2206#   split       split   (dim=1)
2207#   /     \      /   \
2208# getitem  ...        getitem      other ops
2209#        \      |       /            /
2210#       stack(user=mul, dim=1 or 2) -> can be different dim
2211#          |
2212
2213# ################after transformation#############
2214
2215#       /           \         ...       /         \
2216# getitem    getitem        getitem     getitem   -> user=multiple
2217#       \      /
2218#       cat(user=mul, dim=1) cat_other_opts
2219#          \                  /
2220#                  cat
2221#                   |
2222#                  view
2223#                   |
2224
2225
2226@register_graph_pattern(
2227    CallFunction(
2228        torch.stack,
2229        getitem_split,
2230        dim=Ignored(),
2231        _users=MULTIPLE,
2232    ),
2233    pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"),
2234)
2235def split_stack_to_cats(match: Match, split_sections: List[int], dim: int):
2236    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
2237        return
2238    split_node = next(node for node in match.nodes if node.target == torch.split)
2239    split_dim = get_arg_value(split_node, 2, "dim") or 0
2240    graph = match.graph
2241    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
2242        "split_stack_to_cats_pass"
2243    ].get("threshold_to_cat", 10)
2244    # get the stack_node and check its inputs and meta data
2245    next_users = find_next_users(split_node)
2246    for stack_node in next_users:
2247        if stack_node.target != torch.stack or not is_node_meta_valid(stack_node):
2248            continue
2249        inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
2250        new_cat_args, new_cat_args_meta = construct_cat_args(
2251            graph,
2252            stack_node,
2253            inputs,
2254            split_node,
2255            threshold_to_cat,
2256            update_args_from_split_getitem,
2257        )
2258        # At least one node would be in the returned new_cat_args
2259        # case 1: only one node in the new cat args, don't need to cat
2260        if len(new_cat_args) == 1:
2261            reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim)
2262            counters["inductor"]["split_stack_to_cats_pass"] += 1
2263            continue
2264        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
2265            with graph.inserting_after(stack_node):
2266                cat_node = graph.call_function(
2267                    torch.cat,
2268                    args=(new_cat_args,),
2269                    kwargs={"dim": split_dim},
2270                )
2271                cat_node.meta["example_value"] = torch.cat(  # type: ignore[arg-type]
2272                    new_cat_args_meta, dim=split_dim
2273                )
2274                reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim)
2275                counters["inductor"]["split_stack_to_cats_pass"] += 1
2276
2277
2278# ############pattern to be optimized is#########
2279
2280#               unbind(dim=1)  -> user=multiple
2281#                  \         ...       /         \
2282# others    getitem        getitem     getitem   -> user=multiple
2283#  \          \                    /            \
2284#                stack(user=mul, dim=1)             other_op
2285#                      |
2286
2287# ################after transformation#############
2288
2289#                 input_of_unbind
2290#                           |    \
2291#                         slice
2292#                           |
2293#                          view   others
2294#                           |    /
2295#                          stack
2296#                           |
2297
2298
2299@register_graph_pattern(
2300    CallFunction(
2301        torch.stack,
2302        getitem_unbind,
2303        dim=Ignored(),
2304        _users=MULTIPLE,
2305    ),
2306    pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"),
2307)
2308def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int):
2309    unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
2310    graph = match.graph
2311    # get the cat_node and check its inputs and meta data
2312    next_users = find_next_users(unbind_node)
2313    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
2314        "unbind_stack_to_slices_pass"
2315    ].get("threshold_to_cat", 10)
2316    # get the cat_node and check its inputs and meta data
2317    for stack_node in next_users:
2318        if stack_node.target != torch.stack or not is_node_meta_valid(stack_node):
2319            continue
2320        inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
2321        new_cat_args, new_cat_args_meta = construct_cat_args(
2322            graph,
2323            stack_node,
2324            inputs,
2325            unbind_node,
2326            threshold_to_cat,
2327            update_args_from_unbind_getitem,
2328        )
2329        unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0
2330        # At least one node would be in the returned new_cat_args
2331        # case 1: only one node in the new cat args, don't need to cat
2332        if len(new_cat_args) == 1:
2333            reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim)
2334            counters["inductor"]["unbind_stack_to_slices_pass"] += 1
2335            continue
2336        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
2337            # get the view shape
2338            cat_dim = get_arg_value(stack_node, 1, "dim")
2339            with graph.inserting_after(stack_node):
2340                new_cat_node = graph.call_function(
2341                    torch.cat,
2342                    args=(new_cat_args,),
2343                    kwargs={"dim": cat_dim},
2344                )
2345                new_cat_node.meta["example_value"] = torch.cat(
2346                    new_cat_args_meta, dim=cat_dim
2347                )
2348                reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim)
2349            counters["inductor"]["unbind_stack_to_slices_pass"] += 1
2350
2351
2352# ############pattern to be optimized is#########
2353#                   input
2354#                     |
2355#               split(dim=1)  -> user=multiple
2356#                  \         \
2357# others    getitem        getitem
2358#  \          \               /
2359#  reshape     reshape      reshape     other_op
2360#  \          \             /         /
2361#                stack(user=mul, dim=0)
2362#                      |
2363
2364# ################after transformation#############
2365#                          input
2366#                           |
2367#                         permute
2368#                           |
2369#                         reshape   others
2370#                           |    /
2371#                          cat (dim=0)
2372#                           |
2373
2374
2375def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> List[int]:
2376    # cat_arg must be the split input
2377    view_shape_list = []
2378    for user in cat_arg.users.keys():
2379        if user.target == torch.split:
2380            for getitem in user.users.keys():
2381                if getitem.target == operator.getitem:
2382                    reshape_user = [
2383                        user
2384                        for user in getitem.users.keys()
2385                        if user.target == torch.reshape
2386                    ]
2387                    if len(reshape_user) > 0:
2388                        view_shape_list = list(
2389                            reshape_user[0]
2390                            .meta["example_value"]
2391                            .unsqueeze(stack_dim)
2392                            .shape
2393                        )
2394                        view_shape_list[stack_dim] = -1
2395                        return view_shape_list
2396    return view_shape_list
2397
2398
2399@register_graph_pattern(
2400    CallFunction(
2401        torch.stack,
2402        reshape_getitem_split,
2403        dim=Ignored(),
2404        _users=MULTIPLE,
2405    ),
2406    pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"),
2407)
2408def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
2409    split_node = next(node for node in match.nodes if node.target == torch.split)
2410    split_dim = _get_dim(split_node)
2411    split_users = list(split_node.users.keys())
2412    stack_nodes = [node for node in match.nodes if node.target == torch.stack]
2413    graph = match.graph
2414    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
2415        "move_reshape_out_of_split_stack_pass"
2416    ].get("threshold_to_cat", 10)
2417    for stack_node in stack_nodes:
2418        if not is_node_meta_valid(stack_node):
2419            log.debug("example value absent for node: %s", stack_node)
2420            continue
2421        stack_dim = _get_dim(stack_node)
2422        stack_inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
2423        inputs = []
2424        for stack_input in stack_inputs:
2425            if stack_input.target != torch.reshape:
2426                inputs.append(stack_input)
2427            else:
2428                inputs.append(stack_input.args[0])  # type: ignore[union-attr]
2429        new_cat_args, new_cat_args_meta = construct_cat_args(
2430            graph,
2431            stack_node,
2432            inputs,
2433            split_node,
2434            threshold_to_cat,
2435            update_args_from_split_getitem,
2436        )
2437        # At least one node would be in the returned new_cat_args
2438        # case 1: only one node in the new cat args, don't need to cat
2439        if len(new_cat_args) == 1:
2440            reshape_node = convert_reshape_cat_arg_to_stack(
2441                graph,
2442                new_cat_args[0],
2443                stack_node,
2444                stack_node.meta["example_value"].shape,
2445                stack_dim,
2446                split_dim,
2447            )
2448            stack_node.replace_all_uses_with(reshape_node)
2449            # remove stack node
2450            graph.erase_node(stack_node)
2451            # check the input of stack node, and remove nodes that have no users
2452            remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
2453            remove_split_unbind_children(graph, split_users)  # type: ignore[arg-type]
2454            counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
2455            continue
2456        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
2457            # decompose the cat args into multiple stack nodes, i.e., we stack
2458            # all the nodes exist in the stack inputs and reshape the rest followed by a cat
2459            stack_node_input, stack_node_input_meta, cat_inputs = [], [], []  # type: ignore[var-annotated]
2460            for cat_arg in new_cat_args:
2461                if cat_arg not in stack_inputs:
2462                    if len(stack_node_input) > 0:
2463                        with graph.inserting_after(stack_node):
2464                            decomposed_stack_node = graph.call_function(
2465                                torch.stack,
2466                                args=(stack_node_input,),
2467                                kwargs={"dim": stack_dim},
2468                            )
2469                            decomposed_stack_node.meta["example_value"] = torch.stack(
2470                                stack_node_input_meta, dim=stack_dim
2471                            )
2472                            cat_inputs.append(decomposed_stack_node)
2473                    # cat_arg must be the split input
2474                    view_shape_list = get_view_shape_list(cat_arg, stack_dim)
2475                    stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape  # type: ignore[union-attr]
2476                    cat_inputs.append(
2477                        convert_reshape_cat_arg_to_stack(
2478                            graph,
2479                            cat_arg,
2480                            stack_node,
2481                            stack_node_shape,
2482                            stack_dim,
2483                            split_dim,
2484                        )
2485                    )
2486                    stack_node_input, stack_node_input_meta = [], []
2487                else:
2488                    stack_node_input.append(cat_arg)
2489                    stack_node_input_meta.append(cat_arg.meta["example_value"])
2490
2491            if len(stack_node_input) > 0:
2492                with graph.inserting_after(stack_node):
2493                    decomposed_stack_node = graph.call_function(
2494                        torch.stack,
2495                        args=(stack_node_input,),
2496                        kwargs={"dim": stack_dim},
2497                    )
2498                    decomposed_stack_node.meta["example_value"] = torch.stack(
2499                        stack_node_input_meta, dim=stack_dim
2500                    )
2501                    cat_inputs.append(decomposed_stack_node)
2502
2503            with graph.inserting_after(stack_node):
2504                cat_node = graph.call_function(
2505                    torch.cat,
2506                    args=(cat_inputs,),
2507                    kwargs={"dim": stack_dim},
2508                )
2509                stack_node.replace_all_uses_with(cat_node)
2510                cat_node.meta.update(stack_node.meta)
2511                graph.erase_node(stack_node)
2512                remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
2513                remove_split_unbind_children(graph, split_users)  # type: ignore[arg-type]
2514            counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
2515