xref: /aosp_15_r20/external/pytorch/torch/_inductor/scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: disallow-untyped-defs
2from __future__ import annotations
3
4import collections
5import dataclasses
6import functools
7import itertools
8import logging
9import math
10import operator
11import os
12import pprint
13import textwrap
14import traceback
15import typing
16from typing import (
17    Any,
18    Callable,
19    Counter,
20    DefaultDict,
21    Dict,
22    Generic,
23    List,
24    Optional,
25    Sequence,
26    Set,
27    Tuple,
28    TypeVar,
29    Union,
30)
31
32import sympy
33
34import torch
35import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
36from torch._dynamo.utils import counters, dynamo_timed
37from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
38from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
39from torch.utils._ordered_set import OrderedSet
40from torch.utils._sympy.symbol import free_symbol_is_type, SymT
41from torch.utils._triton import has_triton
42
43from . import comms, config, dependencies, ir, metrics
44from .codecache import write_text
45from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
46from .comm_analysis import estimate_nccl_collective_runtime
47from .dependencies import Dep, MemoryDep, StarDep, WeakDep
48from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
49from .loop_body import LoopBody
50from .runtime.runtime_utils import green_text, red_text
51from .sizevars import SimplifyIndexing
52from .utils import (
53    cache_on_self,
54    cmp,
55    device_need_guard,
56    get_device_tflops,
57    get_dtype_size,
58    get_gpu_dram_gbps,
59    IndentedBuffer,
60    is_collective,
61    is_gpu,
62    is_wait,
63    sympy_product,
64)
65from .virtualized import V
66
67
68log = logging.getLogger(__name__)
69fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
70loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
71
72
73@dataclasses.dataclass
74class SchedulerBuffer:
75    scheduler: Scheduler
76    node: ir.Buffer
77    defining_op: BaseSchedulerNode
78    users: List[NodeUser] = dataclasses.field(default_factory=list)
79
80    def __hash__(self) -> int:
81        return hash(self.node.name)
82
83    def debug_str(self) -> str:
84        result = IndentedBuffer()
85        name = self.get_name()
86        result.writeline(f"{name}: {type(self.node).__name__}")
87        result.writeline(f"{name}.layout = {self.node.layout}")
88        if self.get_aliases():
89            result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}")
90        if self.get_mutations():
91            result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}")
92
93        if len(self.users) <= 1:
94            result.writeline(f"{name}.users = {self.users}")
95        else:
96            result.writeline(f"{name}.users = [")
97            with result.indent(1):
98                for user in self.users:
99                    result.writeline(f"{user},")
100            result.writeline("]")
101        return result.getrawvalue()
102
103    def get_name(self) -> str:
104        return self.node.get_name()
105
106    def allocate(self) -> None:
107        assert self.node is not None
108        if not self.node.should_allocate():
109            return
110
111        if self.node.get_inputs_that_alias_output() or self.node.get_mutation_names():
112            V.graph.wrapper_code.codegen_allocation(self.node)
113            return
114
115        # hacky check for if V.kernel is a real kernel or NullHandler
116        if (
117            hasattr(V.kernel, "args")
118            and self.get_name() in V.kernel.inplace_update_buffers
119        ):
120            V.graph.wrapper_code.codegen_inplace_reuse(
121                self.scheduler.name_to_buf[
122                    V.kernel.inplace_update_buffers[self.get_name()]
123                ].node,
124                self.node,
125            )
126        else:
127            V.graph.wrapper_code.codegen_allocation(self.node)
128
129    def can_free(self) -> bool:
130        # There's no real allocated buffer, no need to free it
131        assert self.node is not None
132        if isinstance(self.node.layout, ir.NoneLayout):
133            return False
134        for use in self.users:
135            if isinstance(use.node, OutputNode):
136                return False
137        return True
138
139    def set_users(self, users: List[NodeUser]) -> None:
140        # deduplicate
141        result: Dict[int, NodeUser] = {}
142        for use in users:
143            if id(use.node) in result:
144                result[id(use.node)] = use.merge(result[id(use.node)])
145            else:
146                result[id(use.node)] = use
147        self.users = list(result.values())
148
149    def get_aliases(self) -> Sequence[str]:
150        assert self.node is not None
151        return self.node.get_inputs_that_alias_output()
152
153    def get_mutations(self) -> List[str]:
154        assert self.node is not None
155        return self.node.get_mutation_names()
156
157
158class BaseSchedulerNode:
159    group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
160    read_writes: dependencies.ReadWrites
161    unmet_dependencies: OrderedSet[Dep]
162    # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
163    # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
164    # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
165    # For non-"grouped" nodes (i.e. regular SchedulerNode),
166    # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
167    min_order: int
168    max_order: int
169
170    def __init__(self, scheduler: Scheduler) -> None:
171        self.scheduler: Scheduler = scheduler
172
173    def _init_from_node(self, node: ir.Operation) -> None:
174        self.node: Optional[ir.Operation] = node
175        self.ancestors: OrderedSet[str] = OrderedSet()
176        self.last_usage: OrderedSet[
177            str
178        ] = OrderedSet()  # buffers that won't be used after this kernel
179        self.written = False
180        self.outputs: List[SchedulerBuffer] = [
181            SchedulerBuffer(
182                scheduler=self.scheduler,
183                node=output,
184                defining_op=self,
185            )
186            for output in node.get_outputs()
187        ]
188        self.outputs_by_name: Dict[str, SchedulerBuffer] = {
189            buf.get_name(): buf for buf in self.outputs
190        }
191
192    def __repr__(self) -> str:
193        return f"{type(self).__name__}(name={self.get_name()!r})"
194
195    def debug_str(self) -> str:
196        """Longer form printout for trace logs"""
197        name = self.get_name()
198        buf = IndentedBuffer()
199        buf.splice(
200            f"""\
201{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})
202{name}.writes = {pformat(self.read_writes.writes)}
203{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
204{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
205{name}.outputs = [
206        """
207        )
208        with buf.indent():
209            for out in self.get_outputs():
210                buf.splice(out.debug_str())
211        buf.writeline("]")
212
213        try:
214            buf.splice(self.debug_str_extra())
215        except Exception:
216            log.warning("Ignoring error in debug_str()", exc_info=True)
217
218        return buf.getrawvalue().rstrip()
219
220    def debug_str_extra(self) -> str:
221        return ""
222
223    def debug_str_short(self) -> str:
224        maybe_data = getattr(self.node, "data", None)
225        data_str = ""
226        if isinstance(maybe_data, torch._inductor.ir.Pointwise):
227            data_str = ", " + maybe_data.str_helper(
228                [maybe_data.get_size()], shorten=False, multiline=False
229            )
230        elif isinstance(maybe_data, torch._inductor.ir.Reduction):
231            data_str = ", " + maybe_data.str_helper(
232                [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()],
233                shorten=False,
234                multiline=False,
235            )
236        return f"{self}{data_str}"
237
238    def log_details(self) -> None:
239        log.info(
240            "%s: unmet_dependencies = %s, writes = %s",
241            self,
242            self.unmet_dependencies,
243            self.read_writes.writes,
244        )
245
246    def reorder_loops_by_dep_pair(
247        self, self_dep: MemoryDep, other_dep: MemoryDep
248    ) -> None:
249        return
250
251    def update_mutated_names(self, renames: Dict[str, str]) -> None:
252        self.set_read_writes(self.read_writes.rename(renames))
253
254    def add_fake_dep(self, dep: Dep) -> None:
255        self.set_read_writes(self.read_writes.with_read(dep))
256
257    def has_aliasing_or_mutation(self) -> bool:
258        return any(
259            buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs()
260        )
261
262    def set_read_writes(self, rw: dependencies.ReadWrites) -> None:
263        self.read_writes = rw
264        self.unmet_dependencies = self.read_writes.reads
265        self.prune_deps()
266
267    def set_last_usage(
268        self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
269    ) -> None:
270        used_buffers = self.used_or_aliased_buffer_names()
271        used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers])
272        self.last_usage = used_buffers - future_used_buffers
273
274    def mark_run(self) -> None:
275        for buf in self.outputs:
276            buf.allocate()
277
278    def used_buffer_names(self) -> OrderedSet[str]:
279        return OrderedSet(
280            dep.name
281            for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
282        )
283
284    def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
285        used_names: OrderedSet[str] = OrderedSet()
286
287        deps = [
288            dep.name
289            for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
290        ]
291        while len(deps) > 0:
292            dep = deps.pop()
293            used_names.add(dep)
294            if V.graph.name_to_buffer.get(dep):
295                for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output():
296                    if alias not in used_names:
297                        deps.append(alias)
298        return used_names
299
300    def prune_deps(self) -> None:
301        self.unmet_dependencies = OrderedSet(
302            dep
303            for dep in self.unmet_dependencies
304            if dep.name not in self.scheduler.available_buffer_names
305        )
306
307    def prune_weak_deps(self) -> None:
308        # Prune weak dependencies on operations that have been removed
309        def should_prune(dep: Dep) -> bool:
310            if not isinstance(dep, WeakDep):
311                return False
312            op = self.scheduler.name_to_buf[dep.name].defining_op
313            return op.get_name() in V.graph.removed_operations
314
315        to_remove = OrderedSet(
316            dep for dep in self.read_writes.reads if should_prune(dep)
317        )
318        self.set_read_writes(self.read_writes.remove_reads(to_remove))
319
320    def prune_redundant_deps(
321        self, name_to_fused_node: Dict[str, BaseSchedulerNode]
322    ) -> None:
323        _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
324
325    def get_name(self) -> str:
326        assert self.node is not None
327        return self.node.get_operation_name()
328
329    def get_first_name(self) -> str:
330        return self.get_name()
331
332    def get_operation_names(self) -> OrderedSet[str]:
333        return OrderedSet(node.get_name() for node in self.get_nodes())
334
335    def get_buffer_names(self) -> OrderedSet[str]:
336        return OrderedSet(out.get_name() for out in self.outputs)
337
338    def get_nodes(self) -> Sequence[BaseSchedulerNode]:
339        return [self]
340
341    def get_outputs(self) -> Sequence[SchedulerBuffer]:
342        return self.outputs
343
344    def get_output(self, buf_name: str) -> SchedulerBuffer:
345        return self.outputs_by_name[buf_name]
346
347    def get_device(self) -> torch.device:
348        assert self.node is not None
349        return self.node.get_device()
350
351    def is_reduction(self) -> bool:
352        return False
353
354    def is_split_scan(self) -> bool:
355        return False
356
357    def is_template(self) -> bool:
358        return False
359
360    def is_extern(self) -> bool:
361        return False
362
363    def is_foreach(self) -> bool:
364        return False
365
366    def can_inplace(self, read_dep: dependencies.Dep) -> bool:
367        return False
368
369    def has_side_effects(self) -> bool:
370        return False
371
372    def decide_inplace_update(self) -> None:
373        """
374        Decide if there should be inplace updates for the node
375        and record the decision in the active kernel.
376        """
377        from .codegen.wrapper import buffer_reuse_key
378
379        if not (
380            isinstance(self, (SchedulerNode,))
381            and config.inplace_buffers
382            and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS)
383            and (
384                not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel)
385                or getattr(V.kernel, "mutations", None) is not None
386            )
387            # hacky check for if V.kernel is a real kernel or NullHandler
388            and hasattr(V.kernel, "args")
389        ):
390            return
391
392        ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)
393
394        for buf in self.get_outputs():
395            buf_node = buf.node
396            assert buf_node is not None
397            if (
398                not buf_node.should_allocate()
399                or buf_node.get_inputs_that_alias_output()
400                or buf_node.get_mutation_names()
401                or buf.get_name() in V.graph.removed_buffers
402            ):
403                continue
404
405            for read in ordered_reads:
406                input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
407                    read.name
408                )
409                if (
410                    input_buf
411                    and V.graph.wrapper_code.can_reuse(input_buf, self)
412                    and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
413                ):
414                    assert input_buf.users is not None
415                    remaining_uses = [
416                        x
417                        for x in input_buf.users
418                        if x.node.get_name() not in self.scheduler.completed_operations
419                    ]
420                    if (
421                        len(remaining_uses) == 1
422                        and remaining_uses[0].can_inplace
423                        and remaining_uses[0].node is self
424                        and input_buf.node is not None
425                        and not isinstance(
426                            input_buf.node.get_layout(),
427                            (
428                                ir.MultiOutputLayout,
429                                ir.MutationLayoutSHOULDREMOVE,
430                            ),
431                        )
432                        and not (
433                            isinstance(
434                                input_buf.defining_op.node,
435                                (ir.FallbackKernel, ir.MultiOutput),
436                            )
437                            and len(input_buf.node.get_inputs_that_alias_output()) > 0
438                        )
439                        and buffer_reuse_key(input_buf.node)
440                        == buffer_reuse_key(buf.node)
441                    ):
442                        # if there isn't a triton kernel, then we don't need to call triton-specific things.
443                        # but TODO this might be a convenient place to signal to the Collective kernels to inplace
444                        # (and, can we make "kernel" less generic of a name?)
445                        V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name())
446                        # mutations not tracked in cpp kernels
447                        if isinstance(
448                            V.kernel, torch._inductor.codegen.simd.SIMDKernel
449                        ):
450                            V.kernel.mutations.add(input_buf.get_name())
451                            V.kernel.mutations.add(buf.get_name())
452
453                        # update last usage of reused node
454                        self.last_usage.discard(input_buf.get_name())
455
456                        V.kernel.inplace_update_buffers[
457                            buf.get_name()
458                        ] = input_buf.get_name()
459                        break
460
461    def codegen_originating_info(
462        self, buffer: IndentedBuffer, only_once: bool = True
463    ) -> None:
464        if not config.comment_origin:
465            return
466
467        if only_once and self.written:
468            return
469        assert self.node is not None
470        origins = self.node.get_origins()
471        out_lines = []
472
473        for o in origins:
474            if o.op == "output":
475                # These are boring and samey
476                continue
477
478            out_lines.append("")
479            # TODO(voz): Should the pragma be constant somewhere?
480            out_lines.append("#pragma CMT ORIGIN:")
481            op_info_str = f"#pragma CMT {o.op} {o.target}"
482            if "seq_nr" in o.meta:
483                op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}"
484            out_lines.append(op_info_str)
485            if "stack_trace" in o.meta:
486                stack_trace = f"{o.meta['stack_trace']}"
487                stack_trace_last_line = stack_trace.split("|")[-1]
488                out_lines.append(
489                    "#pragma CMT "
490                    + stack_trace_last_line.replace("{", "{{")
491                    .replace("}", "}}")
492                    .replace("\n", "\\")
493                )
494                out_lines.append("#pragma CMT END ORIGIN")
495                out_lines.append("")
496
497        if len(out_lines) == 0:
498            return
499
500        # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
501        # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
502        buffer.writelines(out_lines)
503        self.written = True
504
505    def get_read_write_buffers_sizes(self) -> int:
506        """
507        Counting the number of bytes accessed for a kernel is
508        surprisingly tricky. In particular, there is a differentiation
509        between 'theoretical' memory accesses and practical memory
510        accesses. For example, a layernorm kernel may actually access an
511        input 3 times, but in theory, it only needs to access its input
512        once (and may be optimized to do so through say, persistent
513        reductions)
514
515        Another example is that even though a buffer is passed in, we may
516        not access the entire buffer. This may occur if we are accessing
517        a slice of the buffer. Another tricky case is for indirect
518        indexing, where the amount of bytes accessed depends on the
519        values of the input.
520
521        What this function aims to compute is the memory accesses for
522        worst-case inputs, best-case optimization. What this means is
523        that for each buffer we compute the amount of potential accesses in two ways and take the minimum.
524
525        1. Numel in ranges multiplied by number of deps the buffer has
526        2. The buffer size
527        """
528        if isinstance(self, NopKernelSchedulerNode):
529            return 0
530        if isinstance(self, ExternKernelSchedulerNode) and isinstance(
531            self.node, MultiOutput
532        ):
533            # todo: Calculate this - it's kinda annoying.
534            return 0
535
536        def try_size_hint(s: sympy.Expr) -> int:
537            return V.graph.sizevars.size_hint(s, fallback=0)
538
539        if isinstance(self, SchedulerNode):
540            node_numel = try_size_hint(
541                sympy_product(self.get_ranges()[0])
542                * sympy_product(self.get_ranges()[1]),
543            )
544        else:
545            node_numel = int(1e9)
546        buf_accesses = collections.defaultdict(list)
547        for dep in self.read_writes.reads | self.read_writes.writes:
548            buf_accesses[dep.name].append(dep)
549
550        reads = OrderedSet(dep.name for dep in self.read_writes.reads)
551        writes = OrderedSet(dep.name for dep in self.read_writes.writes)
552
553        def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool:
554            users = self.scheduler.name_to_buf[buf].users
555            buf_uses = OrderedSet(user.node for user in users)
556            return len(buf_uses - OrderedSet(snodes)) > 0
557
558        if isinstance(self, FusedSchedulerNode):
559            removed_buffers = OrderedSet(
560                dep for dep in writes if not is_materialized(dep, self.snodes)
561            )
562            writes = writes - removed_buffers
563            reads = reads - removed_buffers
564        node_bytes = 0
565
566        for buf_name in reads | writes:
567            buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
568            buf: Union[ir.Buffer, ir.TensorBox]
569            if buf_name in V.graph.name_to_buffer:
570                buf = V.graph.name_to_buffer[buf_name]
571            elif buf_name in V.graph.graph_inputs:
572                buf = V.graph.graph_inputs[buf_name]
573            else:
574                continue
575
576            def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int:
577                if not buf:
578                    return 0
579                # Kind of a lazy way to get the MultiOutput nodes corresponding to
580                # a MultiOutputLayout
581                if isinstance(buf.layout, MultiOutputLayout):
582                    users = self.scheduler.name_to_buf[buf.get_name()].users
583                    tot = 0
584                    for user in users:
585                        assert isinstance(user.node, BaseSchedulerNode)
586                        if isinstance(user.node.node, MultiOutput):
587                            for sched_buf in user.node.get_outputs():
588                                tot += get_buf_bytes(sched_buf.node)
589                        else:
590                            # Buf is a MultiOutputLayout but not all of its
591                            # users are MultiOutputs...
592                            # TODO: Figure out what's going on
593                            return 0
594                    return tot
595                elif isinstance(buf.layout, ir.NoneLayout):
596                    return sum(
597                        get_buf_bytes(V.graph.get_buffer(mut_name))
598                        for mut_name in buf.get_mutation_names()
599                    )
600                else:
601                    buf_elems = try_size_hint(sympy_product(buf.get_size()))
602                    return get_dtype_size(buf.get_dtype()) * min(
603                        buf_accessed_elems, buf_elems
604                    )
605
606            node_bytes += get_buf_bytes(buf)
607
608        return node_bytes
609
610    def get_estimated_runtime(self) -> float:
611        """
612        Returns estimated op runtime in nanoseconds (ns)
613        """
614        buf = self.get_nodes()[0].get_outputs()[0]
615        layout = buf.node.get_layout()
616        dtype = buf.node.get_dtype()
617
618        if layout.device is not None and not is_gpu(layout.device.type):
619            # default to no reordering based on runtime
620            return 0
621
622        # Collective kernels
623        if is_collective(self.node):
624            assert isinstance(self.node, ir.IRNode)
625            try:
626                return estimate_nccl_collective_runtime(self.node)
627            except ValueError as e:
628                # We don't know how to estimate runtime for this collective,
629                # falling back to 0
630                log.info(e)
631                return 0
632
633        elif is_wait(self.node):
634            # ir.Wait is only used for collective ops.
635            # The time needed for the collective op is already estimated and considered
636            # when we are processing the collective op IR node, so ir.Wait takes 0 time
637            # since it doesn't take extra time to get the result after the collective is completed.
638            return 0
639
640        try:
641            gpu_memory_bandwidth = get_gpu_dram_gbps()
642            gpu_flops = get_device_tflops(dtype) * 10**12
643        except Exception:
644            return 0
645
646        if isinstance(self, ExternKernelSchedulerNode):
647            assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}"
648            op = kernel_name_to_op.get(
649                getattr(self.node, "python_kernel_name", ""), None
650            )
651
652            # if there is a resolved op, dry-run using fake mode and record flop count
653            if op is not None:
654                from torch._subclasses.fake_tensor import FakeTensorMode
655                from torch.utils.flop_counter import FlopCounterMode
656
657                if any(
658                    len(free_unbacked_symbols(n.get_numel())) > 0
659                    for n in self.node.inputs
660                ):
661                    # Tensor has unbacked symints, we don't know how to estimate
662                    # runtime for that today
663                    return 0
664
665                with FakeTensorMode() as fake_mode, FlopCounterMode(
666                    display=False
667                ) as flop_counter_mode, V.set_current_node(
668                    self.node.fx_node
669                ), V.set_fake_mode(
670                    fake_mode
671                ):
672                    from .ir import ir_node_to_tensor
673
674                    fake_inputs = [
675                        ir_node_to_tensor(input, guard_shape=False)
676                        for input in self.node.inputs
677                    ]
678                    cls = self.node.__class__
679                    cls.process_kernel(op, *fake_inputs, **self.node.kwargs)
680
681                    # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
682                    factor = 1.0
683                    counted_flops = flop_counter_mode.get_total_flops()
684                    counted_bytes = self.get_read_write_buffers_sizes()
685                    compute_time = (factor * counted_flops / gpu_flops) * 1e9
686                    transfer_time = counted_bytes / gpu_memory_bandwidth
687
688                    # Return estimated runtime in nanoseconds
689                    return max(compute_time, transfer_time)
690
691        elif isinstance(self, FusedSchedulerNode) or isinstance(
692            self.node, ComputedBuffer
693        ):
694            # Return estimated runtime in nanoseconds (bytes / gbps)
695            return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
696
697        return 0
698
699    def get_template_node(self) -> Optional[ir.TemplateBuffer]:
700        return None
701
702
703class WhyNoFuse:
704    # TODO when we drop support for Python < 3.10, we can use
705    # @dataclass(slots=True) instead of manually specifying __slots__.
706    __slots__ = ["node1", "node2", "reason", "args"]
707    reason: str
708    args: Tuple[Any, ...]
709
710    def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
711        self.node1 = node1
712        self.node2 = node2
713
714    def __call__(self, reason: str, *args: Any) -> None:
715        self.reason = reason
716        self.args = args
717        fusion_log.debug(self)
718
719    def __str__(self) -> str:
720        return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
721            self.reason % self.args
722        )
723
724
725def pformat(obj: Any) -> str:
726    if isinstance(obj, OrderedSet):
727        # pformat has trouble with sets of sympy exprs
728        obj = sorted(obj, key=str)
729    result = pprint.pformat(obj, indent=4)
730    if "\n" in result:
731        return f"\n{textwrap.indent(result, ' ' * 4)}"
732    return result
733
734
735class OutputNode:
736    def __init__(self, dep: StarDep) -> None:
737        self.unmet_dependencies = OrderedSet([dep])
738
739    def is_reduction(self) -> bool:
740        return False
741
742    def get_inputs_that_alias_output(self) -> Sequence[str]:
743        return ()
744
745    def get_name(self) -> str:
746        return "OUTPUT"
747
748    __repr__ = get_name
749
750
751def _prune_redundant_deps(
752    node: BaseSchedulerNode,
753    name_to_fused_node: Dict[str, BaseSchedulerNode],
754    name_to_buf: Dict[str, SchedulerBuffer],
755) -> None:
756    """
757    Prunes weakdeps intended for mutation ordering
758    on an upstream fused node if after fusion there is another dependency
759    on the fused upstream node, making the weakdep redundant
760
761    In essence this enforces an ordering on fusions. As fusions occur, weakdeps will
762    be incrementally removed, enabling other fusions, ensuring they are fused in order.
763    """
764    name_to_dep_count: Counter[str] = collections.Counter()
765
766    for dep in node.unmet_dependencies:
767        if not isinstance(dep, WeakDep):
768            op = name_to_buf[dep.name].defining_op
769            name_to_dep_count[name_to_fused_node[op.get_name()].get_name()] += 1
770
771    def should_prune(dep: Dep) -> bool:
772        if isinstance(dep, WeakDep):
773            op_name = name_to_buf[dep.name].defining_op.get_name()
774            is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0
775            # These can occur because fused nodes always gather deps from their snodes
776            # If B has a weakdep on A
777            # B gets fused with C, then any time BC is fused, the weakdep will reappear
778            is_self_dep = name_to_fused_node[op_name] == node
779            return is_redundant or is_self_dep
780        else:
781            return False
782
783    deps_to_prune = OrderedSet(
784        dep for dep in node.unmet_dependencies if should_prune(dep)
785    )
786
787    if deps_to_prune:
788        node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
789        node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
790
791
792# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
793kernel_name_to_op = {
794    "extern_kernels.convolution": torch.ops.aten.convolution,
795    "extern_kernels.mm": torch.ops.aten.mm,
796    "extern_kernels.bmm": torch.ops.aten.bmm,
797    "extern_kernels.addmm": torch.ops.aten.addmm,
798}
799
800
801class ExternKernelSchedulerNode(BaseSchedulerNode):
802    def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
803        super().__init__(scheduler)
804        self._init_from_node(node)
805        self.set_read_writes(node.get_read_writes())
806
807    def debug_str_extra(self) -> str:
808        return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
809
810    def is_extern(self) -> bool:
811        return True
812
813    def has_side_effects(self) -> bool:
814        assert self.node is not None
815        return hasattr(self.node, "has_side_effects") and self.node.has_side_effects()
816
817
818class NopKernelSchedulerNode(BaseSchedulerNode):
819    def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
820        super().__init__(scheduler)
821        self._init_from_node(node)
822        self.set_read_writes(node.get_read_writes())
823
824
825class SchedulerNode(BaseSchedulerNode):
826    def __init__(
827        self,
828        scheduler: Scheduler,
829        node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
830    ) -> None:
831        super().__init__(scheduler)
832        self._init_from_node(node)
833        self._compute_attrs()
834
835    def _compute_attrs(
836        self,
837        extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
838        recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
839    ) -> None:
840        assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
841        self._sizes, self._body = self.node.simplify_and_reorder(
842            extra_indexing_constraints=extra_indexing_constraints,
843            recompute_sizes_body_func=recompute_sizes_body_func,
844        )
845
846        group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn
847        self.group = (self.node.get_device(), group_fn(self._sizes))
848
849        # Don't normalize since normalization will merge loops which
850        # makes it hard to decide new loop orders.
851        should_normalize = (
852            not config.loop_ordering_after_fusion
853            or self.node.get_device().type != "cuda"
854        )
855
856        if isinstance(self.node, ir.TemplateBuffer):
857            self.set_read_writes(
858                self.node.extract_read_writes(normalize=should_normalize)
859            )
860        else:
861            self.set_read_writes(
862                dependencies.extract_read_writes(
863                    self._body, *self._sizes, normalize=should_normalize
864                )
865            )
866
867    def recompute_size_and_body(
868        self,
869        extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
870        recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
871    ) -> None:
872        self._compute_attrs(
873            extra_indexing_constraints=extra_indexing_constraints,
874            recompute_sizes_body_func=recompute_sizes_body_func,
875        )
876
877    def refresh_dependencies(self, normalize: bool) -> None:
878        # Fake dependencies are added manually. They can not be analyzed from
879        # extract_read_writes. Find them out and apply manually.
880        fake_deps = {
881            dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep))
882        }
883
884        # don't normalize since the loop order may need to be further changed
885        # later
886        self.set_read_writes(
887            dependencies.extract_read_writes(
888                self._body, *self._sizes, normalize=normalize
889            ).with_read(fake_deps)
890        )
891
892    def apply_new_loop_order(self, new_order: Sequence[int]) -> None:
893        self._body = self._body.reorder_iter_loops(
894            new_order,
895        )
896        self._sizes = self._body.sizes
897
898        self.refresh_dependencies(normalize=False)
899
900    def reorder_loops_by_dep_pair(
901        self, self_dep: MemoryDep, other_dep: MemoryDep
902    ) -> None:
903        new_order = None
904        self_sizes = self._sizes[0]
905        if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
906            new_order = self_dep.decide_loop_order_to_match(other_dep)
907
908        if new_order:
909            metrics.num_loop_reordering += 1
910            loop_ordering_log.debug(
911                "Reorder loops for %s with order %s", self.get_name(), new_order
912            )
913            self.apply_new_loop_order(new_order)
914        else:
915            loop_ordering_log.debug(
916                "Don't reordering %s because we can not decide the suitable loop order",
917                self.get_name(),
918            )
919
920    def debug_str_extra(self) -> str:
921        name = self.get_name()
922        lines = [
923            f"{name}.group.device = {self.group[0]}",
924            f"{name}.group.iteration = {self.group[1]}",
925            f"{name}.sizes = {self._sizes}",
926        ]
927        for dep in self.read_writes.reads_and_writes():
928            if not isinstance(dep, WeakDep):
929                buf_name = dep.name
930                buf = V.graph.get_buffer(buf_name)
931                lines.append(f"{buf_name}_layout = {pformat(buf.layout)}")
932        if isinstance(self._body, LoopBody):
933            lines.append(f"class {name}_loop_body:")
934            lines.append(textwrap.indent(self._body.debug_str(), "    "))
935
936        assert self.node is not None
937        if ir.is_triton(self.node.get_device()):
938            lines.extend(debug_triton_code(self))
939
940        return "\n".join(lines)
941
942    def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]:
943        return self._sizes
944
945    def is_reduction(self) -> bool:
946        assert isinstance(
947            self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
948        ), f"{type(self.node)=}"
949        return bool(self.node.get_reduction_type())
950
951    def is_split_scan(self) -> bool:
952        assert isinstance(
953            self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
954        ), f"{type(self.node)=}"
955        return isinstance(self.node, ir.ComputedBuffer) and isinstance(
956            self.node.data, ir.SplitScan
957        )
958
959    def is_template(self) -> bool:
960        return isinstance(self.node, ir.TemplateBuffer)
961
962    def get_template_node(self) -> Optional[ir.TemplateBuffer]:
963        return self.node if isinstance(self.node, ir.TemplateBuffer) else None
964
965    def run(self, *index_vars: Sequence[sympy.Expr]) -> None:
966        self.decide_inplace_update()
967        self.mark_run()
968        self.codegen(index_vars)
969
970    def ranges_from_index_vars(
971        self, index_vars: Sequence[Sequence[sympy.Expr]]
972    ) -> Dict[sympy.Expr, sympy.Expr]:
973        sizes = self._sizes
974        assert sum(map(len, sizes)) == sum(map(len, index_vars))
975        var_ranges = dict(
976            zip(
977                itertools.chain.from_iterable(index_vars),
978                itertools.chain.from_iterable(sizes),
979            )
980        )
981        return var_ranges
982
983    def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
984        var_ranges = self.ranges_from_index_vars(index_vars)
985        try:
986            with V.set_ops_handler(
987                SimplifyIndexing(V.get_ops_handler(), var_ranges)
988            ), V.kernel.set_current_node(self):
989                self._body(*index_vars)
990        except Exception:
991            log.fatal("Error in codegen for %s", self.node)
992            raise
993
994    @cache_on_self
995    def pointwise_read_writes(self) -> dependencies.ReadWrites:
996        """
997        Get the memory dependencies in the non-reduction axis.
998        """
999        sizes, reduction_sizes = self._sizes
1000        return dependencies.extract_read_writes(
1001            self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)]
1002        )
1003
1004    def can_inplace(self, read_dep: dependencies.Dep) -> bool:
1005        if self.is_template():
1006            return False
1007        if any(out.get_aliases() for out in self.get_outputs()):
1008            return False
1009        if len(self.read_writes.writes) == 1 and isinstance(
1010            read_dep, dependencies.MemoryDep
1011        ):
1012            write_dep = next(iter(self.read_writes.writes))
1013            assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}"
1014            return read_dep.index == write_dep.index and read_dep.size == write_dep.size
1015        return False
1016
1017    @cache_on_self
1018    def _get_atomic_add_buffers(self) -> OrderedSet[str]:
1019        buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
1020        if isinstance(self._body, LoopBody):
1021            for node in self._body.get_nodes():
1022                if (
1023                    node.op == "call_method"
1024                    and node.target == "store"
1025                    and (
1026                        ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add")
1027                        or (len(node.args) == 5 and node.args[4] == "atomic_add")
1028                    )
1029                ):
1030                    buffers_store_as_atomic_add.add(
1031                        node.kwargs["name"]
1032                        if "name" in node.kwargs
1033                        else (node.args[1] if len(node.args) >= 2 else "")
1034                    )
1035        return buffers_store_as_atomic_add
1036
1037
1038def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None:
1039    snodes = group_snode.snodes  # type: ignore[attr-defined]
1040    group_snode.set_read_writes(
1041        dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
1042    )
1043
1044    group_snode.unmet_dependencies = (
1045        OrderedSet(
1046            dep
1047            for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes])
1048            if dep.name not in group_snode.get_buffer_names()
1049        )
1050        - group_snode.read_writes.writes
1051    )
1052
1053
1054def init_group_node(
1055    group_snode: BaseSchedulerNode,
1056    scheduler: Scheduler,
1057    snodes: List[BaseSchedulerNode],
1058) -> None:
1059    assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode))
1060    group_snode.snodes = snodes
1061    group_snode.scheduler = scheduler
1062    group_snode.node = None
1063    group_snode.ancestors = OrderedSet.union(
1064        *[x.ancestors for x in snodes if x.ancestors is not None]
1065    )
1066
1067    refresh_group_node_dependencies(group_snode)
1068
1069    group_snode.min_order = min(x.min_order for x in group_snode.snodes)
1070    group_snode.max_order = max(x.max_order for x in group_snode.snodes)
1071    group_snode.outputs_by_name = {
1072        buf.get_name(): buf for buf in group_snode.get_outputs()
1073    }
1074
1075
1076class FusedSchedulerNode(BaseSchedulerNode):
1077    """
1078    This is a "fake" scheduler node that represents a group of scheduler nodes
1079    that are meant to be fused together. The way it does this is by maintaining
1080    its unmet dependencies as the union of its constituent nodes.
1081    """
1082
1083    snodes: List[BaseSchedulerNode]
1084
1085    @classmethod
1086    def fuse(
1087        cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
1088    ) -> FusedSchedulerNode:
1089        assert node1.scheduler is node2.scheduler
1090        assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
1091        assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
1092        nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes()))
1093        return cls(node1.scheduler, nodes)
1094
1095    def reorder_loops_by_dep_pair(
1096        self, self_dep: MemoryDep, other_dep: MemoryDep
1097    ) -> None:
1098        if self.is_template():
1099            # We can not really reorder loops for a triton template
1100            return
1101        self_sizes = None
1102        for snode in self.snodes:
1103            assert isinstance(snode, SchedulerNode)
1104            if self_sizes is not None and self_sizes != snode._sizes[0]:
1105                loop_ordering_log.debug(
1106                    "Can not reorder fused node due to different sizes"
1107                )
1108                return
1109            self_sizes = snode._sizes[0]
1110        new_order = None
1111
1112        assert self_sizes is not None
1113        if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
1114            new_order = self_dep.decide_loop_order_to_match(other_dep)
1115
1116        if not new_order:
1117            loop_ordering_log.debug(
1118                "Dont reordering fused node %s because we can not decide the suitable loop order",
1119                self.get_name(),
1120            )
1121            return
1122        metrics.num_loop_reordering += 1
1123        loop_ordering_log.debug(
1124            "Reorder loops for fused node %s with order %s", self.get_name(), new_order
1125        )
1126        for snode in self.snodes:
1127            assert isinstance(snode, SchedulerNode)
1128            snode.apply_new_loop_order(new_order)  # type: ignore[arg-type]
1129
1130        refresh_group_node_dependencies(self)
1131
1132    def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
1133        super().__init__(scheduler)
1134        init_group_node(self, scheduler, snodes)
1135        self.users: List[NodeUser] = []
1136        self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
1137
1138    @cache_on_self
1139    def get_name(self) -> str:
1140        return "_".join([x.get_name() for x in self.snodes])
1141
1142    def get_first_name(self) -> str:
1143        return self.snodes[0].get_name()
1144
1145    @cache_on_self
1146    def get_buffer_names(self) -> OrderedSet[str]:
1147        return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
1148
1149    def get_outputs(self) -> List[SchedulerBuffer]:
1150        result: List[SchedulerBuffer] = []
1151        for node in self.snodes:
1152            result.extend(node.get_outputs())
1153        return result
1154
1155    def debug_str_extra(self) -> str:
1156        lines = [
1157            f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
1158            for i, node in enumerate(self.snodes)
1159        ]
1160        node = self.snodes[0].node
1161        if node is not None:
1162            device = node.get_device()
1163            if ir.is_triton(device):
1164                lines.extend(debug_triton_code(self))
1165
1166        return textwrap.indent("\n".join(lines).rstrip(), "    ")
1167
1168    def debug_str_short(self) -> str:
1169        snodes_str = [node.debug_str_short() for node in self.snodes]
1170        return f"{self}, snodes: {snodes_str}"
1171
1172    def set_last_usage(
1173        self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
1174    ) -> None:
1175        # Set self.last_usage using the global information
1176        # This will be used for inter-kernel optimisations
1177        super().set_last_usage(future_used_buffers, mutation_real_name)
1178        # Set self.last_usage on the snodes
1179        # This will be used for optimisations within the kernel
1180        future_used_buffers: OrderedSet[str] = OrderedSet()
1181        for node in reversed(self.snodes):
1182            node.set_last_usage(future_used_buffers, mutation_real_name)
1183            future_used_buffers.update(node.last_usage)
1184
1185    @cache_on_self
1186    def used_buffer_names(self) -> OrderedSet[str]:
1187        return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes])
1188
1189    @cache_on_self
1190    def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
1191        return OrderedSet.union(
1192            *[x.used_or_aliased_buffer_names() for x in self.snodes]
1193        )
1194
1195    def get_nodes(self) -> Sequence[BaseSchedulerNode]:
1196        return self.snodes
1197
1198    def __repr__(self) -> str:
1199        return f"{type(self).__name__}(nodes={self.get_name()})"
1200
1201    @cache_on_self
1202    def is_reduction(self) -> bool:
1203        return any(x.is_reduction() for x in self.snodes)
1204
1205    @cache_on_self
1206    def is_split_scan(self) -> bool:
1207        return any(x.is_split_scan() for x in self.snodes)
1208
1209    @cache_on_self
1210    def is_template(self) -> bool:
1211        return any(x.is_template() for x in self.snodes)
1212
1213    @cache_on_self
1214    def get_template_node(self) -> Optional[ir.TemplateBuffer]:
1215        for node in self.snodes:
1216            if node.is_template():
1217                return node.get_template_node()
1218        return None
1219
1220    def get_device(self) -> torch.device:
1221        return self.group[0]
1222
1223    @cache_on_self
1224    def has_aliasing_or_mutation(self) -> bool:
1225        return any(x.has_aliasing_or_mutation() for x in self.snodes)
1226
1227    # None of these need to be implemented, as a FusedSchedulerNode is just an
1228    # abstraction for scheduling purposes
1229    def update_mutated_names(self, renames: Dict[str, str]) -> None:
1230        raise NotImplementedError
1231
1232    def add_fake_dep(self, name: Dep) -> None:
1233        raise NotImplementedError
1234
1235    def can_inplace(self, read_dep: dependencies.Dep) -> bool:
1236        raise NotImplementedError
1237
1238    def debug_str(self) -> str:
1239        """Longer form printout for trace logs"""
1240        name = self.get_name()
1241        node_typestr = ",".join(type(n).__name__ for n in self.snodes)
1242        buf = IndentedBuffer()
1243        buf.splice(
1244            f"""\
1245{name}: {type(self).__name__}({node_typestr})
1246{name}.writes = {pformat(self.read_writes.writes)}
1247{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
1248{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
1249{name}.outputs = [
1250            """
1251        )
1252        with buf.indent():
1253            for out in self.get_outputs():
1254                buf.splice(out.debug_str())
1255        buf.writeline("]")
1256
1257        try:
1258            buf.splice(self.debug_str_extra())
1259        except Exception:
1260            log.warning("Ignoring error in debug_str()", exc_info=True)
1261
1262        return buf.getrawvalue().rstrip()
1263
1264
1265class ForeachKernelSchedulerNode(FusedSchedulerNode):
1266    """
1267    This is a schedular node that consists of a set of scheduler nodes that
1268    has no data dependencies among them and can be executed in parallel.
1269    """
1270
1271    def get_consumer_subnode_for(
1272        self, producer: BaseSchedulerNode
1273    ) -> Optional[BaseSchedulerNode]:
1274        for buf in producer.get_outputs():
1275            if buf.get_name() in self.read_to_node:
1276                return self.read_to_node[buf.get_name()]
1277
1278        return None
1279
1280    def get_producer_subnode_for(
1281        self, consumer: BaseSchedulerNode
1282    ) -> Optional[BaseSchedulerNode]:
1283        producers = set()
1284        for rd in consumer.read_writes.reads:
1285            if rd.name not in self.scheduler.name_to_buf:
1286                continue
1287
1288            node_name = self.scheduler.name_to_buf[rd.name].defining_op.get_name()
1289            if node_name in self.name_to_node:
1290                producers.add(self.name_to_node[node_name])
1291
1292        # Don't permit fusion if there are multiple subnodes
1293        # that this consumer reads from
1294        if len(producers) == 1:
1295            return next(iter(producers))
1296        else:
1297            return None
1298
1299    @classmethod
1300    def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
1301        why = WhyNoFuse(producer, consumer)
1302        if producer.is_foreach() and consumer.is_foreach():
1303            producer = typing.cast(ForeachKernelSchedulerNode, producer)
1304            consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
1305            foreach_match = len(producer.snodes) == len(consumer.snodes)
1306            if not foreach_match:
1307                why("foreach do not have same length")
1308            return foreach_match and all(
1309                producer.scheduler.can_fuse(l, r)
1310                for l, r in zip(producer.snodes, consumer.snodes)
1311            )
1312        elif consumer.is_foreach():
1313            if producer.is_reduction():
1314                why(
1315                    "candidate producer is a reduction, foreach ops cannot be fused with reductions currently"
1316                )
1317                return False
1318
1319            consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
1320            consumer_subnode = consumer.get_consumer_subnode_for(producer)
1321            if consumer_subnode is not None:
1322                return consumer.scheduler.can_fuse(producer, consumer_subnode)
1323
1324            why("candidate producer is not dep of any foreach consumer")
1325            return False
1326
1327        elif producer.is_foreach():
1328            if consumer.is_reduction():
1329                why(
1330                    "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently"
1331                )
1332                return False
1333
1334            producer = typing.cast(ForeachKernelSchedulerNode, producer)
1335            producer_subnode = producer.get_producer_subnode_for(consumer)
1336            if producer_subnode is not None:
1337                return producer.scheduler.can_fuse(producer_subnode, consumer)
1338
1339            why("candidate consumer has no dep in any foreach producer")
1340            return False
1341
1342        raise AssertionError(
1343            "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node"
1344        )
1345
1346    @classmethod
1347    def fuse(
1348        cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode
1349    ) -> ForeachKernelSchedulerNode:
1350        assert producer.is_foreach() or consumer.is_foreach()
1351        if producer.is_foreach():
1352            producer = typing.cast(ForeachKernelSchedulerNode, producer)
1353            use_custom_partition_algo = producer.use_custom_partition_algo
1354            enable_autotune = producer.enable_autotune
1355        else:
1356            consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
1357            use_custom_partition_algo = consumer.use_custom_partition_algo
1358            enable_autotune = consumer.enable_autotune
1359        prev_node_1 = None
1360        prev_node_2 = None
1361        fused_nodes: List[BaseSchedulerNode]
1362        if producer.is_foreach() and consumer.is_foreach():
1363            producer = typing.cast(ForeachKernelSchedulerNode, producer)
1364            consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
1365            fused_nodes = [
1366                FusedSchedulerNode.fuse(l, r)
1367                for l, r in zip(producer.snodes, consumer.snodes)
1368            ]
1369        elif producer.is_foreach():
1370            producer = typing.cast(ForeachKernelSchedulerNode, producer)
1371            producer_subnode = producer.get_producer_subnode_for(consumer)
1372            fused_nodes = []
1373            prev_node_1 = producer
1374            prev_node_2 = None
1375            for node in producer.snodes:
1376                if node is producer_subnode:
1377                    new_node = FusedSchedulerNode.fuse(node, consumer)
1378                    prev_node_2 = new_node
1379                    fused_nodes.append(new_node)
1380                else:
1381                    fused_nodes.append(node)
1382
1383        elif consumer.is_foreach():
1384            consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
1385            consumer_subnode = consumer.get_consumer_subnode_for(producer)
1386            fused_nodes = []
1387            prev_node_1 = consumer
1388            prev_node_2 = None
1389
1390            for node in consumer.snodes:
1391                if node is consumer_subnode:
1392                    new_node = FusedSchedulerNode.fuse(producer, node)
1393                    prev_node_2 = new_node
1394                    fused_nodes.append(new_node)
1395                else:
1396                    fused_nodes.append(node)
1397        else:
1398            raise AssertionError(
1399                "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
1400            )
1401
1402        return cls(
1403            producer.scheduler,
1404            fused_nodes,
1405            use_custom_partition_algo=use_custom_partition_algo,
1406            prev_node_1=prev_node_1,
1407            prev_node_2=prev_node_2,
1408            enable_autotune=enable_autotune,
1409        )
1410
1411    def __init__(
1412        self,
1413        scheduler: Scheduler,
1414        snodes: List[BaseSchedulerNode],
1415        use_custom_partition_algo: bool,
1416        prev_node_1: Optional[BaseSchedulerNode] = None,
1417        prev_node_2: Optional[BaseSchedulerNode] = None,
1418        enable_autotune: bool = False,
1419    ) -> None:
1420        self.read_to_node = {}
1421        self.name_to_node = {}
1422
1423        if prev_node_1 is None or prev_node_2 is None:
1424            super().__init__(scheduler, snodes)
1425
1426            for node in snodes:
1427                for read in node.read_writes.reads:
1428                    self.read_to_node[read.name] = node
1429
1430                for name in node.get_operation_names():
1431                    self.name_to_node[name] = node
1432        else:
1433            self.scheduler = scheduler
1434            self.snodes = snodes
1435            self.node = None
1436            self.users: List[NodeUser] = []
1437
1438            self.set_read_writes(
1439                dependencies.ReadWrites.merge_list(
1440                    [prev_node_1.read_writes, prev_node_2.read_writes]
1441                )
1442            )
1443
1444            self.unmet_dependencies = (
1445                OrderedSet(
1446                    dep
1447                    for dep in OrderedSet.union(
1448                        prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
1449                    )
1450                    if dep.name not in self.get_buffer_names()
1451                )
1452                - self.read_writes.writes
1453            )
1454
1455            self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
1456            self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
1457
1458            if prev_node_1.is_foreach():
1459                assert isinstance(prev_node_1, ForeachKernelSchedulerNode)
1460                foreach_node, other_node = prev_node_1, prev_node_2
1461            else:
1462                assert isinstance(prev_node_2, ForeachKernelSchedulerNode)
1463                foreach_node, other_node = prev_node_2, prev_node_1
1464
1465            self.ancestors = foreach_node.ancestors
1466            self.ancestors.update(other_node.ancestors)
1467
1468            self.name_to_node = foreach_node.name_to_node
1469            for name in other_node.get_operation_names():
1470                self.name_to_node[name] = other_node
1471
1472        self.use_custom_partition_algo = use_custom_partition_algo
1473        self.group = (snodes[0].get_device(), ((sympy.Expr("combo_kernel"),),))
1474        self.origins: OrderedSet[torch.fx.Node] = OrderedSet()
1475        self.enable_autotune = enable_autotune
1476
1477    @classmethod
1478    def combinable_nodes(
1479        cls, nodes: List[BaseSchedulerNode]
1480    ) -> List[BaseSchedulerNode]:
1481        extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
1482        if extern:
1483            log.debug(
1484                "ComboKernels: %d external nodes are filtered %s",
1485                len(extern),
1486                [node.node.get_origins() for node in extern if node.node is not None],
1487            )
1488        filtered_nodes = [
1489            x
1490            for x in nodes
1491            if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode))
1492        ]
1493        foreach_nodes = [
1494            x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
1495        ]
1496        if foreach_nodes:
1497            log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
1498        filtered_nodes = [
1499            x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
1500        ]
1501        template_nodes = [x for x in filtered_nodes if x.is_template()]
1502        if template_nodes:
1503            log.debug(
1504                "ComboKernels: %d template nodes are filtered", {len(template_nodes)}
1505            )
1506        filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
1507        return filtered_nodes
1508
1509    @staticmethod
1510    def _default_group_nodes_for_combo_kernels(
1511        scheduler: Scheduler,
1512    ) -> List[List[BaseSchedulerNode]]:
1513        """
1514        Returns a list of lists of nodes that are to be grouped together.
1515        """
1516        sorted_nodes = scheduler._topological_sort_nodes()
1517        grouped_nodes = []
1518        max_num_nodes = 8
1519        for nodes in sorted_nodes:
1520            grouped_nodes.extend(
1521                [
1522                    nodes[i : i + max_num_nodes]
1523                    for i in range(0, len(nodes), max_num_nodes)
1524                ]
1525            )
1526
1527        return grouped_nodes
1528
1529    group_algorithm_for_combo_kernels: Callable[
1530        [Scheduler], List[List[BaseSchedulerNode]]
1531    ] = _default_group_nodes_for_combo_kernels
1532
1533    @staticmethod
1534    def set_group_algorithm_for_combo_kernels(
1535        custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]]
1536    ) -> None:
1537        ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
1538            custom_group_algorithm
1539        )
1540
1541    @staticmethod
1542    def group_nodes_for_combo_kernels(
1543        scheduler: Scheduler,
1544    ) -> List[List[BaseSchedulerNode]]:
1545        return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler)
1546
1547    def mark_run(self) -> None:
1548        raise NotImplementedError
1549
1550    def codegen(self) -> None:
1551        assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
1552        self.node.get_store_function()(self.node.make_loader()())
1553
1554    def is_foreach(self) -> bool:
1555        return True
1556
1557    def get_subkernel_nodes(self) -> List[BaseSchedulerNode]:
1558        """Returns a list of nodes which comprise the combo kernel.
1559        These nodes may be vertically fused."""
1560        return list(self.snodes)
1561
1562    def get_nodes(self) -> Sequence[BaseSchedulerNode]:
1563        """Returns all nodes contained in this kernel, unpacking fused nodes
1564        into their constituent scheduler nodes."""
1565        return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
1566
1567    def get_first_name(self) -> str:
1568        return self.snodes[0].get_first_name()
1569
1570    def prune_redundant_deps(
1571        self, name_to_fused_node: Dict[str, BaseSchedulerNode]
1572    ) -> None:
1573        _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
1574
1575        for node in self.snodes:
1576            node.prune_redundant_deps(name_to_fused_node)
1577
1578
1579class GroupedSchedulerNode(BaseSchedulerNode):
1580    """
1581    This is a "fake" scheduler node that represents a group of scheduler nodes
1582    that are meant to be *grouped* together (it does not allow another node to be scheduled
1583    in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes).
1584    The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes.
1585    Fusion will still happen among the nodes within each GroupedSchedulerNode.
1586    At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node.
1587    """
1588
1589    snodes: List[BaseSchedulerNode]
1590
1591    @classmethod
1592    def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode:
1593        scheduler = snodes[0].scheduler
1594        assert all(node.scheduler is scheduler for node in snodes)
1595        grouped_snode = cls(scheduler, snodes)  # type: ignore[arg-type]
1596        for snode in snodes:
1597            scheduler.name_to_fused_node[snode.get_name()] = grouped_snode
1598        scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode
1599        return grouped_snode
1600
1601    def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
1602        super().__init__(scheduler)
1603        init_group_node(self, scheduler, snodes)
1604
1605    def unpack(self) -> List[BaseSchedulerNode]:
1606        """
1607        Do fusion among nodes within this GroupedSchedulerNode,
1608        and then unpack this GroupedSchedulerNode into regular nodes.
1609        """
1610        for snode in self.snodes:
1611            self.scheduler.name_to_fused_node[snode.get_name()] = snode
1612        del self.scheduler.name_to_fused_node[self.get_name()]
1613        return self.scheduler.fuse_nodes(self.snodes)
1614
1615    def add_fake_dep(self, fake_dep: Dep) -> None:
1616        self.set_read_writes(self.read_writes.with_read(fake_dep))
1617        self.unmet_dependencies.add(fake_dep)
1618
1619    @cache_on_self
1620    def get_name(self) -> str:
1621        return "_".join([x.get_name() for x in self.snodes])
1622
1623    def get_first_name(self) -> str:
1624        return self.snodes[0].get_name()
1625
1626    @cache_on_self
1627    def get_buffer_names(self) -> OrderedSet[str]:
1628        return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
1629
1630    def get_outputs(self) -> List[SchedulerBuffer]:
1631        result: List[SchedulerBuffer] = []
1632        for node in self.snodes:
1633            result.extend(node.get_outputs())
1634        return result
1635
1636    def get_nodes(self) -> Sequence[BaseSchedulerNode]:
1637        return self.snodes
1638
1639    @classmethod
1640    def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
1641        # GroupedSchedulerNode cannot be fused with another node
1642        return False
1643
1644
1645def pick_loop_order(
1646    stride_lengths: List[List[int]],
1647    sizes: List[sympy.Expr],
1648    priority_idx: Tuple[int, ...] = (),
1649) -> List[int]:
1650    """
1651    A heuristic to decide loop iteration orders.  This has not been well
1652    tuned and may be something we should autotune.
1653    """
1654
1655    @functools.cmp_to_key
1656    def index_cmp(a: int, b: int) -> int:
1657        if sizes[a] == 1 or sizes[b] == 1:
1658            # 1-sizes don't matter, just move them to the end
1659            return cmp(sizes[a] == 1, sizes[b] == 1)
1660
1661        # Take abs, otherwise flipped dimensions are treated as smaller
1662        # strides than contiguous dims
1663        stride_len_a = [abs(sl[a]) for sl in stride_lengths]
1664        stride_len_b = [abs(sl[b]) for sl in stride_lengths]
1665
1666        # equivalent to
1667        # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
1668        a_first = sum(
1669            sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
1670        )
1671        b_first = sum(
1672            sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
1673        )
1674        if a_first > b_first:
1675            return -1
1676        if b_first > a_first:
1677            return 1
1678
1679        # otherwise contiguous
1680        return cmp(b, a)
1681
1682    order = list(reversed(range(len(stride_lengths[0]))))
1683    if len(priority_idx) > 0:
1684        # if we have priority node, only use that node's order
1685        stride_lengths = [stride_lengths[pi] for pi in priority_idx]
1686    if config.pick_loop_orders:
1687        order.sort(key=index_cmp)
1688    return order
1689
1690
1691@dataclasses.dataclass
1692class NodeUser:
1693    node: Union[BaseSchedulerNode, OutputNode]
1694    can_inplace: bool = False
1695
1696    # A weak user must be scheduled after a given node, but doesn't actually
1697    # use the result
1698    is_weak: bool = False
1699
1700    def __hash__(self) -> int:
1701        return hash((self.node.get_name(), self.can_inplace, self.is_weak))
1702
1703    def __eq__(self, other: object) -> bool:
1704        return (
1705            isinstance(other, NodeUser)
1706            and self.get_name() == other.get_name()
1707            and self.can_inplace == other.can_inplace
1708            and self.is_weak == other.is_weak
1709        )
1710
1711    def get_name(self) -> str:
1712        return self.node.get_name()
1713
1714    def merge(self, other: NodeUser) -> NodeUser:
1715        assert self.node is other.node
1716        return NodeUser(
1717            self.node,
1718            self.can_inplace and other.can_inplace,
1719            self.is_weak and other.is_weak,
1720        )
1721
1722
1723_post_grad_graph_counter = itertools.count()
1724
1725
1726class Scheduler:
1727    __dep_size_hint_cache: Dict[Dep, int]
1728
1729    def __init__(self, nodes: List[ir.Operation]) -> None:
1730        with dynamo_timed("Scheduler.__init__"):
1731            self._init(nodes)
1732
1733    def _init(self, nodes: List[ir.Operation]) -> None:
1734        super().__init__()
1735        self.__dep_size_hint_cache = {}
1736        V.graph.scheduler = self
1737        self.backends: Dict[torch.device, BaseScheduling] = {}
1738        self.post_grad_graph_id = next(_post_grad_graph_counter)
1739
1740        self.completed_operations: OrderedSet[str] = OrderedSet()
1741        self.available_buffer_names = OrderedSet(
1742            [
1743                *V.graph.graph_inputs.keys(),
1744                *V.graph.constants.keys(),
1745                *V.graph.torchbind_constants.keys(),
1746            ]
1747        )
1748
1749        self.nodes = [self.create_scheduler_node(n) for n in nodes]
1750        self.update_zero_dim_cpu_tensor()
1751        # some new constants could have been created above
1752        self.available_buffer_names.update(V.graph.constants.keys())
1753        for node in self.nodes:
1754            node.prune_deps()
1755
1756        self.name_to_node: Dict[str, BaseSchedulerNode] = {
1757            n.get_name(): n for n in self.nodes
1758        }
1759        self.name_to_buf: Dict[str, SchedulerBuffer] = {
1760            buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
1761        }
1762        self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy()
1763
1764        # mutation_real_name: Maps back to the original name for codegen
1765        # Example:
1766        # If you mutate buf0 inside of buf1's kernel, then:
1767        # mutation_real_name = {"buf0" : "buf1"}
1768        # all subsequent uses of buf0 become buf1's usage in dependency graph
1769        self.mutation_real_name: Dict[str, str] = {}
1770
1771        # We handle mutation by renaming modified versions of the same
1772        # buffer in the dependency graph to prevent cycles.
1773        # mutation_renames: tracks the current name for a given buffer
1774        #                   (changed once per mutation)
1775        # Example:
1776        # If you mutate buf0 inside of buf1's kernel, then:
1777        # mutation_renames = {"buf1" : "buf0"}
1778        # in codegen we only use buf0, never buf1
1779        self.mutation_renames: Dict[str, str] = {}
1780
1781        self.compute_dependencies()
1782        self.nodes = self.topological_sort_schedule(self.nodes)
1783        self.dead_node_elimination()
1784        self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
1785        self.compute_ancestors()
1786        if config.reorder_for_compute_comm_overlap:
1787            self.nodes = comms.decide_global_ordering_of_comms(
1788                self.nodes,
1789                self.name_to_buf,
1790                self.name_to_fused_node,
1791            )
1792
1793        metrics.ir_nodes_pre_fusion += len(self.nodes)
1794        V.debug.ir_pre_fusion(self.nodes)
1795        self.num_orig_nodes = len(self.nodes)
1796        self.create_foreach_nodes()
1797        self.nodes = self.topological_sort_schedule(self.nodes)
1798        self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet()
1799        if config._pre_fusion_custom_pass is not None:
1800            self.nodes = config._pre_fusion_custom_pass(self.nodes)
1801        self.nodes = self.fuse_nodes(self.nodes)
1802        self.merge_loops()
1803        self.finalize_multi_template_buffers()
1804        if config.reorder_for_compute_comm_overlap:
1805            self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
1806        if config.combo_kernels:
1807            self.create_combo_kernel_nodes(num_ck_nodes=None)
1808        self.process_grouped_nodes()
1809        self.compute_last_usage()
1810        V.debug.ir_post_fusion(self.nodes)
1811        V.debug.graph_diagram(self.nodes)
1812        self.debug_draw_graph()
1813
1814        # used during codegen:
1815        self.current_device: Optional[torch.device] = None
1816        self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
1817
1818        # fx graph node to the position it appears in the graph
1819        # for debug attribution
1820        self.origin_to_index: Dict[torch.fx.Node, int] = {}
1821
1822        get_metric_table("graph_stats").add_row(
1823            lambda: {
1824                "graph_id": self.post_grad_graph_id,
1825                "num_nodes_before_fusion": self.num_orig_nodes,
1826                "num_nodes_after_fusion": len(self.nodes),
1827            }
1828        )
1829
1830    def get_current_device_or_throw(self) -> torch.device:
1831        if device := self.current_device:
1832            return device
1833        else:
1834            raise RuntimeError("No current device")
1835
1836    def debug_draw_graph(self) -> None:
1837        """Generate an image of the graph for debugging"""
1838        if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
1839            from .debug import draw_buffers
1840
1841            draw_buffers(self.nodes, print_graph=True)
1842
1843    def debug_print_nodes(self, label: str) -> None:
1844        if log.isEnabledFor(logging.INFO):
1845            log.info("%s:", label)
1846            for node in self.nodes:
1847                node.log_details()
1848
1849    def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
1850        assert (
1851            node.get_origins() is not None
1852        ), "All nodes passed to scheduling must have an origin"
1853        if node.is_no_op():
1854            return NopKernelSchedulerNode(self, node)
1855        elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
1856            return SchedulerNode(self, node)
1857        elif isinstance(node, ir.ExternKernel):
1858            return ExternKernelSchedulerNode(self, node)
1859        else:
1860            raise NotImplementedError(node)
1861
1862    def create_foreach_nodes(self) -> None:
1863        removed_node_names: OrderedSet[str] = OrderedSet()
1864        fe_nodes = []
1865        kept_node_names = self.name_to_fused_node.keys()
1866
1867        for names in V.graph.lists.values():
1868            names = [
1869                name
1870                for name in names
1871                if name in kept_node_names
1872                and not isinstance(self.name_to_node[name], NopKernelSchedulerNode)
1873            ]
1874            if not names:
1875                # All nodes eliminated
1876                continue
1877
1878            removed_node_names.update(names)
1879            snodes = [self.name_to_node[name] for name in names]
1880
1881            enable_autotune = config.combo_kernels_autotune > 1
1882            fe_node = ForeachKernelSchedulerNode(
1883                self,
1884                snodes,
1885                use_custom_partition_algo=False,
1886                enable_autotune=enable_autotune,
1887            )
1888
1889            fe_nodes.append(fe_node)
1890
1891            for name in names:
1892                self.name_to_fused_node[name] = fe_node
1893
1894        self.nodes = [
1895            node for node in self.nodes if node.get_name() not in removed_node_names
1896        ] + list(fe_nodes)
1897
1898    def compute_dependencies(self) -> None:
1899        """
1900        Create dependency edges between nodes, handling aliasing and
1901        mutation properly.
1902        """
1903
1904        T = TypeVar("T")
1905
1906        class DedupList(Generic[T]):
1907            """
1908            This data structure behaves like a list except it makes sure the
1909            elements remain unique.
1910            Normally one could use a OrderedSet/dict for this purpose however
1911            the list in question gets elements appended as it is being
1912            iterated over which means that we need to keep the list
1913            semantics.
1914            """
1915
1916            def __init__(
1917                self,
1918                items: Optional[List[T]] = None,
1919                membership: Optional[OrderedSet[T]] = None,
1920            ) -> None:
1921                self.items = items or []
1922                self.membership = membership or OrderedSet()
1923
1924            def append(self, node_user: T) -> None:
1925                if node_user in self.membership:
1926                    return
1927                self.items.append(node_user)
1928                self.membership.add(node_user)
1929
1930            def __add__(self, other: DedupList[T]) -> DedupList[T]:
1931                new_membership = OrderedSet.union(self.membership, other.membership)
1932                new_items = self.items + [
1933                    x for x in other.items if x not in self.membership
1934                ]
1935                return DedupList(new_items, new_membership)
1936
1937        name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
1938            DedupList
1939        )
1940
1941        # handle aliasing by using python aliasing in name_to_users
1942        # if foo aliases bar then we will make name_to_users["foo"] point
1943        # to the same python list as name_to_users["bar"]
1944        for node in self.nodes:
1945            for buf1 in node.get_outputs():
1946                buf1_name = buf1.get_name()
1947                for buf2_name in buf1.get_aliases():
1948                    if buf1_name in name_to_users and buf2_name in name_to_users:
1949                        # merge the two
1950                        list1 = name_to_users[buf1_name]
1951                        list2 = name_to_users[buf2_name]
1952                        combined = list1 + list2
1953                        for key in name_to_users.keys():
1954                            if (
1955                                name_to_users[key] is list1
1956                                or name_to_users[key] is list2
1957                            ):
1958                                name_to_users[key] = combined
1959                    elif buf1_name in name_to_users:
1960                        name_to_users[buf2_name] = name_to_users[buf1_name]
1961                    else:
1962                        name_to_users[buf1_name] = name_to_users[buf2_name]
1963
1964        def rename(n: str) -> str:
1965            if n in self.mutation_renames:
1966                return rename(self.mutation_renames[n])
1967            return n
1968
1969        def add_user(
1970            used_by_name: str,
1971            user_node: Union[BaseSchedulerNode, OutputNode],
1972            can_inplace: bool = False,
1973            is_weak: bool = False,
1974        ) -> None:
1975            name_to_users[rename(used_by_name)].append(
1976                NodeUser(user_node, can_inplace, is_weak)
1977            )
1978
1979        unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {}
1980
1981        # NB: None means that the dependency is on an input.  Don't actually
1982        # generate a dependency because if we do, Inductor will start trying
1983        # to free the unbacked int but that's pointless
1984        for name, val in V.graph.graph_inputs.items():
1985            if isinstance(val, sympy.Expr):
1986                for fs in val.free_symbols:
1987                    unbacked_symbol_to_origin_node[fs] = None
1988
1989        for node in self.nodes:
1990            log.debug("scheduling %s", node.node)
1991
1992            # unbacked symbols don't follow ordinary buffer dependencies, so
1993            # we track their def/uses separately
1994            assert node.node is not None
1995            unbacked_symbol_defs = sorted(
1996                node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
1997            )
1998            for s in unbacked_symbol_defs:
1999                assert isinstance(s, sympy.Symbol)
2000                # Pick the first definer as canonical.  There may be multiple
2001                # because if a MultiOutputLayout buffer propagates an unbacked
2002                # symint to multiple outputs, they will all claim to def it.
2003                if s not in unbacked_symbol_to_origin_node:
2004                    unbacked_symbol_to_origin_node[s] = node.get_name()
2005
2006            unbacked_symbol_uses = sorted(
2007                node.node.get_unbacked_symbol_uses(), key=lambda x: x.name
2008            )
2009            # if a kernel takes unbacked symints, register dependencies
2010            for s in unbacked_symbol_uses:
2011                assert (
2012                    s in unbacked_symbol_to_origin_node
2013                ), f"{s} not in {unbacked_symbol_to_origin_node}"
2014                if (r := unbacked_symbol_to_origin_node[s]) is not None:
2015                    for buf in self.name_to_node[r].get_outputs():
2016                        node.add_fake_dep(StarDep(buf.get_name()))
2017
2018            if (
2019                len(node.read_writes.writes) == 1
2020                and (dep := next(iter(node.read_writes.writes)))
2021                and isinstance(dep, MemoryDep)
2022            ):
2023                node_mode = dep.mode
2024            else:
2025                node_mode = None
2026
2027            # Handle output mutations
2028            for buf in node.get_outputs():
2029                # a node will mutate either 0 or 1 buffers
2030                assert len(buf.get_mutations()) <= 1
2031                for alt_name in buf.get_mutations():
2032                    alt_name = rename(alt_name)
2033                    # this node must run after the prior writer
2034                    add_user(alt_name, node)
2035                    node.add_fake_dep(StarDep(alt_name, mode=node_mode))
2036                    for user in name_to_users[alt_name].items:
2037                        if user.get_name() == node.get_name():
2038                            continue
2039
2040                        assert isinstance(user.node, BaseSchedulerNode)
2041                        for other_name in user.node.get_buffer_names():
2042                            # this node must run after all prior readers
2043                            other_name = rename(other_name)
2044                            node.add_fake_dep(
2045                                WeakDep(other_name, mutating_buf=buf.get_name())
2046                            )
2047                            add_user(other_name, node, is_weak=True)
2048
2049            # add normal non-mutation dependencies
2050            for read in node.read_writes.reads:
2051                if not isinstance(read, WeakDep):
2052                    add_user(read.name, node, node.can_inplace(read))
2053
2054            node.update_mutated_names(self.mutation_renames)
2055
2056            # update our renaming scheme for the next iteration
2057            for buf in node.get_outputs():
2058                for alt_name in buf.get_mutations():
2059                    self.mutation_renames[rename(alt_name)] = buf.get_name()
2060                    self.mutation_renames[alt_name] = buf.get_name()
2061                    self.mutation_real_name[
2062                        buf.get_name()
2063                    ] = self.mutation_real_name.get(alt_name, alt_name)
2064
2065        # make sure outputs aren't dead-code-eliminated
2066        for buf_name in V.graph.get_output_names():
2067            log.debug("scheduling output %s", buf_name)
2068            add_user(buf_name, OutputNode(StarDep(buf_name)))
2069
2070        # make sure unbacked symints aren't dead-code-eliminated
2071        for out in V.graph.graph_outputs:
2072            for s in out.get_unbacked_symbol_uses():
2073                assert (
2074                    s in unbacked_symbol_to_origin_node
2075                ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
2076                if r := unbacked_symbol_to_origin_node[s]:
2077                    for buf_name in self.name_to_node[r].get_buffer_names():
2078                        log.debug(
2079                            "scheduling output %s for unbacked symint %s", buf_name, s
2080                        )
2081                        add_user(buf_name, OutputNode(StarDep(buf_name)))
2082
2083        # make sure input mutation isn't dead-code-eliminated
2084        for name in self.mutation_renames:
2085            if name in V.graph.graph_inputs:
2086                add_user(name, OutputNode(StarDep(name)))
2087                V.graph.mutated_inputs.add(name)
2088            elif name in V.graph.constants:
2089                # In AOTI, module parameters and buffers are not lifted as graph inputs
2090                add_user(name, OutputNode(StarDep(name)))
2091
2092        inp_names = {
2093            name: index for index, name in enumerate(V.graph.graph_inputs.keys())
2094        }
2095        V.graph.mutated_input_idxs = [
2096            inp_names[name] for name in V.graph.mutated_inputs
2097        ]
2098
2099        # copy users information onto the nodes
2100        for node in self.nodes:
2101            for buf in node.get_outputs():
2102                buf.set_users(name_to_users[buf.get_name()].items)
2103
2104    def dead_node_elimination(self) -> None:
2105        """
2106        Remove any nodes without users
2107        """
2108        # self.nodes is in topological order, so by iterating in reverse order
2109        # we have visited (and potentially removed) all users before visiting a
2110        # given node.
2111        updated_nodes = []
2112        for node in reversed(self.nodes):
2113
2114            def can_eliminate_user(user: NodeUser) -> bool:
2115                return user.is_weak or user.get_name() in V.graph.removed_operations
2116
2117            active_buffers = False
2118            for buf in node.get_outputs():
2119                can_eliminate = all(can_eliminate_user(u) for u in buf.users)
2120                if can_eliminate:
2121                    log.debug("removed dead buffer: %s", buf.get_name())
2122                    V.graph.removed_buffers.add(buf.get_name())
2123                else:
2124                    active_buffers = True
2125
2126            can_eliminate = not node.has_side_effects() and not active_buffers
2127
2128            if not can_eliminate:
2129                updated_nodes.append(node)
2130            else:
2131                # dead code
2132                log.debug("removed dead operation: %s", node.get_name())
2133                V.graph.removed_operations.add(node.get_name())
2134
2135        self.nodes = list(reversed(updated_nodes))
2136
2137        # Prune any WeakDeps no longer needed
2138        for node in self.nodes:
2139            node.prune_weak_deps()
2140
2141    def topological_sort_schedule(
2142        self, nodes: List[BaseSchedulerNode]
2143    ) -> List[BaseSchedulerNode]:
2144        """
2145        Ensure nodes is in topologically sorted order
2146        """
2147        seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
2148        name_to_node: Dict[str, BaseSchedulerNode] = dict()
2149        result: List[BaseSchedulerNode] = []
2150
2151        def visit(n: BaseSchedulerNode) -> None:
2152            if n not in seen:
2153                seen.add(n)
2154                for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
2155                    # We only care about doing toposort within `nodes`
2156                    if dep.name not in name_to_node:
2157                        continue
2158                    visit(name_to_node[dep.name])
2159                result.append(n)
2160
2161        for node in nodes:
2162            for name in node.get_buffer_names():
2163                name_to_node[name] = node
2164        for node in nodes:
2165            visit(node)
2166        return result
2167
2168    def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:
2169        unmet_deps = set()
2170        if isinstance(
2171            snode,
2172            (
2173                SchedulerNode,
2174                ExternKernelSchedulerNode,
2175                NopKernelSchedulerNode,
2176                FusedSchedulerNode,
2177            ),
2178        ):
2179            for dep in snode.unmet_dependencies:
2180                unmet_deps.add(dep.name)
2181        else:
2182            raise RuntimeError(
2183                f"get_unmet_dep_nodes is not implemented for {type(snode)}."
2184            )
2185        unmet_dep_ops = (self.name_to_buf[dep].defining_op for dep in unmet_deps)
2186        return list({self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops})
2187
2188    def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]:
2189        """
2190        Sort nodes by their topological order, return a list of node lists.
2191        """
2192        order = []
2193        nodes = dict.fromkeys(self.nodes, 0)
2194        children: Dict[Any, Any] = {}
2195        for node in self.nodes:
2196            deps = self._get_unmet_dep_nodes(node)
2197            nodes[node] = len(deps)
2198            for dep in deps:
2199                c = children.get(dep, [])
2200                c.append(node)
2201                children[dep] = c
2202
2203        zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
2204        while zero_deg_nodes:
2205            order.append(zero_deg_nodes)
2206            for n in zero_deg_nodes:
2207                for user in children.get(n, []):
2208                    nodes[user] -= 1
2209                nodes.pop(n)
2210            zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
2211        assert not nodes, "Topological sort failed!"
2212        return order
2213
2214    def compute_ancestors(self) -> None:
2215        """
2216        Populate each node.ancestors
2217        """
2218        # note self.nodes is topologically sorted
2219        name_to_ancestors: Dict[str, OrderedSet[str]] = {}
2220        for node in self.nodes:
2221            ancestors: OrderedSet[str] = OrderedSet()
2222            for dep in node.unmet_dependencies:
2223                dep_node_name = self.name_to_buf[dep.name].defining_op.get_name()
2224                ancestors.add(dep_node_name)
2225                ancestors |= name_to_ancestors[dep_node_name]
2226            name_to_ancestors[node.get_name()] = ancestors
2227            node.ancestors = ancestors
2228
2229        for order, node in enumerate(self.nodes):
2230            node.min_order = order
2231            node.max_order = order
2232
2233    def merge_loops(self) -> None:
2234        for node in self.nodes:
2235            if not config.loop_ordering_after_fusion:
2236                continue
2237
2238            # Even for CPU, if we are using the halide backend, we still need
2239            # the merge loops steps below
2240            if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or (
2241                node.get_device().type != "cuda" and config.cpu_backend != "halide"
2242            ):
2243                continue
2244            for snode in node.get_nodes():
2245                # merge loops for the scheduler node
2246                if not isinstance(snode, SchedulerNode) or snode.is_template():
2247                    continue
2248
2249                snode._body = snode._body.merge_loops()
2250                snode._sizes = snode._body.sizes
2251
2252                # merge_loops is called after loop reordering.
2253                # We still need retain fake dependencies since codegen the
2254                # estimated amount of memory access rely on them.
2255                snode.refresh_dependencies(normalize=True)
2256
2257                # Note that for CPU backend, merging loops will change
2258                # snode.group. It's fine for Triton backend.
2259                # But if we simplify update snode.group like this:
2260                #   group_fn = self.get_backend(snode.node.get_device()).group_fn
2261                #   snode.group = (snode.node.get_device(), group_fn(snode._sizes))
2262                # There is still an issue due to different snode in a
2263                # FusedSchedulerNode having different merged loops.
2264                # Skip CPU backend for now.
2265
2266    def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
2267        """
2268        Combine eligible nodes into FusedSchedulerNodes.
2269        """
2270        for i in range(10):
2271            old_len = len(nodes)
2272            fusion_log.debug(
2273                "===== attempting fusion (%d/10): %d nodes =====",
2274                i + 1,
2275                old_len,
2276            )
2277            nodes = self.fuse_nodes_once(nodes)
2278            new_len = len(nodes)
2279            fusion_log.debug(
2280                "completed fusion round (%d/10): fused %d nodes into %d nodes\n",
2281                i + 1,
2282                old_len,
2283                new_len,
2284            )
2285            if new_len == old_len or new_len == 1:
2286                fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
2287                break
2288        return nodes
2289
2290    def process_grouped_nodes(self) -> None:
2291        """
2292        Unpack GroupedSchedulerNode into regular nodes.
2293        """
2294        new_nodes: List[BaseSchedulerNode] = []
2295        for node in self.nodes:
2296            new_nodes.extend(
2297                node.unpack() if isinstance(node, GroupedSchedulerNode) else [node]
2298            )
2299        self.nodes = new_nodes
2300
2301    def benchmark_fused_nodes(
2302        self, nodes: Sequence[BaseSchedulerNode]
2303    ) -> Tuple[float, str]:
2304        """
2305        Benchmark fused list of nodes and return the execution time
2306        in milliseconds on randomly generated inputs.
2307        """
2308        assert len(nodes) > 0
2309        device = nodes[0].get_device()
2310        self.current_device = device
2311        backend = self.get_backend(device)
2312        return backend.benchmark_fused_nodes(nodes)
2313
2314    def finalize_multi_template_buffers(self) -> None:
2315        def replace_operation_buffer(
2316            orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
2317        ) -> None:
2318            replaced_buf_name = new_node.get_name()
2319            orig_buf_name = orig_node.get_name()
2320            assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
2321
2322            replaced_op_name = new_node.get_operation_name()
2323            orig_op_name = orig_node.get_operation_name()
2324            assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
2325
2326            del V.graph.name_to_buffer[replaced_buf_name]
2327            new_node.name = orig_buf_name
2328
2329            del V.graph.name_to_op[replaced_op_name]
2330            new_node.operation_name = orig_op_name
2331
2332            orig = V.graph.buffers.index(orig_node)
2333            V.graph.buffers.remove(new_node)
2334            V.graph.buffers[orig] = new_node
2335            V.graph.name_to_buffer[orig_buf_name] = new_node
2336
2337            orig = V.graph.operations.index(orig_node)
2338            V.graph.operations.remove(new_node)
2339            V.graph.operations[orig] = new_node
2340            V.graph.name_to_op[orig_op_name] = new_node
2341
2342        for i, node in enumerate(self.nodes):
2343            if isinstance(node, SchedulerNode) and isinstance(
2344                node.node, ir.MultiTemplateBuffer
2345            ):
2346                multi_node = node.node
2347                min_node_unfused, _ = multi_node.get_min_choice()
2348
2349                if isinstance(
2350                    min_node_unfused,
2351                    torch._inductor.ir.TritonTemplateCallerBase,
2352                ):
2353                    node.node.finalize_as_triton_caller(min_node_unfused)
2354                    continue
2355
2356                out_tensorbox = min_node_unfused.output_node()
2357                out_storage = out_tensorbox.data
2358                assert isinstance(out_storage, ir.StorageBox)
2359                out_buffer = out_storage.data
2360                assert isinstance(out_buffer, ir.OperationBuffer)
2361
2362                out_buffer.layout = multi_node.layout
2363                replace_operation_buffer(multi_node, out_buffer)
2364                new_scheduler_node = self.create_scheduler_node(out_buffer)
2365
2366                self.nodes[i] = new_scheduler_node
2367                self.name_to_node[node.get_name()] = new_scheduler_node
2368                self.name_to_fused_node[node.get_name()] = new_scheduler_node
2369
2370                for new_out, old_out in zip(
2371                    new_scheduler_node.get_outputs(), node.get_outputs()
2372                ):
2373                    self.name_to_buf[old_out.get_name()] = new_out
2374                    new_out.users = old_out.users
2375
2376                new_scheduler_node.min_order = node.min_order
2377                new_scheduler_node.max_order = node.max_order
2378                new_scheduler_node.last_usage = node.last_usage
2379
2380    def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
2381        return any(
2382            hasattr(n.node, "data")
2383            and n.node is not None
2384            and hasattr(n.node.data, "scatter_mode")
2385            and n.node.data.scatter_mode == "atomic_add"
2386            for n in node_list
2387        )
2388
2389    def speedup_by_fusion(
2390        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2391    ) -> bool:
2392        """
2393        If config.benchmark_fusion is False, always return True.
2394        Otherwise, return True if fusion can brings speedup.
2395        """
2396
2397        is_multi_template = node1.is_template() and isinstance(
2398            node1.get_template_node(), ir.MultiTemplateBuffer
2399        )
2400        if not config.benchmark_fusion and not is_multi_template:
2401            return True
2402
2403        if (
2404            node1.is_template()
2405            and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
2406            or node1.is_foreach()
2407            or node2.is_foreach()
2408        ):
2409            # TODO support benchmarking epilogue fusion
2410            return True
2411
2412        node_list_1 = node1.get_nodes()
2413        device = node_list_1[0].get_device()
2414
2415        # don't support benchmark fusion for CPU right now.
2416        if device.type == "cpu":
2417            return True
2418
2419        node_list_2 = node2.get_nodes()
2420        node_list_fused = list(itertools.chain(node_list_1, node_list_2))
2421
2422        # We can not accurately benchmark kernel using atomic_add
2423        # due to how we generate random integer inputs.
2424        # Skip benchmarking them by allowing fusion.
2425        if self._any_atomic_add(node_list_fused):
2426            return True
2427
2428        from triton.compiler.errors import CompilationError
2429
2430        why = WhyNoFuse(node1, node2)
2431
2432        def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None:
2433            if fusion_log.isEnabledFor(logging.DEBUG):
2434                if ms_fused < ms1 + ms2:
2435                    fusion_log.debug(
2436                        "can fuse (benchmark): fusing %s with %s cause %sx speedup",
2437                        node1.get_buffer_names(),
2438                        node2.get_buffer_names(),
2439                        green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
2440                    )
2441                else:
2442                    fusion_log.debug(
2443                        "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
2444                        node1.get_buffer_names(),
2445                        node2.get_buffer_names(),
2446                        red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
2447                    )
2448
2449        if isinstance(node1, SchedulerNode) and isinstance(
2450            node1.node, ir.MultiTemplateBuffer
2451        ):
2452            multi_node = node1.node
2453            choice_timings = multi_node.choice_timings
2454
2455            _, ms1 = multi_node.get_min_choice()
2456            ms2, path2 = self.benchmark_fused_nodes(node_list_2)
2457
2458            min_ms_fused = float("inf")
2459            ms_fused_choice = None
2460
2461            triton_choices = 0
2462
2463            for choice, unfused_time in sorted(
2464                choice_timings.items(), key=lambda x: x[1]
2465            ):
2466                if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase):
2467                    continue
2468
2469                if unfused_time >= ms1 + ms2:
2470                    break
2471
2472                triton_choices += 1
2473                if triton_choices > config.max_epilogue_benchmarked_choices:
2474                    break
2475
2476                # TODO - parallel compile triton templates
2477                # TODO - should prune/skip choices that are not within certain % of best choice
2478                with node1.node.swap_as_triton_caller(choice):
2479                    ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
2480
2481                    if ms_fused < min_ms_fused:
2482                        min_ms_fused = ms_fused
2483                        ms_fused_choice = choice
2484
2485            log_fusion(min_ms_fused, ms1, ms2)
2486
2487            # after we do a fusion, we finalize a triton template.
2488            # TODO - could preserve multi template and choices for subsequent fusions
2489            if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
2490                node1.node.finalize_as_triton_caller(ms_fused_choice)
2491                return True
2492            else:
2493                return False
2494        else:
2495            try:
2496                ms1, path1 = self.benchmark_fused_nodes(node_list_1)
2497                if math.isinf(ms1):
2498                    why("register spilling of the first kernel")
2499                    return False
2500                ms2, path2 = self.benchmark_fused_nodes(node_list_2)
2501                if math.isinf(ms2):
2502                    why("register spilling of the second kernel")
2503                    return False
2504                ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
2505                if math.isinf(ms_fused):
2506                    why("register spilling of the fused kernel")
2507                    return False
2508            except CompilationError as e:
2509                # workaround triton issue: https://github.com/openai/triton/issues/2151
2510                if "Loop-carried variable" in str(e):
2511                    return True  # allow fusion
2512                else:
2513                    raise
2514
2515        log_fusion(ms_fused, ms1, ms2)
2516        if (
2517            is_metric_table_enabled("slow_fusion")
2518            and ms_fused >= ms1 + ms2
2519            and (path1, path2) not in self.logged_slow_fusion
2520        ):
2521            self.logged_slow_fusion.add((path1, path2))
2522            get_metric_table("slow_fusion").add_row(
2523                lambda: {
2524                    "kernel1_path": path1,
2525                    "kernel1_latency": ms1,
2526                    "kernel2_path": path2,
2527                    "kernel2_latency": ms2,
2528                    "fused_kernel_path": path_fused,
2529                    "fused_kernel_latency": ms_fused,
2530                    "slow_down_ratio": ms_fused / (ms1 + ms2),
2531                }
2532            )
2533        return ms_fused < ms1 + ms2
2534
2535    def fuse_nodes_once(
2536        self, nodes: List[BaseSchedulerNode]
2537    ) -> List[BaseSchedulerNode]:
2538        """
2539        Combine eligible nodes into FusedSchedulerNodes.
2540
2541        This relies on two key functions to control the logic:
2542            - self.can_fuse(): checks if a fusion is legal
2543            - self.score_fusion(): assigns priority to a given fusion
2544        """
2545        fused_nodes = OrderedSet(nodes)
2546        if fusion_log.isEnabledFor(logging.DEBUG):
2547            fusion_log.debug("fuse_nodes_once, candidates:")
2548            for node in fused_nodes:
2549                fusion_log.debug("  " + node.debug_str_short())  # noqa: G003
2550        for node1, node2 in self.get_possible_fusions(nodes):
2551            node1 = self.name_to_fused_node[node1.get_first_name()]
2552            node2 = self.name_to_fused_node[node2.get_first_name()]
2553            if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
2554                node1, node2
2555            ):
2556                if not self.speedup_by_fusion(node1, node2):
2557                    continue
2558                fusion_log.debug(
2559                    "fusing %s with %s", node1.get_name(), node2.get_name()
2560                )
2561
2562                # above can_fuse asserts that node2 has the same device
2563                device = node1.get_device()
2564                node3 = self.get_backend(device).fuse(node1, node2)
2565                fused_nodes.remove(node1)
2566                fused_nodes.remove(node2)
2567                fused_nodes.add(node3)
2568                self.name_to_fused_node.update(
2569                    {n.get_name(): node3 for n in node3.get_nodes()}
2570                )
2571        nodes = sorted(fused_nodes, key=lambda x: x.min_order)
2572        nodes = self.topological_sort_schedule(nodes)
2573        self.prune_redundant_deps(nodes)
2574        return nodes
2575
2576    def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None:
2577        """
2578        Groups parallel nodes
2579        """
2580        fused_nodes = set(self.nodes)
2581        count = 0
2582        num_nodes_orig = len(self.nodes)
2583        log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes)
2584        for num, node_list in enumerate(
2585            ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self)
2586        ):
2587            node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
2588            if len(node_list) < 2:
2589                continue
2590            if num_ck_nodes is not None and count > num_ck_nodes:
2591                break
2592            if not self.speedup_by_combo_kernel(node_list):
2593                log.debug("ComboKernels: Not speeding up %d-th group", num)
2594                continue
2595            count += 1
2596            enable_autotune = config.combo_kernels_autotune > 0
2597            group_snode = ForeachKernelSchedulerNode(
2598                node_list[0].scheduler,
2599                node_list,
2600                use_custom_partition_algo=True,
2601                enable_autotune=enable_autotune,
2602            )
2603            log.info(
2604                "ComboKernels: Combining %d nodes for %d-th group",
2605                len(node_list),
2606                num,
2607            )
2608            for node in node_list:
2609                fused_nodes.remove(node)
2610            fused_nodes.add(group_snode)
2611            self.name_to_fused_node.update(
2612                {n.get_name(): group_snode for n in group_snode.get_nodes()}
2613            )
2614        self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
2615        self.nodes = self.topological_sort_schedule(self.nodes)
2616        log.info(
2617            "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels",
2618            count,
2619            num_nodes_orig,
2620            len(self.nodes),
2621        )
2622        self.prune_redundant_deps(self.nodes)
2623
2624    def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None:
2625        for node in nodes:
2626            node.prune_redundant_deps(self.name_to_fused_node)
2627
2628    def get_possible_fusions(
2629        self, nodes: List[BaseSchedulerNode]
2630    ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]:
2631        """
2632        Helper to find all legal fusion opportunities, sorted by self.score_fusion()
2633        """
2634        possible_fusions = []
2635        seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet()
2636
2637        def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:
2638            for node1_index, node1 in enumerate(nodes):
2639                for node2 in nodes[node1_index + 1 :]:
2640                    key = (node1, node2)
2641                    if key in seen:
2642                        continue
2643                    seen.add(key)
2644
2645                    if self.can_fuse(node1, node2):
2646                        possible_fusions.append(key)
2647                    elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
2648                        node2, node1
2649                    ):
2650                        # foreach fusions and epilogue fusions are order dependent
2651                        possible_fusions.append((node2, node1))
2652
2653        buffer_names_grouping = collections.defaultdict(list)
2654        for node in nodes:
2655            for buf in node.used_buffer_names():
2656                buffer_names_grouping[buf].append(node)
2657        for node_grouping in buffer_names_grouping.values():
2658            check_all_pairs(node_grouping)
2659
2660        if config.aggressive_fusion:
2661            group_grouping = collections.defaultdict(list)
2662            for node in nodes:
2663                group = getattr(node, "group", None)
2664                if group:
2665                    group_grouping[group].append(node)
2666            for node_grouping in group_grouping.values():
2667                check_all_pairs(node_grouping)
2668
2669        possible_fusions = self.get_possible_fusions_with_highest_priority(
2670            possible_fusions
2671        )
2672        possible_fusions.sort(key=self.score_fusion_key, reverse=True)
2673        fusion_log.debug("found %d possible fusions", len(possible_fusions))
2674        return possible_fusions
2675
2676    def will_fusion_create_cycle(
2677        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2678    ) -> bool:
2679        """
2680        Finds whether there's a path from node1 to node2 (or vice-versa)
2681        caused indirectly by other fusions.
2682        """
2683        # since we are just returning boolean here, use slightly faster, unordered set
2684        visited: Set[FusedSchedulerNode] = set()
2685
2686        def found_path(node: BaseSchedulerNode) -> bool:
2687            # only fused nodes can introduce new ancestors.
2688            if isinstance(node, FusedSchedulerNode) and node not in visited:
2689                visited.add(node)
2690                if node.get_operation_names().issubset(combined_ancestors):
2691                    # All fusion outputs are in ancestors of node1 and node2, thus
2692                    # cannot introduce new path:
2693                    #
2694                    # 1. if output is neither descendent of node1 or node2, the
2695                    #        output cannot introduce a path
2696                    # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be
2697                    #        on path(node1->node2), hence it cannot be ancestor of node2
2698                    # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be
2699                    #        ancestor of node1
2700                    return False
2701                else:
2702                    # continue DFS of new ancestors introduced by the fusion
2703                    return bool(combined_names & node.ancestors) or any(
2704                        found_path(self.name_to_fused_node[n])
2705                        for n in node.ancestors - combined_ancestors
2706                    )
2707            return False
2708
2709        # as above - use slightly faster, unordered set
2710        combined_names = (
2711            node1.get_operation_names()._dict.keys()
2712            | node2.get_operation_names()._dict.keys()
2713        )
2714        combined_ancestors = (
2715            node1.ancestors._dict.keys() | node2.ancestors._dict.keys()
2716        ) - combined_names
2717        cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
2718        if cycle:
2719            WhyNoFuse(node1, node2)("will create cycle")
2720        return cycle
2721
2722    def can_fusion_increase_peak_memory(
2723        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2724    ) -> bool:
2725        """
2726        This function prevents fusion for nodes that can increase memory
2727        footprint. This problem is more common in horizontal fusion, where nodes
2728        that are far apart in the original order get fused, lengthening the live
2729        intervals of tensors. This is very evident in models with activation
2730        checkpointing, where the recomputed nodes from different checkpointed
2731        regions get fused and significantly increase the memory footprint.
2732
2733        The current attempt is a quick, possibly hacky, heuristic to prevent the
2734        fusion of nodes that are far away in the original order.
2735
2736        A better but difficult to implement heurisitic would be to use live
2737        intervals of the buffers, find region of peak pressure in the original
2738        program and prevent fusion that crosses that peak region. We might need
2739        special care or good approximation in this implementation, as fusion of
2740        node changes live intervals, and re-computing live intervals and peak
2741        memory after each fusion can introduce large compilation overhead.
2742        """
2743        proximity_score = max(
2744            abs(node1.min_order - node2.max_order),
2745            abs(node2.min_order - node1.max_order),
2746        )
2747        return proximity_score > 64
2748
2749    def decide_fusion_fail_reason(
2750        self,
2751        node1: BaseSchedulerNode,
2752        node2: BaseSchedulerNode,
2753        common_buf_names: Tuple[str, ...],
2754    ) -> str:
2755        """
2756        Try to decide reasons why fusion fail due to no shared memory even though
2757        there are common buffers.
2758        """
2759        reasons = {}
2760        node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
2761        node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
2762
2763        for buf_name in common_buf_names:
2764            buf = V.graph.get_buffer(buf_name)
2765            lhs_dep = node1_name2dep[buf_name]
2766            rhs_dep = node2_name2dep[buf_name]
2767
2768            if lhs_dep.get_numel() != rhs_dep.get_numel():
2769                reasons[
2770                    buf_name
2771                ] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
2772                continue
2773
2774            # same numel but different MemoryDep.size. Should be broadcasting
2775            if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size):
2776                reasons[buf_name] = "broadcast"
2777                continue
2778
2779            if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
2780                reasons[
2781                    buf_name
2782                ] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
2783                continue
2784
2785            lhs_off = lhs_dep.get_offset()
2786            rhs_off = rhs_dep.get_offset()
2787            if lhs_off != rhs_off:
2788                # One example is in transformer, we use a concatenated linear layer
2789                # to project Q/K/V and then split the result. The 3 splits will
2790                # point to the same buffer with different offsets.
2791                reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}"
2792                continue
2793
2794            if (
2795                lhs_dep.normalize_with_stride_order()
2796                == rhs_dep.normalize_with_stride_order()
2797            ):
2798                reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}"
2799                continue
2800
2801            # Add more rules here
2802            reasons[
2803                buf_name
2804            ] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. Layout: {buf.layout}"
2805
2806        return str(reasons)
2807
2808    def has_shared_data_after_reordering_loop(
2809        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2810    ) -> bool:
2811        """
2812        Right now just greedily reorder the loop of node1 to be compatible with node2,
2813        but ideally we should have some heuristics to reorder the loop for node2
2814        to be compatibile with node1 if that's more efficient.
2815        """
2816
2817        # TODO Don't do loop reordering for CPU for now.
2818        # Should debug more why it does not work for CPU codegen
2819        if not config.loop_ordering_after_fusion or any(
2820            n.get_device().type == "cpu" for n in [node1, node2]
2821        ):
2822            return False
2823
2824        node1_buffer_names = node1.read_writes.buffer_names()
2825        node2_buffer_names = node2.read_writes.buffer_names()
2826        # Fast path: no common buffers.
2827        common_buffer_names = node1_buffer_names & node2_buffer_names
2828        if not common_buffer_names:
2829            return False
2830
2831        node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
2832        node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
2833
2834        # Find the commons buffers that has different loop orders
2835        candidates = []
2836        for buffer_name in common_buffer_names:
2837            lhs_dep = node1_name2dep[buffer_name]
2838            rhs_dep = node2_name2dep[buffer_name]
2839            if (
2840                lhs_dep.normalize_with_stride_order()
2841                == rhs_dep.normalize_with_stride_order()
2842            ):
2843                candidates.append(
2844                    (
2845                        V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0),
2846                        lhs_dep,
2847                        rhs_dep,
2848                    )
2849                )
2850
2851        if len(candidates) == 0:
2852            return False
2853
2854        # Pick the largest buffer to guide the loop reordering
2855        numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[
2856            0
2857        ]
2858
2859        if lhs_dep.num_vars != rhs_dep.num_vars:
2860            # this can happen due to we don't merge loops.
2861            # We can not do loop reordering in this case right now
2862            # Simply returning true if the two Deps are the same after
2863            # normalization (merging loops)
2864            return lhs_dep.normalize() == rhs_dep.normalize()
2865
2866        # Only reorder loops for pointwise for now
2867        if not node1.is_reduction():
2868            node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
2869        elif not node2.is_reduction():
2870            node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
2871        else:
2872            loop_ordering_log.debug(
2873                "Don't reorder loops since both nodes are reductions: %s v.s. %s",
2874                node1.get_name(),
2875                node2.get_name(),
2876            )
2877
2878        return self.score_fusion_memory(node1, node2) > 0
2879
2880    def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
2881        """
2882        Determine if it is possible to combine node1 and node2 into a
2883        single fused node.
2884        """
2885
2886        if node1 is node2:
2887            return False
2888
2889        why = WhyNoFuse(node1, node2)
2890
2891        if isinstance(node1, GroupedSchedulerNode) or isinstance(
2892            node2, GroupedSchedulerNode
2893        ):
2894            why("grouped node must not be fused with other nodes")
2895            return False
2896        if (
2897            isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
2898            and not node1.is_template()
2899        ):
2900            why("node1 is extern or nop")
2901            return False
2902        if (
2903            isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
2904            and not node2.is_template()
2905        ):
2906            why("node2 is extern or nop")
2907            return False
2908
2909        if node2.get_operation_names() & node1.ancestors:
2910            why("node1 must go before node2")
2911            return False
2912
2913        if node2.is_template():
2914            why("templates can only fuse epilogues")
2915            return False
2916        if node1.is_template() and (
2917            node2.has_aliasing_or_mutation()
2918            or node2.is_reduction()
2919            or not config.epilogue_fusion
2920        ):
2921            why("template epilogue not satisfied")
2922            return False
2923
2924        if (
2925            node1.get_buffer_names() | node2.get_buffer_names()
2926        ) & V.graph.no_fuse_buffer_names:
2927            why("fusion for buffer explicit disabled")
2928            return False
2929
2930        device = node1.get_device()
2931        device2 = node2.get_device()
2932        if device != device2:
2933            why("device mismatch (%s vs %s)", device, device2)
2934            return False
2935        del device2
2936
2937        no_shared_data = self.score_fusion_memory(node1, node2) == 0
2938        if no_shared_data:
2939            no_shared_data = not self.has_shared_data_after_reordering_loop(
2940                node1, node2
2941            )
2942
2943        loop_ordering_log.debug(
2944            "%s and %s has%s shared data",
2945            node1.get_name(),
2946            node2.get_name(),
2947            " no" if no_shared_data else "",
2948        )
2949        if no_shared_data and (
2950            not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
2951        ):
2952            if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"):
2953                common_buf_names = (
2954                    node1.read_writes.buffer_names() & node2.read_writes.buffer_names()
2955                )
2956                if len(common_buf_names) > 0:
2957                    get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row(
2958                        lambda: {
2959                            "pre_grad_graph_id": V.graph.graph_id,
2960                            "post_grad_graph_id": V.graph.post_grad_graph_id,
2961                            "node1_name": node1.get_name(),
2962                            "node2_name": node2.get_name(),
2963                            "node1_debug_str": write_text(node1.debug_str()),
2964                            "node2_debug_str": write_text(node2.debug_str()),
2965                            "common_buffer_names": list(common_buf_names),
2966                            "failure_reason": self.decide_fusion_fail_reason(
2967                                node1, node2, common_buf_names
2968                            ),
2969                        }
2970                    )
2971
2972                    why("no shared data due to indexing mismatch")
2973                    return False
2974            why("no shared data")
2975            return False  # heuristic not needed for correctness
2976
2977        if (
2978            not node1.is_foreach()
2979            and not node2.is_foreach()
2980            and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
2981        ):
2982            why("exceeds max fusion")
2983            return False  # heuristic not needed for correctness
2984
2985        if node1.get_operation_names() & node2.ancestors:
2986            # node2 depends on node1 outputs
2987            if not self.can_fuse_vertical(node1, node2):
2988                return False
2989            return self.get_backend(device).can_fuse_vertical(node1, node2)
2990        else:  # nodes don't depend on each other, but may have common reads
2991            if self.can_fusion_increase_peak_memory(node1, node2):
2992                why("will increase peak memory")
2993                return False
2994            return self.get_backend(device).can_fuse_horizontal(node1, node2)
2995
2996    def can_fuse_vertical(
2997        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
2998    ) -> bool:
2999        """
3000        Check if it is legal to fuse a consumer (node2) into a producer (node1).
3001
3002        We can fuse them if all the reads of node2 either match
3003        corresponding writes in node1, or are written by nodes that can
3004        be scheduled before the fusion of node1 and node2.
3005        """
3006        node1_buf_names = node1.get_buffer_names()
3007        node1_op_names = node1.get_operation_names()
3008        computed_deps: OrderedSet[Dep] = OrderedSet()
3009        why = WhyNoFuse(node1, node2)
3010
3011        for cd in node1.read_writes.writes:
3012            if not isinstance(cd, MemoryDep):
3013                continue
3014            for rd in node2.unmet_dependencies:
3015                if self.fusable_read_and_write(rd, cd):
3016                    computed_deps.add(rd)
3017
3018        for dep in node2.unmet_dependencies:
3019            if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2):
3020                computed_deps.add(dep)
3021
3022        remaining_deps = OrderedSet(
3023            dep.name for dep in node2.unmet_dependencies - computed_deps
3024        )
3025        if remaining_deps & node1_buf_names:
3026            # MemoryDeps didn't match and read different locations of the same buffer.
3027            # Examples here include:
3028            #   - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
3029            #   - MemoryDep("foo", x) != StarDep("foo")
3030            why("memory deps did not match")
3031            return False
3032        for name in remaining_deps:
3033            op_name = self.name_to_buf[name].defining_op.get_name()
3034            if node1_op_names & self.name_to_fused_node[op_name].ancestors:
3035                why("intermediate nodes between node1 & node2")
3036                return False
3037
3038        return True
3039
3040    def fusable_weak_dep(
3041        self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3042    ) -> bool:
3043        if weak_dep.name not in node1.get_buffer_names():
3044            return False
3045
3046        # A weak dep can be fused if and only if the fused operation acts inplace
3047        # on the buffer being mutated. i.e. the same index is being read then mutated
3048        mutating_writes = [
3049            write
3050            for write in node2.read_writes.writes
3051            if write.name == weak_dep.mutating_buf
3052        ]
3053        if len(mutating_writes) != 1:
3054            return False
3055        write = mutating_writes[0]
3056        assert isinstance(write, MemoryDep)
3057
3058        if free_symbol_is_type(write.index, SymT.TMP):
3059            return False
3060
3061        real_name = self.mutation_real_name[weak_dep.mutating_buf]
3062        relevant_reads = [
3063            read for read in node1.read_writes.reads if read.name == real_name
3064        ]
3065        return all(
3066            isinstance(read, MemoryDep)
3067            and not free_symbol_is_type(read.index, SymT.TMP)
3068            and read.index == write.index
3069            and read.size == write.size
3070            for read in relevant_reads
3071        )
3072
3073    # StarDep doesn't match MemoryDep, different indices don't match
3074    # However, broadcasting sometimes strips dimensions, and if that's the case
3075    # we still can match unmet dep
3076    # if there's indirect indexing, don't match it
3077    def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool:
3078        if isinstance(read, MemoryDep):
3079            if read.mode == write.mode and write.mode is not None:
3080                return True
3081            read_name = self.mutation_renames.get(read.name, read.name)
3082
3083            if (
3084                read_name != write.name
3085                or free_symbol_is_type(read.index, SymT.TMP)
3086                or free_symbol_is_type(write.index, SymT.TMP)
3087            ):
3088                return False
3089
3090            if config.loop_ordering_after_fusion and read.num_vars != write.num_vars:
3091                # Need merge loops if we do loop ordering after fusion since
3092                # we have not merged the loops yet when creating the scheduler
3093                # nodes.
3094                read = read.normalize()
3095                write = write.normalize()
3096
3097            return (
3098                read.index == write.index
3099                and len(read.size) >= len(write.size)
3100                and read.size[: len(write.size)] == write.size
3101            )
3102        elif isinstance(read, StarDep):
3103            read_name = self.mutation_renames.get(read.name, read.name)
3104            write_name = self.mutation_renames.get(write.name, write.name)
3105            if (
3106                read.mode == write.mode
3107                and write.mode is not None
3108                and read_name == write_name
3109            ):
3110                return True
3111        return False
3112
3113    def score_fusion(
3114        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3115    ) -> Tuple[bool, bool, int, int]:
3116        """
3117        Assign a score (higher comes first) to the fusion of node1
3118        and node2.  When different fusions conflict with each other,
3119        this is the way we decide what order to run them in.
3120
3121        Our current score is based on:
3122        - Estimate of the saved memory operations
3123        - Fusions closer together in original order
3124        """
3125        memory_score = self.score_fusion_memory(node1, node2)
3126        proximity_score = -max(
3127            abs(node1.min_order - node2.max_order),
3128            abs(node2.min_order - node1.max_order),
3129        )
3130        return (
3131            node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
3132            node1.is_reduction() == node2.is_reduction() and memory_score > 0,
3133            memory_score,
3134            proximity_score,
3135        )
3136
3137    def dep_size_hint(self, dep: Dep) -> int:
3138        res = 0
3139        if dep not in self.__dep_size_hint_cache:
3140            try:
3141                if not dep.has_unbacked_symbols():
3142                    res = dep.numbytes_hint()
3143            except KeyError:
3144                # In at least one test (test/inductor/test_torchbind.py) we
3145                # create a StarDep that doesn't exist in the graph and calling
3146                # `has_unbacked_symbols()` throws an error.
3147                pass
3148            self.__dep_size_hint_cache[dep] = res
3149        else:
3150            res = self.__dep_size_hint_cache[dep]
3151        return res
3152
3153    def score_fusion_memory(
3154        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3155    ) -> int:
3156        """
3157        The first term in our fusion score that estimates number of saved
3158        memory operations.
3159        """
3160        node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes)
3161        node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes)
3162
3163        # optimization: iter over smaller set
3164        if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len):
3165            if node1_dep_len > node2_dep_len:
3166                tmp = node1
3167                node1 = node2
3168                node2 = tmp
3169
3170            deps = []
3171            for dep in node1.read_writes.reads | node1.read_writes.writes:
3172                if dep in node2.read_writes.reads or dep in node2.read_writes.writes:
3173                    deps.append(dep)
3174
3175            return sum(self.dep_size_hint(dep) for dep in deps)
3176
3177        common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
3178            node2.read_writes.reads | node2.read_writes.writes
3179        )
3180        return sum(self.dep_size_hint(dep) for dep in common_memory_deps)
3181
3182    def get_possible_fusions_with_highest_priority(
3183        self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]
3184    ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]:
3185        # Group the possible fusions based on their priority from the backend.
3186        # Only return the group of possible fusions with highest priority.
3187        if len(possible_fusions) == 0:
3188            return possible_fusions
3189        possible_fusions_group_by_priority: Dict[
3190            int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]
3191        ] = {}
3192
3193        for node1, node2 in possible_fusions:
3194            assert node1.get_device() == node2.get_device()
3195            device = node1.get_device()
3196            fusion_pair_priority = int(
3197                self.get_backend(device).get_fusion_pair_priority(node1, node2)
3198            )
3199            if fusion_pair_priority not in possible_fusions_group_by_priority:
3200                possible_fusions_group_by_priority[fusion_pair_priority] = [
3201                    (node1, node2),
3202                ]
3203            else:
3204                possible_fusions_group_by_priority[fusion_pair_priority].append(
3205                    (node1, node2)
3206                )
3207        # return the possible fusions with highest priority
3208        possible_fusions_with_highest_priority = min(
3209            possible_fusions_group_by_priority.items(), key=operator.itemgetter(0)
3210        )[1]
3211        assert len(possible_fusions_with_highest_priority) > 0
3212        return possible_fusions_with_highest_priority
3213
3214    def score_fusion_key(
3215        self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode]
3216    ) -> Tuple[bool, bool, int, int]:
3217        """
3218        Shim for list.sort(key=...)
3219        """
3220        node1, node2 = nodes
3221        return self.score_fusion(node1, node2)
3222
3223    def compute_last_usage(self) -> None:
3224        """
3225        Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
3226        """
3227
3228        future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
3229
3230        for node in reversed(self.nodes):
3231            node.set_last_usage(future_used_buffers, self.mutation_real_name)
3232            future_used_buffers.update(node.last_usage)
3233
3234    def free_buffers(self) -> None:
3235        """Free any buffers that are no longer needed"""
3236        for name in sorted(
3237            self.buffer_names_to_free
3238            - V.graph.removed_buffers
3239            - V.graph.wrapper_code.freed
3240        ):
3241            if name in self.name_to_buf:
3242                buf = self.name_to_buf[name]
3243                if buf.can_free():
3244                    V.graph.wrapper_code.codegen_free(buf.node)
3245            elif name in V.graph.graph_inputs:
3246                storage = V.graph.graph_inputs[name].data
3247                assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer()
3248                V.graph.wrapper_code.codegen_free(storage.data)
3249
3250        self.buffer_names_to_free.clear()
3251
3252    def remove_kernel_local_buffers(self) -> None:
3253        """
3254        Any buffers that are both created and have a last use in the
3255        same kernel can be removed.
3256        """
3257
3258        fused_node_names = OrderedSet(
3259            self.name_to_buf[buf].defining_op.get_name()
3260            for buf in V.kernel.store_buffer_names
3261            if buf in self.name_to_buf
3262        )
3263        names_to_remove = []
3264        for out_buf in V.kernel.store_buffer_names:
3265            if out_buf not in self.name_to_buf:
3266                # Aux buffers created during kernel codegen
3267                names_to_remove.append(out_buf)
3268                continue
3269            users = self.name_to_buf[out_buf].users
3270            assert users is not None
3271            users = OrderedSet(user.get_name() for user in users if not user.is_weak)
3272            if users.issubset(fused_node_names):
3273                names_to_remove.append(out_buf)
3274
3275        def remove_filter(n: str) -> bool:
3276            return (
3277                n not in V.kernel.must_keep_buffers
3278                and n not in V.kernel.args.input_buffers
3279                and n not in self.mutation_renames
3280                and n not in self.mutation_real_name
3281            )
3282
3283        names_to_remove = list(filter(remove_filter, names_to_remove))
3284
3285        for name in names_to_remove:
3286            if name in V.kernel.args.inplace_buffers:
3287                buf = V.kernel.args.inplace_buffers[name]
3288                if isinstance(buf, str) and buf.startswith("REMOVED"):
3289                    continue
3290                remove = all(n in names_to_remove for n in buf.other_names)
3291                if remove:
3292                    self.remove_inplace_buffer(name)
3293                V.kernel.inplaced_to_remove.add(name)
3294            else:
3295                self.remove_buffer(name)
3296
3297    def remove_buffer(self, name: str) -> None:
3298        # Assign a special value instead of deleting the entry
3299        # because we still rely on output_buffers's length to
3300        # generate unique arg name.
3301        log.debug("remove_buffer(%r)", name)
3302        V.kernel.args.output_buffers[name] = "REMOVED"
3303        V.kernel.removed_buffers.add(name)
3304
3305    def remove_inplace_buffer(self, name: str) -> None:
3306        log.debug("removing_inplace_buffer(%r)", name)
3307        inner_name = V.kernel.args.inplace_buffers[name].inner_name
3308        V.kernel.args.inplace_buffers[name] = inner_name.replace(
3309            "in_out_ptr", "REMOVED"
3310        )
3311        V.kernel.removed_buffers.add(name)
3312
3313    def flush(self) -> None:
3314        for backend in self.backends.values():
3315            backend.flush()
3316        self.free_buffers()
3317
3318    def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None:
3319        assert isinstance(scheduler_node, ExternKernelSchedulerNode)
3320        # 'decide_inplace_update' stores the inplace update decisions in
3321        # the current kernel from where 'allocate' retrieve those decisions.
3322        # We have to make sure there is a non-NULL kernel handler to store
3323        # those inplace update decisions.
3324        counters["inductor"]["extern_calls"] += 1
3325        with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
3326            scheduler_node.decide_inplace_update()
3327            scheduler_node.mark_run()
3328        node = scheduler_node.node
3329        assert isinstance(node, ir.ExternKernel), f"{type(node)=}"
3330        node.codegen(V.graph.wrapper_code)
3331        self.free_buffers()
3332
3333    def create_backend(self, device: torch.device) -> BaseScheduling:
3334        assert (
3335            not is_gpu(device.type) or device.index is not None
3336        ), f"{device} should have been normalized in lowering"
3337        V.graph.add_device_info(device)
3338
3339        device_scheduling = get_scheduling_for_device(device.type)
3340        if device_scheduling is None:
3341            raise RuntimeError(f"Unsupported device type: {device.type}")
3342
3343        if not has_triton():
3344            if (
3345                device.type == "cuda"
3346                and (device_props := torch.cuda.get_device_properties(device)).major < 7
3347            ):
3348                raise RuntimeError(
3349                    f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}"  # noqa: B950
3350                )
3351            elif is_gpu(device.type):
3352                raise RuntimeError(
3353                    "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton"  # noqa: B950
3354                )
3355
3356        return device_scheduling(self)
3357
3358    def get_backend(self, device: torch.device) -> BaseScheduling:
3359        if device not in self.backends:
3360            self.backends[device] = self.create_backend(device)
3361        return self.backends[device]
3362
3363    def enter_context(self, node: BaseSchedulerNode) -> None:
3364        def get_order(n: torch.fx.Node) -> int:
3365            if n not in self.origin_to_index:
3366                self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
3367            return self.origin_to_index[n]
3368
3369        # Use a dict to have ordering
3370        origins = {
3371            (get_order(e), e): None
3372            for n in node.get_nodes()
3373            if n.node is not None
3374            for e in n.node.get_origins()
3375        }
3376        origins = list(origins.keys())
3377        if origins:
3378            _, last = max(origins, key=operator.itemgetter(0))
3379            V.graph.wrapper_code.enter_context(last)
3380
3381    def codegen(self) -> None:
3382        with dynamo_timed("Scheduler.codegen"):
3383            return self._codegen()
3384
3385    def _codegen(self) -> None:
3386        if config.check_stack_no_cycles_TESTING_ONLY:
3387            import torch._dynamo.convert_frame
3388
3389            stack = traceback.extract_stack()
3390            seen = set()
3391            for frame in reversed(stack):
3392                # This is where maybe_cprofile is
3393                if (
3394                    frame.name == "_compile_inner"
3395                    and frame.filename == torch._dynamo.convert_frame.__file__
3396                ):
3397                    break
3398                key = (frame.filename, frame.lineno)
3399                assert key not in seen, (
3400                    f"Duplicate stack frame {frame.filename}:{frame.lineno}; "
3401                    "did you add a decorator to one of the functions in this stack "
3402                    "trace?  If so, try using a context manager instead."
3403                )
3404                seen.add(key)
3405
3406        for node in self.nodes:
3407            try:
3408                log.debug(
3409                    "Generating code for node %s with estimated runtime %f",
3410                    node.get_name(),
3411                    node.get_estimated_runtime(),
3412                )
3413            except Exception as e:
3414                log.debug(
3415                    "Generating code for node %s with estimated runtime 0.0",
3416                    node.get_name(),
3417                )
3418
3419            self.enter_context(node)
3420
3421            if not isinstance(node, NopKernelSchedulerNode) and (
3422                device := node.get_device()
3423            ):
3424                if (
3425                    device != self.current_device
3426                    or node.is_extern()
3427                    or node.is_template()
3428                ):
3429                    self.flush()
3430                if device != self.current_device:
3431                    if self.current_device and device_need_guard(
3432                        self.current_device.type
3433                    ):
3434                        V.graph.wrapper_code.codegen_device_guard_exit()
3435                    if device_need_guard(device.type):
3436                        assert device.index is not None, "device should have an index"
3437                        V.graph.wrapper_code.codegen_device_guard_enter(device.index)
3438
3439                    self.current_device = device
3440
3441            self.buffer_names_to_free.update(node.last_usage)
3442
3443            if node.is_template():
3444                node, *epilogue = node.get_nodes()
3445                self.get_backend(device).codegen_template(node, epilogue)
3446            elif node.is_extern():
3447                node = typing.cast(ExternKernelSchedulerNode, node)
3448                self.codegen_extern_call(node)
3449            elif node.is_foreach():
3450                node = typing.cast(ForeachKernelSchedulerNode, node)
3451                backend_ = self.get_backend(device)
3452                from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
3453                from .codegen.simd import SIMDScheduling
3454
3455                if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
3456                    backend = backend_
3457                else:
3458                    raise AssertionError(f"{type(self)=}")
3459                backend.codegen_combo_kernel(node)
3460            elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
3461                self.get_backend(device).codegen_node(node)
3462            else:
3463                assert isinstance(node, NopKernelSchedulerNode)
3464                node.mark_run()
3465
3466            if config.triton.debug_sync_kernel:
3467                self.get_backend(device).codegen_sync()
3468
3469            self.available_buffer_names.update(node.get_buffer_names())
3470            self.completed_operations.update(node.get_operation_names())
3471
3472            if not isinstance(node, NopKernelSchedulerNode):
3473                device = node.get_device()
3474                if device is not None and self.get_backend(device).ready_to_flush():
3475                    self.flush()
3476
3477        if self.current_device and device_need_guard(self.current_device.type):
3478            # exit the outermost CUDA device guard. this is
3479            # important for nested indentation codegen-ing.
3480            V.graph.wrapper_code.codegen_device_guard_exit()
3481
3482        self.flush()
3483
3484    def benchmark_combo_kernel(
3485        self, node_list: Sequence[BaseSchedulerNode]
3486    ) -> Tuple[float, float, str]:
3487        """
3488        Benchmark fused list of nodes and return the execution time
3489        in milliseconds on randomly generated inputs.
3490        """
3491        device = node_list[0].get_device()
3492        V.graph.scheduler = self
3493        self.current_device = device
3494        backend = self.get_backend(device)
3495        return backend.benchmark_combo_kernel(node_list)
3496
3497    def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool:
3498        """
3499        If config.benchmark_fusion is False, always return True.
3500        Otherwise, return True if fusion can brings speedup.
3501        """
3502        if not config.benchmark_combo_kernel:
3503            return True
3504
3505        subkernel_nodes = nodes
3506        device = subkernel_nodes[0].get_device()
3507
3508        # don't support benchmark fusion for CPU right now.
3509        if device.type == "cpu":
3510            return True
3511
3512        from triton.compiler.errors import CompilationError
3513
3514        ms1, path1_list = 0.0, []
3515        for i, snode in enumerate(subkernel_nodes):
3516            node_list = snode.get_nodes()
3517            # We can not accurately benchmark kernel using atomic_add
3518            # due to how we generate random integer inputs.
3519            if self._any_atomic_add(node_list):
3520                fusion_log.debug(
3521                    "ComboKernel: benchmarking may not accurate due to atomic_add"
3522                )
3523
3524            try:
3525                ms, path = self.benchmark_fused_nodes(node_list)
3526                if math.isinf(ms):
3527                    fusion_log.debug(
3528                        "ComboKernel benchmark: register spilling of %d-th subkernel",
3529                        i,
3530                    )
3531                    return False
3532            except CompilationError as e:
3533                # workaround triton issue: https://github.com/openai/triton/issues/2151
3534                if "Loop-carried variable" in str(e):
3535                    fusion_log.debug(
3536                        "ComboKernel benchmark: return True because of loop-carried variable"
3537                    )
3538                    return True  # allow fusion
3539                else:
3540                    raise
3541            ms1 += ms
3542            path1_list.append(path)
3543
3544        try:
3545            ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes)
3546        except CompilationError as e:
3547            # workaround triton issue: https://github.com/openai/triton/issues/2151
3548            if "Loop-carried variable" in str(e):
3549                fusion_log.debug(
3550                    "ComboKernel benchmark: return True because of loop-carried variable"
3551                )
3552                return True  # allow fusion
3553            else:
3554                raise
3555
3556        # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
3557        small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3
3558        if fusion_log.isEnabledFor(logging.DEBUG):
3559            if ms1 > ms2 or small_kernel:
3560                fusion_log.debug(
3561                    "can fuse (benchmark): fusing causes %sx speedup",
3562                    green_text(f"{ms1 / ms2:.3f}"),
3563                )
3564            else:
3565                fusion_log.debug(
3566                    "cannot fuse (benchmark): fusing causes %sx slowdown",
3567                    red_text(f"{ms1 / ms2:.3f}"),
3568                )
3569        # ms1 returned by benchmark_fused_nodes discounted clone time
3570        return ms2 - ms2_clone < ms1 or small_kernel
3571
3572    def get_buffer_layout(self, buf_name: str) -> ir.Layout:
3573        buf = self.name_to_buf[buf_name]
3574        assert buf.node is not None
3575        return buf.node.get_layout()
3576
3577    def update_zero_dim_cpu_tensor(self) -> None:
3578        for node in self.nodes:
3579            if node.get_device() and is_gpu(node.get_device().type):
3580                for read in node.read_writes.reads:
3581                    buffer = V.graph.name_to_buffer.get(read.name)
3582                    if (
3583                        buffer
3584                        and buffer.get_device()
3585                        and buffer.get_device().type == "cpu"
3586                        and not isinstance(buffer.layout, MultiOutputLayout)
3587                        and buffer.get_size() == []
3588                    ):
3589                        V.graph.zero_dim_cpu_tensor_list.add(read.name)
3590
3591
3592class BaseScheduling:
3593    @classmethod
3594    def get_backend_features(cls, device: torch.device) -> Sequence[BackendFeature]:
3595        """Return a set of .codegen.common.BackendFeature()"""
3596        return ()
3597
3598    def can_fuse_vertical(
3599        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3600    ) -> bool:
3601        """
3602        Check whether node1 and node2 can be vertically fused or not.
3603        """
3604        raise NotImplementedError
3605
3606    def can_fuse_horizontal(
3607        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3608    ) -> bool:
3609        """
3610        Check whether node1 and node2 can be horizontally fused or not.
3611        """
3612        raise NotImplementedError
3613
3614    def fuse(
3615        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3616    ) -> FusedSchedulerNode:
3617        """
3618        Fuse two nodes
3619        """
3620        if node1.is_foreach() or node2.is_foreach():
3621            return ForeachKernelSchedulerNode.fuse(node1, node2)
3622        else:
3623            return FusedSchedulerNode.fuse(node1, node2)
3624
3625    def group_fn(
3626        self, sizes: Sequence[Sequence[sympy.Expr]]
3627    ) -> Tuple[Tuple[sympy.Expr, ...], ...]:
3628        """
3629        Process the iteration sizes in case a transformation needs to be applied.
3630        """
3631        raise NotImplementedError
3632
3633    def codegen_template(
3634        self,
3635        template_node: BaseSchedulerNode,
3636        epilogue_nodes: Sequence[BaseSchedulerNode],
3637    ) -> Optional[str]:
3638        """
3639        Given a template node, generate a kernel.
3640
3641        This function is only available for triton now. If the third-party backend behaves as a sub-class
3642        of TritonScheduling, it can override it or reuse it.
3643        """
3644        raise NotImplementedError
3645
3646    def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
3647        """
3648        Generate a kernel given a list of pre-fused nodes.
3649        """
3650        raise NotImplementedError
3651
3652    def codegen_sync(self) -> None:
3653        """
3654        Generate synchronization code for the kernel. This method depends on the hardware characteristics.
3655        """
3656        raise NotImplementedError
3657
3658    def ready_to_flush(self) -> bool:
3659        """
3660        Check whether the backend is requesting the scheduler to flush the generated kernel.
3661        If not supported, please return False.
3662        """
3663        return False
3664
3665    def flush(self) -> None:
3666        """
3667        Flush the generated kernel and python wrapper code to the source code file.
3668        """
3669        raise NotImplementedError
3670
3671    def benchmark_fused_nodes(
3672        self, nodes: Sequence[BaseSchedulerNode]
3673    ) -> Tuple[float, str]:
3674        """
3675        Benchmark fused list of nodes and return the execution time
3676        in milliseconds on randomly generated inputs.
3677        """
3678        raise NotImplementedError
3679
3680    def get_fusion_pair_priority(
3681        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
3682    ) -> int:
3683        """
3684        Return an unsigned integer which represents the priority of this fusion pair.
3685        The smaller is with higher priority.
3686        """
3687        return 0
3688
3689    def benchmark_combo_kernel(
3690        self, node_list: Sequence[BaseSchedulerNode]
3691    ) -> Tuple[float, float, str]:
3692        """
3693        Benchmark the list of nodes to combine and return the execution time
3694        and memory copy time in milliseconds on randomly generated inputs.
3695        """
3696        raise NotImplementedError
3697
3698
3699def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]:
3700    lines = []
3701    multi_template = node.get_template_node()
3702    assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
3703    if multi_template and multi_template.make_kernel_render is None:
3704        lines.append(f"{node.get_name()} Unfinalized multi template buffer")
3705    else:
3706        from torch._inductor.codegen.cuda_combined_scheduling import (
3707            CUDACombinedScheduling,
3708        )
3709
3710        from .codegen.simd import SIMDScheduling
3711
3712        snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes
3713        device = snodes[0].get_device()
3714        backend = node.scheduler.get_backend(device)
3715        assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling))
3716        V.graph.scheduler.current_device = device
3717
3718        # Don't increment kernel count when generating debug string.
3719        # This will confuse some unit tests that check the number of
3720        # generated kernels.
3721        old_generated_kernel_count = metrics.generated_kernel_count
3722        triton_code = backend.generate_kernel_code_from_nodes(snodes).strip()
3723        metrics.generated_kernel_count = old_generated_kernel_count
3724
3725        lines.append(f"{node.get_name()} Triton code:")
3726        lines.append(textwrap.indent(triton_code, "    "))
3727    return lines
3728