xref: /aosp_15_r20/external/pytorch/torch/_inductor/comms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# pyre-strict
3from __future__ import annotations
4
5import heapq
6import operator
7import sys
8from collections import defaultdict
9from typing import Dict, List, Set, TYPE_CHECKING
10
11import torch
12
13from . import config, ir
14from .dependencies import WeakDep
15from .utils import (
16    contains_collective,
17    contains_wait,
18    find_recursive_deps_of_node,
19    find_recursive_users_of_node,
20    is_collective,
21    is_fallback_op,
22    is_wait,
23)
24
25
26overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
27
28if TYPE_CHECKING:
29    from .scheduler import BaseSchedulerNode
30
31
32def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
33    """
34    Greedily schedules waits as late as possible.
35    """
36    return _schedule_for_comm(
37        snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
38    )
39
40
41def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
42    """
43    Greedily schedules comms as early as possible.
44    """
45    return _schedule_for_comm(
46        snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
47    )
48
49
50def reorder_compute_for_overlap(
51    snodes: List[BaseSchedulerNode],
52) -> List[BaseSchedulerNode]:
53    """
54    This achieves the following overall scheduling procedure:
55        Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
56            that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
57        Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
58            Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
59            We prioritize compute nodes that are needed sooner.
60        Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
61        Step 4: We schedule comm N + 1.
62        Repeat this for subsequent comm nodes.
63    """
64    return _schedule_for_comm(
65        snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
66    )
67
68
69def _schedule_for_comm(
70    snodes: List[BaseSchedulerNode],
71    raise_comms: bool,
72    sink_waits: bool,
73    reorder_for_overlap: bool,
74) -> List[BaseSchedulerNode]:
75    """
76    Schedule `snodes` for various comm optimization objectives.
77
78    Args:
79        snodes: the nodes to be scheduled.
80        raise_comms: whether to greedily schedule collectives as early as possible
81        sink_wait: whether to greedily schedule waits as late as possible
82        reorder_compute_for_overlap: whether to reorder compute nodes to
83            optimize for compute/communication overlapping.
84
85    Returns:
86        The new schedule order.
87
88    Some notes on the synergy between different options:
89        - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
90        - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
91    """
92    # We assign each node a tuple of scores (score_0, score_1, score_2),
93    # decreasing in importance, with a lower value indicating a higher ranking:
94    #
95    # - score_0: the lowest comm_idx among the comm nodes that the node blocks.
96    # If a node doesn't block any comm nodes, its score_0 is set to
97    # sys.maxsize. This score ensures that comm nodes get scheduled as early as
98    # possible.
99    # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
100    # that wait nodes are deferred as late as possible.
101    # - score_2: the index of the node in the original topological order. This
102    # score provides stability in case of ties.
103    #
104    # When only raise_comms is True, only score_0 and score_2 are considered.
105    # When only sink_waits is True, only score_1 and score_2 are considered.
106    # When neither is True, the original order is yielded.
107    buf_name_to_snode = {}
108    name_to_fused_node = {}
109    scores_0, scores_1, scores_2 = {}, {}, {}
110    for idx, snode in enumerate(snodes):
111        for buf_name in snode.get_buffer_names():
112            buf_name_to_snode[buf_name] = snode
113
114        for op_name in snode.get_operation_names():
115            name_to_fused_node[op_name] = snode
116        name_to_fused_node[snode.get_name()] = snode
117
118        node_name = snode.get_name()
119        scores_0[node_name] = sys.maxsize
120        scores_1[node_name] = 0
121        scores_2[node_name] = idx
122
123    comm_idx = 0
124    for snode in snodes:
125        if raise_comms and contains_collective(snode):
126            scores_0[snode.get_name()] = comm_idx
127            for anc in snode.ancestors:
128                anc_fused_name = name_to_fused_node[anc].get_name()
129                scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
130            comm_idx += 1
131        elif sink_waits and contains_wait(snode):
132            scores_1[snode.get_name()] = 1
133
134    class Runnable:
135        def __init__(self, snode) -> None:
136            self.snode = snode
137            name = next(iter(snode.get_operation_names()))
138            fused_name = name_to_fused_node[name].get_name()
139            self.score = (
140                scores_0[fused_name],
141                scores_1[fused_name],
142                scores_2[fused_name],
143            )
144
145        def __lt__(self, other):
146            return self.score < other.score
147
148    unmet_deps: Dict[BaseSchedulerNode, Set[str]] = {
149        snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes
150    }
151
152    ready: List[Runnable] = []
153    buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set)
154    snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
155
156    for snode, deps in unmet_deps.items():
157        if len(deps) == 0:
158            heapq.heappush(ready, Runnable(snode))
159        for dep in deps:
160            buffer_users[dep].add(snode)
161
162    scheduled = []
163
164    def schedule(snode):
165        """
166        Schedules `snode` and put all unblocked nodes onto the ready queue.
167        """
168        scheduled.append(snode)
169        for buf_name in snode.get_buffer_names():
170            for snode in buffer_users[buf_name]:
171                unmet_deps[snode].remove(buf_name)
172                if len(unmet_deps[snode]) == 0:
173                    heapq.heappush(ready, Runnable(snode))
174
175    def get_overlapping_candidate():
176        """
177        Return the next node in the ready queue that's neither a collective or
178        a wait.
179        """
180        candidates = [
181            x
182            for x in ready
183            if not contains_collective(x.snode) and not contains_wait(x.snode)
184        ]
185        if len(candidates) == 0:
186            return None
187        return min(candidates, key=lambda x: x.score)
188
189    def schedule_collective_for_overlap(snode):
190        """
191        Schedules collective node `snode`, along with one or more compute nodes
192        to overlap with it. The strategy is described in the comment of
193        `reorder_compute_for_overlap`.
194        """
195        assert contains_collective(snode)
196        schedule(snode)
197
198        collective_cost = snode_to_cost[snode]
199        while (
200            collective_cost > 0
201            and (candidate := get_overlapping_candidate()) is not None
202        ):
203            ready.remove(candidate)
204            schedule(candidate.snode)
205            collective_cost -= snode_to_cost[candidate.snode]
206        heapq.heapify(ready)
207
208    while len(ready):
209        snode = heapq.heappop(ready).snode
210        if reorder_for_overlap and contains_collective(snode):
211            schedule_collective_for_overlap(snode)
212        else:
213            schedule(snode)
214
215    for snode, deps in unmet_deps.items():
216        assert len(deps) == 0, (
217            "Detected unscheduled nodes. "
218            f"Nodes with unmet dependencies: {unmet_deps}"
219        )
220    return scheduled
221
222
223def decide_global_ordering_of_comms(
224    nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
225) -> List[BaseSchedulerNode]:
226    """
227    Decide global ordering of comms, by just enforcing the ordering that's in the input graph
228    (might not be the same ordering as the eager mode program).
229    TODO: Come up with a better approach
230    """
231    # If FSDP2 is used, we apply FSDP-specific passes.
232    if any(
233        is_fallback_op(
234            x.node,
235            {
236                torch.ops.fsdp.all_gather_copy_in.default,
237                torch.ops.fsdp.chunk_cat.default,
238            },
239        )
240        for x in nodes
241    ):
242        nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node)
243
244    comm_nodes = [n for n in nodes if contains_collective(n)]
245
246    for i in range(1, len(comm_nodes)):
247        # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
248        mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
249        for buf in comm_nodes[i - 1].get_buffer_names():
250            comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
251
252    return nodes
253
254
255def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
256    """
257    Returns estimated op runtime in nanoseconds (ns)
258    """
259    if config.estimate_op_runtime == "default":
260        runtime = snode.get_estimated_runtime()
261    else:
262        assert callable(config.estimate_op_runtime)
263        runtime = config.estimate_op_runtime(snode)
264    return runtime
265
266
267def node_summary(snode):
268    detail = ""
269    if isinstance(snode.node, ir.ExternKernelOut):
270        detail = f" ({snode.node.python_kernel_name})"
271    out_tensor_info = ""
272    if (
273        hasattr(snode.node, "layout")
274        and hasattr(snode.node.layout, "size")
275        and hasattr(snode.node.layout, "stride")
276    ):
277        out_tensor_info = (
278            f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
279        )
280    node_name = ""
281    if hasattr(snode.node, "name"):
282        node_name = snode.node.name
283    return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
284
285
286def visualize_overlap(order):
287    total_est_runtime: float = 0.0
288    cur_comm_node = None
289    for snode in order:
290        if cur_comm_node is None:
291            if contains_collective(snode):
292                total_est_runtime += estimate_op_runtime(snode)
293                cur_comm_node = snode.node
294            elif is_wait(snode.node):
295                raise AssertionError(
296                    "Wait is not expected when there is no collective running"
297                )
298            else:  # exposed compute op
299                total_est_runtime += estimate_op_runtime(snode)
300            overlap_log.debug(f"{node_summary(snode)}")  # noqa: G004
301        else:  # cur_comm_node is not None
302            if contains_collective(snode):
303                raise AssertionError(
304                    "Found two collectives running at the same time. "
305                    "`visualize_overlap` needs to be updated to handle this case"
306                )
307            elif is_wait(snode.node):  # end of this comm op
308                overlap_log.debug(f"{node_summary(snode)}")  # noqa: G004
309                cur_comm_node = None
310            else:  # overlapped compute op
311                overlap_log.debug(f"| {node_summary(snode)}")  # noqa: G004
312    overlap_log.debug(
313        f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}"  # noqa: G004
314    )
315
316
317def reorder_compute_and_comm_for_overlap(
318    snodes: List[BaseSchedulerNode],
319) -> List[BaseSchedulerNode]:
320    order = snodes
321
322    for p in config.reorder_for_compute_comm_overlap_passes:
323        if isinstance(p, str) and p in globals():
324            p = globals()[p]  # it is a builtin pass
325        if torch.distributed.get_rank() == 0:
326            overlap_log.debug(
327                f"==== Visualize overlap before reordering pass {p} ===="  # noqa: G004
328            )
329            try:
330                visualize_overlap(order)
331            except Exception as e:
332                overlap_log.debug(str(e))
333        order = p(order)  # type: ignore[operator]
334        if torch.distributed.get_rank() == 0:
335            overlap_log.debug(
336                f"==== Visualize overlap after reordering pass {p} ===="  # noqa: G004
337            )
338            try:
339                visualize_overlap(order)
340            except Exception as e:
341                overlap_log.debug(str(e))
342    return order
343
344
345def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
346    try:
347        import torch.distributed._composable.fsdp._fsdp_collectives
348
349        assert torch.distributed.is_available()
350        # Assert existence of these ops
351        assert (
352            torch.ops._c10d_functional.all_gather_into_tensor
353            and torch.ops._c10d_functional.all_gather_into_tensor_out
354        )
355    except (ImportError, AttributeError, AssertionError):
356        return
357
358    from .pattern_matcher import (
359        CallFunction,
360        KeywordArg,
361        Match,
362        PatternMatcherPass,
363        register_graph_pattern,
364    )
365
366    """
367    all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
368    getitem = all_gather_copy_in[0];
369    (getitem_1 = all_gather_copy_in[1];)  # optional
370
371    all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
372
373    ->
374
375    all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
376    getitem = all_gather_copy_in[0];
377    getitem_1 = all_gather_copy_in[1];
378
379    all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
380    """
381
382    def remove_unused_getitem(g):
383        # Remove `getitem_X = all_gather_copy_in[1]` which is never used.
384        node_list = list(g.nodes)
385        for n in node_list:
386            if (
387                n.target == operator.getitem
388                and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
389                and n.args[1] == 1
390            ):
391                g.erase_node(n)
392
393    graph_pass = PatternMatcherPass()
394
395    @register_graph_pattern(
396        CallFunction(
397            torch.ops._c10d_functional.all_gather_into_tensor.default,
398            CallFunction(
399                operator.getitem,
400                CallFunction(
401                    torch.ops.fsdp.all_gather_copy_in.default,
402                    KeywordArg("all_gather_inputs"),
403                    KeywordArg("inp_split_sizes"),
404                    KeywordArg("all_gather_input_numel"),
405                    KeywordArg("world_size"),
406                    KeywordArg("rank"),
407                    KeywordArg("dtype"),
408                    KeywordArg("device"),
409                ),
410                KeywordArg("item_idx"),
411            ),
412            KeywordArg("group_size"),
413            KeywordArg("group_name"),
414        ),
415        pass_dict=graph_pass,
416        extra_check=lambda match: match.kwargs["item_idx"] == 0,
417    )
418    def reinplace_all_gather(match: Match, *args, **kwargs):
419        def repl(
420            *args,
421        ):
422            copy_in_args = args[:-2]
423            group_size = args[-2]
424            group_name = args[-1]
425            all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
426                *copy_in_args
427            )
428            getitem = all_gather_copy_in[0]
429            getitem_1 = all_gather_copy_in[1]
430            all_gather_into_tensor = (
431                torch.ops._c10d_functional.all_gather_into_tensor_out.default(
432                    getitem, group_size, group_name, out=getitem_1
433                )
434            )
435            return all_gather_into_tensor
436
437        match.replace_by_example(
438            repl,
439            [
440                kwargs["all_gather_inputs"],
441                kwargs["inp_split_sizes"],
442                kwargs["all_gather_input_numel"],
443                kwargs["world_size"],
444                kwargs["rank"],
445                kwargs["dtype"],
446                kwargs["device"],
447                kwargs["group_size"],
448                kwargs["group_name"],
449            ],
450        )
451
452    remove_unused_getitem(graph)
453    graph_pass.apply(graph)  # type: ignore[arg-type]
454
455
456def get_op_idx(snode):
457    assert not isinstance(
458        snode,
459        (
460            torch._inductor.scheduler.FusedSchedulerNode,
461            torch._inductor.scheduler.GroupedSchedulerNode,
462        ),
463    )
464    return int(snode.get_name()[2:])
465
466
467def enforce_comm_ordering_for_fsdp(
468    snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
469    name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
470    name_to_fused_node: Dict[str, BaseSchedulerNode],
471) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
472    from . import scheduler
473
474    new_order: list[BaseSchedulerNode] = []
475    scheduled = set()
476    ag_exists = False
477    rs_exists = False
478    ag_grouped_node_to_wait_grouped_node = {}
479    rs_grouped_node_to_wait_grouped_node = {}
480    snode_name_to_final_snode = {}
481
482    def _create_group_node(snodes_to_group):
483        group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
484        for snode in snodes_to_group:
485            snode_name_to_final_snode[snode.get_name()] = group_node
486        snode_name_to_final_snode[group_node.get_name()] = group_node
487        return group_node
488
489    # Create grouped nodes for specific sets of ops
490    for snode in snodes:
491        # Case 1: Handle AllGather
492        if is_collective(
493            snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
494        ) and any(
495            is_fallback_op(
496                name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
497            )
498            for x in snode.ancestors
499        ):
500            ag_exists = True
501            ag_snode = snode
502            ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
503
504            # Find the "cast + copy_in + getitem + all_gather" code block
505            find_recursive_deps_of_node(
506                ag_snode,
507                ag_related_snode_set,
508                name_to_buf,
509                name_to_fused_node,
510            )
511
512            # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
513            allowed_ops = {
514                torch.ops._c10d_functional.all_gather_into_tensor_out.default,
515                torch.ops._c10d_functional.wait_tensor.default,
516                torch.ops.fsdp.split_with_sizes_copy.default,
517                torch.ops.aten.set_.source_Tensor,
518            }
519            find_recursive_users_of_node(
520                ag_snode,
521                ag_related_snode_set,
522                name_to_buf,
523                name_to_fused_node,
524                criteria_cb=lambda x: not (
525                    isinstance(x, scheduler.NopKernelSchedulerNode)
526                    or (
527                        isinstance(x, scheduler.ExternKernelSchedulerNode)
528                        and x.node.op_overload in allowed_ops  # type: ignore[union-attr]
529                    )
530                ),
531            )
532
533            # sort nodes by original operation order
534            ag_related_snodes = sorted(
535                ag_related_snode_set, key=lambda x: get_op_idx(x)
536            )
537
538            # In the "reuse layer" case, some ops in the 2nd all-gather code block could also
539            # depend on ops in the 1st all-gather code block, and we don't want to group them together.
540            end_idx_of_current_ag_block = len(ag_related_snodes)
541            copy_out_count = 0
542            for i in range(len(ag_related_snodes)):
543                cur_snode = ag_related_snodes[i]
544                if is_fallback_op(
545                    cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
546                ):
547                    copy_out_count += 1
548                if copy_out_count > 1:
549                    end_idx_of_current_ag_block = i
550                    break
551
552            ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
553
554            # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
555            wait_node_idx = None
556            for i in range(len(ag_related_snodes) - 1):
557                if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
558                    wait_node_idx = i + 1
559                    break
560            assert wait_node_idx is not None
561            ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
562
563            # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
564            ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
565
566            ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
567
568        # Case 2: Handle ReduceScatter
569        elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
570            rs_exists = True
571            rs_snode = snode
572
573            # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
574            rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
575            find_recursive_users_of_node(
576                rs_snode,
577                rs_related_snode_set,
578                name_to_buf,
579                name_to_fused_node,
580            )
581
582            # sort nodes by original operation order
583            rs_related_snodes = sorted(
584                rs_related_snode_set, key=lambda x: get_op_idx(x)
585            )
586
587            # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
588            wait_node_idx = None
589            for i in range(len(rs_related_snodes) - 1):
590                if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
591                    wait_node_idx = i + 1
592                    break
593            assert wait_node_idx is not None
594            rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
595
596            # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
597            rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
598
599            rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
600
601    assert len(snode_name_to_final_snode) > 0
602    if ag_exists:
603        assert len(ag_grouped_node_to_wait_grouped_node) > 0
604    if rs_exists:
605        assert len(rs_grouped_node_to_wait_grouped_node) > 0
606
607    # Build the new node schedule, taking GroupedSchedulerNode into account
608    for snode in snodes:
609        if snode.get_name() in snode_name_to_final_snode:
610            snode = snode_name_to_final_snode[snode.get_name()]
611        if snode in scheduled:
612            continue
613        new_order.append(snode)
614        scheduled.add(snode)
615
616    # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
617    # before next AllGather's "copy_in then AG" group node
618    prev_ag_wait = None
619    for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
620        if prev_ag_wait is not None:
621            mutating_buf = next(iter(ag_group_node.get_buffer_names()))
622            for o in prev_ag_wait.get_outputs():
623                ag_group_node.add_fake_dep(
624                    WeakDep(o.get_name(), mutating_buf=mutating_buf)
625                )
626        prev_ag_wait = wait_group_node
627
628    # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
629    # before next ReduceScatter's "copy_in then RS" group node
630    prev_rs_wait = None
631    for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
632        if prev_rs_wait is not None:
633            mutating_buf = next(iter(rs_group_node.get_buffer_names()))
634            for o in prev_rs_wait.get_outputs():
635                rs_group_node.add_fake_dep(
636                    WeakDep(o.get_name(), mutating_buf=mutating_buf)
637                )
638        prev_rs_wait = wait_group_node
639
640    return new_order  # type: ignore[return-value]
641