xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/micro_pipeline_tp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from collections import defaultdict
4from dataclasses import dataclass, field
5from typing import Any, cast, Dict, List, Optional, Set
6
7import torch
8
9from .. import config, inductor_prims
10from ..pattern_matcher import (
11    CallFunction,
12    Ignored,
13    KeywordArg,
14    ListOf,
15    Match,
16    MULTIPLE,
17    PatternExpr,
18    PatternMatcherPass,
19)
20
21
22aten = torch.ops.aten
23patterns = PatternMatcherPass()
24
25
26def _is_backward(graph: torch.fx.Graph) -> bool:
27    placeholders = []
28    for node in graph.nodes:
29        if node.op != "placeholder":
30            break
31        placeholders.append(node)
32    return not all(node.name.startswith("primal") for node in placeholders)
33
34
35def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float:
36    return M * N * K / (M * K + N * K + M * N)
37
38
39def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]:
40    return [x for x in nodes if x.target == target]
41
42
43def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]:
44    ancestors = set()
45    ancestors.add(node)
46    cur_nodes = [node]
47    while len(cur_nodes) > 0:
48        new_nodes = []
49        for node in cur_nodes:
50            for inp in node.all_input_nodes:
51                if inp not in ancestors:
52                    ancestors.add(inp)
53                    new_nodes.append(inp)
54        cur_nodes = new_nodes
55    return {node for node in ancestors if node.op != "placeholder"}
56
57
58def _get_tensor(node: torch.fx.Node) -> torch.Tensor:
59    val = node.meta["val"]
60    assert isinstance(val, torch.Tensor)
61    return val
62
63
64@dataclass
65class _AllGatherMatch:
66    match: Match
67    shard_node: torch.fx.Node
68    ag_node: torch.fx.Node
69    res_node: torch.fx.Node
70    gather_dim: int
71    group_name: str
72
73    def replace_with(self, new_node: torch.fx.Node) -> None:
74        self.res_node.replace_all_uses_with(new_node)
75
76    def erase(self) -> None:
77        for node in reversed(self.match.nodes):
78            if len(node.users) == 0:
79                node.graph.erase_node(node)
80
81
82def find_all_gather_patterns(graph: torch.fx.Graph):
83    c10d = torch.ops._c10d_functional
84
85    def make_zero_dim_all_gather_pattern(shard):
86        return CallFunction(
87            c10d.wait_tensor.default,
88            CallFunction(
89                c10d.all_gather_into_tensor.default,
90                shard,
91                Ignored(),
92                KeywordArg("group_name"),
93            ),
94        )
95
96    # Matches funcol.all_gather_tensor with gather_dim == 0
97    zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard"))
98
99    def make_all_gather_split_pattern(shard):
100        return CallFunction(
101            operator.getitem,
102            CallFunction(
103                aten.split.Tensor,
104                make_zero_dim_all_gather_pattern(shard),
105                Ignored(),
106                _users=MULTIPLE,
107            ),
108            Ignored(),
109        )
110
111    def make_cat_pattern(splits):
112        return CallFunction(
113            aten.cat.default,
114            ListOf(splits),
115            KeywordArg("gather_dim"),
116        )
117
118    # Matches funcol.all_gather_tensor with gather_dim > 0
119    non_zero_dim_all_gather_pattern = make_cat_pattern(
120        make_all_gather_split_pattern(KeywordArg("shard")),
121    )
122
123    # Match a zero-dim all-gather in which the data is transferred as uint8 and
124    # viewed back as the original dtype.
125    zero_dim_type_erased_all_gather_pattern = CallFunction(
126        aten.view.dtype,
127        make_zero_dim_all_gather_pattern(
128            KeywordArg("shard"),
129        ),
130        Ignored(),
131    )
132
133    # Match a non-zero dim all-gather in which the data is transferred as uint8
134    # and viewed back as the original dtype.
135    non_zero_dim_type_erased_all_gather_pattern = CallFunction(
136        aten.view.dtype,
137        make_cat_pattern(
138            CallFunction(
139                aten.view.dtype,
140                make_all_gather_split_pattern(
141                    KeywordArg("shard"),
142                ),
143                Ignored(),
144            ),
145        ),
146        Ignored(),
147    )
148
149    # If two patterns with the same res_node_target have the same suffix, the
150    # longer pattern should appear first in the list.
151    # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1)
152    # should appear before (2) in the list.
153    res_node_target_to_patterns = {
154        aten.cat.default: [
155            (non_zero_dim_all_gather_pattern, 0),
156        ],
157        aten.view.dtype: [
158            (non_zero_dim_type_erased_all_gather_pattern, 0),
159            (zero_dim_type_erased_all_gather_pattern, 0),
160        ],
161        c10d.wait_tensor.default: [
162            (zero_dim_all_gather_pattern, 0),
163        ],
164    }
165
166    # Match in reverse to ensure longer patterns is prioritized
167    all_gathers = []
168    visited_ag_nodes = set()
169    for node in reversed(graph.nodes):
170        for target, patterns in res_node_target_to_patterns.items():
171            if node.target != target:
172                continue
173            for pattern, ag_node_idx in patterns:
174                match = pattern.match(node)
175                if not match:
176                    continue
177
178                assert isinstance(match, Match)
179                ag_node = match.nodes[ag_node_idx]
180                assert ag_node.target == c10d.all_gather_into_tensor.default
181
182                if ag_node in visited_ag_nodes:
183                    continue
184                visited_ag_nodes.add(ag_node)
185
186                ag_match = _AllGatherMatch(
187                    match=match,
188                    shard_node=match.kwargs["shard"],
189                    ag_node=ag_node,
190                    res_node=node,
191                    gather_dim=match.kwargs.get("gather_dim", 0),
192                    group_name=match.kwargs["group_name"],
193                )
194                all_gathers.append(ag_match)
195
196    return list(reversed(all_gathers))
197
198
199@dataclass
200class _ReduceScatterMatch:
201    match: Match
202    input_node: torch.fx.Node
203    rs_node: torch.fx.Node
204    res_node: torch.fx.Node
205    reduce_op: str
206    scatter_dim: int
207    group_name: str
208
209    def replace_with(self, new_node: torch.fx.Node) -> None:
210        self.res_node.replace_all_uses_with(new_node)
211
212    def erase(self) -> None:
213        for node in reversed(self.match.nodes):
214            if len(node.users) == 0:
215                node.graph.erase_node(node)
216
217
218def find_reduce_scatter_patterns(graph: torch.fx.Graph):
219    c10d = torch.ops._c10d_functional
220
221    def reduce_scatter_template(inp: PatternExpr):
222        return CallFunction(
223            c10d.wait_tensor.default,
224            CallFunction(
225                c10d.reduce_scatter_tensor.default,
226                inp,
227                KeywordArg("reduce_op"),
228                Ignored(),
229                KeywordArg("group_name"),
230            ),
231        )
232
233    # Matches funcol.reduce_scatter_tensor with scatter_dim == 0
234    zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input"))
235
236    # Matches funcol.reduce_scatter_tensor with scatter_dim > 0
237    non_zero_dim_reduce_scatter_pattern = reduce_scatter_template(
238        CallFunction(
239            aten.cat.default,
240            ListOf(
241                CallFunction(
242                    operator.getitem,
243                    CallFunction(
244                        aten.split.Tensor,
245                        KeywordArg("input"),
246                        Ignored(),
247                        KeywordArg("scatter_dim"),
248                        _users=MULTIPLE,
249                    ),
250                    Ignored(),
251                )
252            ),
253        ),
254    )
255
256    reduce_scatters = []
257    for node in reversed(graph.nodes):
258        if node.target == c10d.wait_tensor.default:
259            if match := non_zero_dim_reduce_scatter_pattern.match(node):
260                assert isinstance(match, Match)
261                reduce_scatters.append(
262                    _ReduceScatterMatch(
263                        match=match,
264                        input_node=match.kwargs["input"],
265                        rs_node=match.nodes[-2],
266                        res_node=node,
267                        reduce_op=match.kwargs["reduce_op"],
268                        scatter_dim=match.kwargs["scatter_dim"],
269                        group_name=match.kwargs["group_name"],
270                    )
271                )
272            elif match := zero_dim_reduce_scatter_pattern.match(node):
273                assert isinstance(match, Match)
274                reduce_scatters.append(
275                    _ReduceScatterMatch(
276                        match=match,
277                        input_node=match.kwargs["input"],
278                        rs_node=match.nodes[0],
279                        res_node=node,
280                        reduce_op=match.kwargs["reduce_op"],
281                        scatter_dim=0,
282                        group_name=match.kwargs["group_name"],
283                    )
284                )
285    return list(reversed(reduce_scatters))
286
287
288@dataclass
289class _Matmul:
290    nodes: List[torch.fx.Node]
291    arg_ancestor_nodes: Set[torch.fx.Node] = field(init=False)
292    A_node: torch.fx.Node
293    B_node: torch.fx.Node
294
295    def __post_init__(self):
296        assert len(self.nodes) in (1, 3)
297        if len(self.nodes) == 1:
298            assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default)
299        else:
300            assert self.nodes[0].target == aten.reshape.default
301            assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
302            assert self.nodes[2].target == aten.reshape.default
303        self.arg_ancestor_nodes = _find_ancestors(self.B_node)
304
305    def replace_with(self, new_node: torch.fx.Node) -> None:
306        """
307        Replace the matmul with the new node.
308        """
309        graph = new_node.graph
310
311        # For 2D-matmuls, we simply replace the mm node with `new_node`.
312        if len(self.nodes) == 1:
313            mm_node = self.nodes[0]
314            assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
315            mm_node.replace_all_uses_with(new_node)
316            graph.erase_node(mm_node)
317            return
318
319        # An ND-matmul is reshape -> mm -> reshape sequence. We first replace
320        # the second reshape node with `new_node`. Then, we ensure that the
321        # original mm node in the sequence ends up with zero users by replacing
322        # it with a reverse reshape of `new_node`.
323        graph = new_node.graph
324        assert len(self.nodes) == 3
325        mm_node = self.nodes[1]
326        output_reshape_node = self.nodes[2]
327
328        assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
329        assert output_reshape_node.target == aten.reshape.default
330
331        output_reshape_node.replace_all_uses_with(new_node)
332        if len(mm_node.users) > 1:
333            with graph.inserting_after(new_node):
334                new_mm_node = graph.call_function(
335                    aten.reshape.default,
336                    args=(new_node, list(_get_tensor(mm_node).shape)),
337                )
338            mm_node.replace_all_uses_with(new_mm_node)
339
340    def erase(self) -> None:
341        for node in reversed(self.nodes):
342            if len(node.users) == 0:
343                node.graph.erase_node(node)
344
345    @classmethod
346    def from_match(cls, match: List[torch.fx.Node]) -> "_Matmul":
347        assert len(match) in (1, 3)
348        assert match[0].target in (
349            aten.mm.default,
350            aten.reshape.default,
351        )
352        mm_node = match[0] if len(match) == 1 else match[1]
353        return _Matmul(
354            nodes=match,
355            A_node=cast(torch.fx.Node, match[0].args[0]),
356            B_node=cast(torch.fx.Node, mm_node.args[1]),
357        )
358
359
360@dataclass
361class _ScaledMatmul(_Matmul):
362    A_scale_node: torch.fx.Node
363    B_scale_node: torch.fx.Node
364    bias_node: Optional[torch.fx.Node]
365    result_scale_node: Optional[torch.fx.Node]
366    out_dtype: Optional[torch.dtype]
367    use_fast_accum: bool
368
369    def __post_init__(self):
370        super().__post_init__()
371        self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node)
372        self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node)
373
374    @classmethod
375    def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul":
376        assert len(match) in (1, 3)
377        assert match[0].target in (
378            aten._scaled_mm.default,
379            aten.reshape.default,
380        )
381        mm_node = match[0] if len(match) == 1 else match[1]
382
383        def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
384            if idx >= len(node.args):
385                return default
386            return node.args[idx]
387
388        return _ScaledMatmul(
389            nodes=match,
390            A_node=cast(torch.fx.Node, match[0].args[0]),
391            B_node=cast(torch.fx.Node, mm_node.args[1]),
392            A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
393            B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
394            bias_node=get_arg(mm_node, 4, None),
395            result_scale_node=get_arg(mm_node, 5, None),
396            out_dtype=get_arg(mm_node, 6, None),
397            use_fast_accum=get_arg(mm_node, 7, False),
398        )
399
400
401def _find_reshape_mm_reshape(node: torch.fx.Node) -> List[_Matmul]:
402    if node.target != aten.reshape.default:
403        return []
404
405    matches = []
406    for mm_node in node.users:
407        if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
408            continue
409        for reshape_node in mm_node.users:
410            if reshape_node.target != aten.reshape.default:
411                continue
412
413            # Since the reshape -> mm -> reshape pattern would be subsumed into
414            # the fused op, we only match the patterns where the shape of the
415            # second reshape is matches the mm result produced by the fused op.
416            matmul_input_node = cast(torch.fx.Node, node.args[0])
417            B_node = cast(torch.fx.Node, mm_node.args[1])
418            matmul_out_shape = torch.Size(
419                [
420                    *_get_tensor(matmul_input_node).shape[:-1],
421                    _get_tensor(B_node).shape[-1],
422                ]
423            )
424            if _get_tensor(reshape_node).shape != matmul_out_shape:
425                continue
426            matches.append([node, mm_node, reshape_node])
427            # If for some rare reason mm_node is being reshaped by two
428            # different reshape nodes, we only include mm_node once in the
429            # parsing result.
430            break
431
432    matmuls = []
433    for match in matches:
434        mm_node = match[1]
435        if mm_node.target == aten.mm.default:
436            matmul = _Matmul.from_match(match)
437            matmuls.append(matmul)
438        elif mm_node.target == aten._scaled_mm.default:
439            matmul = _ScaledMatmul.from_match(match)
440            matmuls.append(matmul)
441        else:
442            raise AssertionError(
443                "Expect the node's target to be either aten.mm.default or "
444                f"aten._scaled_mm.default. Got {mm_node.target}."
445            )
446    return matmuls
447
448
449def _find_consumer_matmuls(node: torch.fx.Node) -> List[_Matmul]:
450    """
451    Find the matmuls that use `node` as the lhs argument.
452    """
453    matmuls = []
454    for user in node.users:
455        # ND matmuls
456        if user.target == aten.reshape.default:
457            matmuls.extend(_find_reshape_mm_reshape(user))
458        # 2D matmuls
459        elif user.target == aten.mm.default:
460            matmul = _Matmul.from_match(match=[user])
461            matmuls.append(matmul)
462        elif user.target == aten._scaled_mm.default:
463            matmul = _ScaledMatmul.from_match([user])
464            matmuls.append(matmul)
465    return matmuls
466
467
468def _insert_fused_all_gather_matmul(
469    graph: torch.fx.Graph,
470    matmuls: List[_Matmul],
471    shard_node: torch.fx.Node,
472    gather_dim: int,
473    group_name: str,
474) -> torch.fx.Node:
475    mm_types = set(map(type, matmuls))
476    assert len(mm_types) == 1
477    mm_type = next(iter(mm_types))
478    if mm_type == _Matmul:
479        B_nodes = [matmul.B_node for matmul in matmuls]
480        return graph.call_function(
481            torch.ops.symm_mem.fused_all_gather_matmul.default,
482            args=(shard_node, B_nodes, gather_dim, group_name),
483        )
484    elif mm_type == _ScaledMatmul:
485        scaled_matmuls = cast(List[_ScaledMatmul], matmuls)
486        return graph.call_function(
487            torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
488            args=(
489                shard_node,
490                [matmul.B_node for matmul in scaled_matmuls],
491                scaled_matmuls[0].A_scale_node,
492                [matmul.B_scale_node for matmul in scaled_matmuls],
493                gather_dim,
494                group_name,
495                [matmul.bias_node for matmul in scaled_matmuls],
496                [matmul.result_scale_node for matmul in scaled_matmuls],
497                [matmul.out_dtype for matmul in scaled_matmuls],
498                [matmul.use_fast_accum for matmul in scaled_matmuls],
499            ),
500        )
501    else:
502        raise AssertionError(f"Unexpected matmul match type: {mm_type}")
503
504
505def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
506    """
507    Fused the pattern
508
509        A = all_gather_tensor(A_shard, gather_dim, group_name)
510        C_0 = torch.matmul(A, B_0)
511        C_1 = torch.matmul(A, B_1)
512        C_2 = torch.matmul(A, B_2)
513        ...
514
515    into
516
517        A, Cs = torch.ops.symm_mem.fused_all_gather_matmul(
518            A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name,
519        )
520    """
521    if (
522        not torch.distributed.is_available()
523        or not torch.distributed.is_nccl_available()
524    ):
525        return
526
527    c10d = torch.ops._c10d_functional
528    from torch.distributed._symmetric_memory import (
529        is_symm_mem_enabled_for_group,
530        restride_A_shard_for_fused_all_gather_matmul,
531    )
532
533    shard_node, ag_node, ag_res_node, gather_dim, group_name = (
534        all_gather.shard_node,
535        all_gather.ag_node,
536        all_gather.res_node,
537        all_gather.gather_dim,
538        all_gather.group_name,
539    )
540
541    if not is_symm_mem_enabled_for_group(group_name):
542        return
543
544    if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
545        # Decomposing the matmul on the K dimension is not supported
546        return
547
548    # Find consumer matmuls
549    matmuls = _find_consumer_matmuls(ag_res_node)
550
551    # The matmuls are only fusible if non-A args don't depend on the all-gather
552    # result node
553    matmuls = [
554        matmul
555        for matmul in matmuls
556        if all_gather.res_node not in matmul.arg_ancestor_nodes
557    ]
558
559    if len(matmuls) == 0 or len(set(map(type, matmuls))) != 1:
560        return
561
562    # Fuse the all_gather_tensor with the eligible matmuls
563    graph = ag_node.graph
564    with graph.inserting_before(ag_node):
565        if "val" in shard_node.meta:
566            restrided = restride_A_shard_for_fused_all_gather_matmul(
567                _get_tensor(shard_node),
568                gather_dim,
569            )
570            shard_node = graph.call_function(
571                inductor_prims.force_stride_order,
572                args=(shard_node, restrided.stride()),
573            )
574
575        fused_node = _insert_fused_all_gather_matmul(
576            graph, matmuls, shard_node, gather_dim, group_name
577        )
578        new_ag_node = graph.call_function(
579            operator.getitem,
580            args=(fused_node, 0),
581        )
582        new_out_nodes = graph.call_function(
583            operator.getitem,
584            args=(fused_node, 1),
585        )
586        for idx, matmul in enumerate(matmuls):
587            new_out_node = graph.call_function(
588                operator.getitem,
589                args=(new_out_nodes, idx),
590            )
591            matmul.replace_with(new_out_node)
592            matmul.erase()
593        all_gather.replace_with(new_ag_node)
594        all_gather.erase()
595
596    # Raise ancestors of non-A args that are topologically ordered between
597    # ag_res_node and the matmul above fused_node.
598    order = {node: idx for idx, node in enumerate(graph.nodes)}
599    nodes_to_raise = sorted(
600        {x for matmul in matmuls for x in matmul.arg_ancestor_nodes},
601        key=lambda x: order[x],
602    )
603    for node in nodes_to_raise:
604        if order[node] > order[fused_node]:
605            fused_node.prepend(node)
606
607
608def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
609    if node.target == aten.mm.default:
610        return _Matmul.from_match(match=[node])
611    elif node.target == aten._scaled_mm.default:
612        return _ScaledMatmul.from_match(match=[node])
613    elif node.target == aten.reshape.default:
614        reshape_node_1 = node
615
616        mm_node = reshape_node_1.args[0]
617        assert isinstance(mm_node, torch.fx.Node)
618        if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
619            return None
620
621        reshape_node_0 = mm_node.args[0]
622        assert isinstance(reshape_node_0, torch.fx.Node)
623        if reshape_node_0.target != aten.reshape.default:
624            return None
625
626        if mm_node.target == aten.mm.default:
627            return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1])
628        elif mm_node.target == aten._scaled_mm.default:
629            return _ScaledMatmul.from_match(
630                match=[reshape_node_0, mm_node, reshape_node_1]
631            )
632    return None
633
634
635def _insert_fused_matmul_reduce_scatter(
636    graph: torch.fx.Graph,
637    matmul: _Matmul,
638    reduce_op: str,
639    scatter_dim: int,
640    group_name: str,
641) -> torch.fx.Node:
642    if type(matmul) == _Matmul:
643        return graph.call_function(
644            torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
645            args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name),
646        )
647    elif type(matmul) == _ScaledMatmul:
648        return graph.call_function(
649            torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
650            args=(
651                matmul.A_node,
652                matmul.B_node,
653                matmul.A_scale_node,
654                matmul.B_scale_node,
655                reduce_op,
656                scatter_dim,
657                group_name,
658                matmul.bias_node,
659                matmul.result_scale_node,
660                matmul.out_dtype,
661                matmul.use_fast_accum,
662            ),
663        )
664    else:
665        raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
666
667
668def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
669    """
670    Fused the pattern
671
672        reduce_scatter_tensor(A @ B, scatter_dim, group_name)
673
674    into
675
676        torch.ops.symm_mem.fused_matmul_reduce_scatter(
677            A, B, scatter_dim, group_name,
678        )
679    """
680    if (
681        not torch.distributed.is_available()
682        or not torch.distributed.is_nccl_available()
683    ):
684        return
685
686    c10d = torch.ops._c10d_functional
687    from torch.distributed._symmetric_memory import (
688        is_symm_mem_enabled_for_group,
689        restride_A_for_fused_matmul_reduce_scatter,
690    )
691
692    input_node, rs_node, rs_res_node, reduce_op, scatter_dim, group_name = (
693        reduce_scatter.input_node,
694        reduce_scatter.rs_node,
695        reduce_scatter.res_node,
696        reduce_scatter.reduce_op,
697        reduce_scatter.scatter_dim,
698        reduce_scatter.group_name,
699    )
700
701    if not is_symm_mem_enabled_for_group(group_name):
702        return
703
704    # Currently fused_matmul_reduce_scatter doesn't return the matmul result,
705    # so we can't apply the fusion if the matmul result is used by multiple
706    # users. This is not a fundamental limitation of the fused op and can be
707    # addressed if needed.
708    if len(input_node.users) != 1:
709        return
710
711    matmul = _find_producer_matmul(input_node)
712    if matmul is None:
713        return
714
715    if rs_res_node in matmul.arg_ancestor_nodes:
716        return
717
718    graph = rs_res_node.graph
719    with graph.inserting_before(rs_res_node):
720        if "val" in matmul.A_node.meta:
721            restrided = restride_A_for_fused_matmul_reduce_scatter(
722                _get_tensor(matmul.A_node),
723                scatter_dim,
724            )
725            matmul.A_node = graph.call_function(
726                inductor_prims.force_stride_order,
727                args=(matmul.A_node, restrided.stride()),
728            )
729
730        fused_node = _insert_fused_matmul_reduce_scatter(
731            graph,
732            matmul,
733            reduce_op,
734            scatter_dim,
735            group_name,
736        )
737        reduce_scatter.replace_with(fused_node)
738        reduce_scatter.erase()
739        matmul.erase()
740
741    order = {node: idx for idx, node in enumerate(graph.nodes)}
742    nodes_to_raise = sorted(
743        matmul.arg_ancestor_nodes,
744        key=lambda x: order[x],
745    )
746    for node in nodes_to_raise:
747        if order[node] > order[fused_node]:
748            fused_node.prepend(node)
749
750
751def _get_node_to_ancestors(
752    graph: torch.fx.Graph,
753) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
754    """
755    Compute the ancestors for all nodes in a graph.
756    """
757    node_to_ancestors = defaultdict(set)
758    for node in graph.nodes:
759        node_to_ancestors[node] = set(node.all_input_nodes)
760        for dep in node.all_input_nodes:
761            node_to_ancestors[node] |= node_to_ancestors[dep]
762
763    return node_to_ancestors
764
765
766def _get_collective_to_overlappable_nodes(
767    graph: torch.fx.Graph,
768) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
769    """
770    For each collective in the graph, find nodes that are neither ancestors nor
771    descendants of the collective.
772    """
773
774    def is_collective(node) -> bool:
775        # Only consider all-gather and reduce-scatter in the context of
776        # micro-pipeline TP.
777        return node.target in [
778            torch.ops._c10d_functional.all_gather_into_tensor.default,
779            torch.ops._c10d_functional.reduce_scatter_tensor.default,
780        ]
781
782    node_to_ancestors = _get_node_to_ancestors(graph)
783    collective_to_overlappable_nodes = defaultdict(list)
784    for node in graph.nodes:
785        if not is_collective(node):
786            continue
787        for x in graph.nodes:
788            if (
789                node not in node_to_ancestors[x]
790                and x not in node_to_ancestors[node]
791                and x.op == "call_function"
792            ):
793                collective_to_overlappable_nodes[node].append(x)
794
795    return collective_to_overlappable_nodes
796
797
798def _get_unexposed_collectives(graph: torch.fx.Graph) -> List[torch.fx.Node]:
799    """
800    Find all unexposed collectives in the graph.
801
802    Because we don't have the runtime estimate, this function is a rough
803    estimation using the following strong/hand-wavy assumptions:
804
805    - Only a predefined set of "compute intensive" operation can hide a collective.
806    - Any "compute intensive" operation can hide exactly one collective.
807    """
808
809    def _is_compute_intensive(node: torch.fx.Node) -> bool:
810        return node.target in [torch.ops.aten.mm.default]
811
812    collective_to_overlapping_candidates = defaultdict(list)
813    available_nodes = set()
814    collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph)
815    for collective, overlappable_nodes in collective_to_overlappable_nodes.items():
816        candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)]
817        collective_to_overlapping_candidates[collective] = candidates
818        available_nodes |= set(candidates)
819
820    unexposed_collectives = []
821    for (
822        collective,
823        overlapping_candidates,
824    ) in collective_to_overlapping_candidates.items():
825        # Each collective consumes exactly one overlapping candidate
826        for x in overlapping_candidates:
827            if x in available_nodes:
828                unexposed_collectives.append(collective)
829                available_nodes.remove(x)
830                break
831    return unexposed_collectives
832
833
834def micro_pipeline_tp_pass(graph: torch.fx.Graph):
835    all_gathers = find_all_gather_patterns(graph)
836    reduce_scatters = find_reduce_scatter_patterns(graph)
837
838    # When a collective can be hidden through either simple overlapping or
839    # micro-pipeline TP, we prefer simple overlapping to avoid the overhead
840    # associated with decomposition. If reorder_for_compute_comm_overlap is
841    # enabled, we identify collectives that can be hidden through simple
842    # overlapping and exclude them from micro-pipeline TP candidates.
843    if config.reorder_for_compute_comm_overlap:
844        unexposed_collectives = _get_unexposed_collectives(graph)
845        all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
846        reduce_scatters = [
847            x for x in reduce_scatters if x.rs_node not in unexposed_collectives
848        ]
849
850    for all_gather in all_gathers:
851        fuse_all_gather_matmul(all_gather)
852
853    for reduce_scatter in reduce_scatters:
854        fuse_matmul_reduce_scatter(reduce_scatter)
855